diff --git a/haystack/preview/document_stores/__init__.py b/haystack/preview/document_stores/__init__.py index 6d70bca62..19ba0ecd2 100644 --- a/haystack/preview/document_stores/__init__.py +++ b/haystack/preview/document_stores/__init__.py @@ -1,2 +1,3 @@ +from haystack.preview.document_stores.protocols import Store, DuplicatePolicy from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError diff --git a/haystack/preview/document_stores/memory/document_store.py b/haystack/preview/document_stores/memory/document_store.py index 4e6f47e90..3d64e0d3b 100644 --- a/haystack/preview/document_stores/memory/document_store.py +++ b/haystack/preview/document_stores/memory/document_store.py @@ -8,12 +8,12 @@ import rank_bm25 from tqdm.auto import tqdm from haystack.preview.dataclasses import Document +from haystack.preview.document_stores.protocols import DuplicatePolicy from haystack.preview.document_stores.memory._filters import match from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError from haystack.utils.scipy_utils import expit logger = logging.getLogger(__name__) -DuplicatePolicy = Literal["skip", "overwrite", "fail"] # document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to # True (default). Scaling uses the expit function (inverse of the logit function) after applying a SCALING_FACTOR. A @@ -126,17 +126,17 @@ class MemoryDocumentStore: return [doc for doc in self.storage.values() if match(conditions=filters, document=doc)] return list(self.storage.values()) - def write_documents(self, documents: List[Document], duplicates: DuplicatePolicy = "fail") -> None: + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None: """ Writes (or overwrites) documents into the store. :param documents: a list of documents. - :param duplicates: documents with the same ID count as duplicates. When duplicates are met, + :param policy: documents with the same ID count as duplicates. When duplicates are met, the store can: - skip: keep the existing document and ignore the new one. - overwrite: remove the old document and write the new one. - fail: an error is raised - :raises DuplicateError: Exception trigger on duplicate document if `duplicates="fail"` + :raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL` :return: None """ if ( @@ -147,10 +147,10 @@ class MemoryDocumentStore: raise ValueError("Please provide a list of Documents.") for document in documents: - if document.id in self.storage.keys(): - if duplicates == "fail": + if policy != DuplicatePolicy.OVERWRITE and document.id in self.storage.keys(): + if policy == DuplicatePolicy.FAIL: raise DuplicateDocumentError(f"ID '{document.id}' already exists.") - if duplicates == "skip": + if policy == DuplicatePolicy.SKIP: logger.warning("ID '%s' already exists", document.id) self.storage[document.id] = document diff --git a/haystack/preview/document_stores/protocols.py b/haystack/preview/document_stores/protocols.py new file mode 100644 index 000000000..1c269351f --- /dev/null +++ b/haystack/preview/document_stores/protocols.py @@ -0,0 +1,126 @@ +from typing import Protocol, Optional, Dict, Any, List + +import logging +from enum import Enum + +from haystack.preview.dataclasses import Document + + +logger = logging.getLogger(__name__) + + +class DuplicatePolicy(Enum): + SKIP = "skip" + OVERWRITE = "overwrite" + FAIL = "fail" + + +class Store(Protocol): + """ + Stores Documents to be used by the components of a Pipeline. + + Classes implementing this protocol often store the documents permanently and allow specialized components to + perform retrieval on them, either by embedding, by keyword, hybrid, and so on, depending on the backend used. + + In order to retrieve documents, consider using a Retriever that supports the document store implementation that + you're using. + """ + + def count_documents(self) -> int: + """ + Returns the number of documents stored. + """ + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the filters provided. + + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`, + `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `$ne`, `"$in"`, `$nin`, `"$gt"`, `"$gte"`, `"$lt"`, + `"$lte"`) or a metadata field name. + + Logical operator keys take a dictionary of metadata field names and/or logical operators as value. Metadata + field names take a dictionary of comparison operators as value. Comparison operator keys take a single value or + (in case of `"$in"`) a list of values as value. If no logical operator is provided, `"$and"` is used as default + operation. If no comparison operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used + as default operation. + + Example: + + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators can take a list of + dictionaries as value. + + Example: + + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` + + :param filters: the filters to apply to the document list. + :return: a list of Documents that match the given filters. + """ + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None: + """ + Writes (or overwrites) documents into the store. + + :param documents: a list of documents. + :param policy: documents with the same ID count as duplicates. When duplicates are met, + the store can: + - skip: keep the existing document and ignore the new one. + - overwrite: remove the old document and write the new one. + - fail: an error is raised + :raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL` + :return: None + """ + + def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the document store. + Fails with `MissingDocumentError` if no document with this id is present in the store. + + :param object_ids: the object_ids to delete + """ diff --git a/haystack/preview/pipeline.py b/haystack/preview/pipeline.py index 6a087cb7b..ca105096b 100644 --- a/haystack/preview/pipeline.py +++ b/haystack/preview/pipeline.py @@ -11,6 +11,8 @@ from canals.pipeline import ( ) from canals.pipeline.sockets import find_input_sockets +from haystack.preview.document_stores.protocols import Store + class NoSuchStoreError(PipelineError): pass @@ -23,9 +25,9 @@ class Pipeline(CanalsPipeline): def __init__(self): super().__init__() - self.stores = {} + self.stores: Dict[str, Store] = {} - def add_store(self, name: str, store: object) -> None: + def add_store(self, name: str, store: Store) -> None: """ Make a store available to all nodes of this pipeline. @@ -43,7 +45,7 @@ class Pipeline(CanalsPipeline): """ return list(self.stores.keys()) - def get_store(self, name: str) -> object: + def get_store(self, name: str) -> Store: """ Returns the store associated with the given name. diff --git a/test/preview/document_stores/_base.py b/test/preview/document_stores/_base.py index b60daaa39..3f9497d65 100644 --- a/test/preview/document_stores/_base.py +++ b/test/preview/document_stores/_base.py @@ -5,17 +5,17 @@ import numpy as np import pandas as pd from haystack.preview.dataclasses import Document -from haystack.preview.document_stores import StoreError +from haystack.preview.document_stores import Store, StoreError, DuplicatePolicy from haystack.preview.document_stores import MissingDocumentError, DuplicateDocumentError class DocumentStoreBaseTests: @pytest.fixture - def docstore(self): + def docstore(self) -> Store: raise NotImplementedError() @pytest.fixture - def filterable_docs(self): + def filterable_docs(self) -> List[Document]: embedding_zero = np.zeros([768, 1]).astype(np.float32) embedding_one = np.ones([768, 1]).astype(np.float32) @@ -70,42 +70,42 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_count_empty(self, docstore): + def test_count_empty(self, docstore: Store): assert docstore.count_documents() == 0 @pytest.mark.unit - def test_count_not_empty(self, docstore): + def test_count_not_empty(self, docstore: Store): docstore.write_documents( [Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")] ) assert docstore.count_documents() == 3 @pytest.mark.unit - def test_no_filter_empty(self, docstore): + def test_no_filter_empty(self, docstore: Store): assert docstore.filter_documents() == [] assert docstore.filter_documents(filters={}) == [] @pytest.mark.unit - def test_no_filter_not_empty(self, docstore): + def test_no_filter_not_empty(self, docstore: Store): docs = [Document(content="test doc")] docstore.write_documents(docs) assert docstore.filter_documents() == docs assert docstore.filter_documents(filters={}) == docs @pytest.mark.unit - def test_filter_simple_metadata_value(self, docstore, filterable_docs): + def test_filter_simple_metadata_value(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": "100"}) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) @pytest.mark.unit - def test_filter_simple_list_single_element(self, docstore, filterable_docs): + def test_filter_simple_list_single_element(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": ["100"]}) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) @pytest.mark.unit - def test_filter_document_content(self, docstore, filterable_docs): + def test_filter_document_content(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"content": "A Foo Document 1"}) assert self.contains_same_docs( @@ -113,19 +113,19 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_document_type(self, docstore, filterable_docs): + def test_filter_document_type(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"content_type": "table"}) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.content_type == "table"]) @pytest.mark.unit - def test_filter_simple_list_one_value(self, docstore, filterable_docs): + def test_filter_simple_list_one_value(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": ["100"]}) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") in ["100"]]) @pytest.mark.unit - def test_filter_simple_list(self, docstore, filterable_docs): + def test_filter_simple_list(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": ["100", "123"]}) assert self.contains_same_docs( @@ -133,49 +133,49 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_incorrect_filter_name(self, docstore, filterable_docs): + def test_incorrect_filter_name(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"non_existing_meta_field": ["whatever"]}) assert len(result) == 0 @pytest.mark.unit - def test_incorrect_filter_type(self, docstore, filterable_docs): + def test_incorrect_filter_type(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(ValueError, match="dictionaries or lists"): docstore.filter_documents(filters="something odd") @pytest.mark.unit - def test_incorrect_filter_value(self, docstore, filterable_docs): + def test_incorrect_filter_value(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": ["nope"]}) assert len(result) == 0 @pytest.mark.unit - def test_incorrect_filter_nesting(self, docstore, filterable_docs): + def test_incorrect_filter_nesting(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(ValueError, match="malformed"): docstore.filter_documents(filters={"number": {"page": "100"}}) @pytest.mark.unit - def test_deeper_incorrect_filter_nesting(self, docstore, filterable_docs): + def test_deeper_incorrect_filter_nesting(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(ValueError, match="malformed"): docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}}) @pytest.mark.unit - def test_eq_filter_explicit(self, docstore, filterable_docs): + def test_eq_filter_explicit(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$eq": "100"}}) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) @pytest.mark.unit - def test_eq_filter_implicit(self, docstore, filterable_docs): + def test_eq_filter_implicit(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": "100"}) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) @pytest.mark.unit - def test_eq_filter_table(self, docstore, filterable_docs): + def test_eq_filter_table(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"content": pd.DataFrame([1])}) assert self.contains_same_docs( @@ -188,7 +188,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_eq_filter_tensor(self, docstore, filterable_docs): + def test_eq_filter_tensor(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) embedding = np.zeros([768, 1]).astype(np.float32) result = docstore.filter_documents(filters={"embedding": embedding}) @@ -197,25 +197,25 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_deeper_incorrect_filter_nesting(self, docstore, filterable_docs): + def test_deeper_incorrect_filter_nesting(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(ValueError, match="malformed"): docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}}) @pytest.mark.unit - def test_eq_filter_explicit(self, docstore, filterable_docs): + def test_eq_filter_explicit(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$eq": "100"}}) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) @pytest.mark.unit - def test_eq_filter_implicit(self, docstore, filterable_docs): + def test_eq_filter_implicit(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": "100"}) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) @pytest.mark.unit - def test_in_filter_explicit(self, docstore, filterable_docs): + def test_in_filter_explicit(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$in": ["100", "123", "n.a."]}}) assert self.contains_same_docs( @@ -223,7 +223,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_in_filter_implicit(self, docstore, filterable_docs): + def test_in_filter_implicit(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": ["100", "123", "n.a."]}) assert self.contains_same_docs( @@ -231,7 +231,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_in_filter_table(self, docstore, filterable_docs): + def test_in_filter_table(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"content": {"$in": [pd.DataFrame([1]), pd.DataFrame([2])]}}) assert self.contains_same_docs( @@ -245,7 +245,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_in_filter_tensor(self, docstore, filterable_docs): + def test_in_filter_tensor(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) embedding_zero = np.zeros([768, 1]).astype(np.float32) embedding_one = np.ones([768, 1]).astype(np.float32) @@ -261,13 +261,13 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_ne_filter(self, docstore, filterable_docs): + def test_ne_filter(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$ne": "100"}}) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") != "100"]) @pytest.mark.unit - def test_ne_filter_table(self, docstore, filterable_docs): + def test_ne_filter_table(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"content": {"$ne": pd.DataFrame([1])}}) assert self.contains_same_docs( @@ -280,7 +280,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_ne_filter_tensor(self, docstore, filterable_docs): + def test_ne_filter_tensor(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) embedding = np.zeros([768, 1]).astype(np.float32) result = docstore.filter_documents(filters={"embedding": {"$ne": embedding}}) @@ -294,7 +294,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_nin_filter(self, docstore, filterable_docs): + def test_nin_filter(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}}) assert self.contains_same_docs( @@ -302,7 +302,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_nin_filter_table(self, docstore, filterable_docs): + def test_nin_filter_table(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"content": {"$nin": [pd.DataFrame([1]), pd.DataFrame([0])]}}) assert self.contains_same_docs( @@ -316,7 +316,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_nin_filter_tensor(self, docstore, filterable_docs): + def test_nin_filter_tensor(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) embedding_zeros = np.zeros([768, 1]).astype(np.float32) embedding_ones = np.zeros([768, 1]).astype(np.float32) @@ -335,7 +335,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_nin_filter(self, docstore, filterable_docs): + def test_nin_filter(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}}) assert self.contains_same_docs( @@ -343,7 +343,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_gt_filter(self, docstore, filterable_docs): + def test_gt_filter(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"number": {"$gt": 0.0}}) assert self.contains_same_docs( @@ -351,26 +351,26 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_gt_filter_non_numeric(self, docstore, filterable_docs): + def test_gt_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"page": {"$gt": "100"}}) @pytest.mark.unit - def test_gt_filter_table(self, docstore, filterable_docs): + def test_gt_filter_table(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"content": {"$gt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) @pytest.mark.unit - def test_gt_filter_tensor(self, docstore, filterable_docs): + def test_gt_filter_tensor(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) embedding_zeros = np.zeros([768, 1]).astype(np.float32) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"embedding": {"$gt": embedding_zeros}}) @pytest.mark.unit - def test_gte_filter(self, docstore, filterable_docs): + def test_gte_filter(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"number": {"$gte": -2.0}}) assert self.contains_same_docs( @@ -378,26 +378,26 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_gte_filter_non_numeric(self, docstore, filterable_docs): + def test_gte_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"page": {"$gte": "100"}}) @pytest.mark.unit - def test_gte_filter_table(self, docstore, filterable_docs): + def test_gte_filter_table(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"content": {"$gte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) @pytest.mark.unit - def test_gte_filter_tensor(self, docstore, filterable_docs): + def test_gte_filter_tensor(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) embedding_zeros = np.zeros([768, 1]).astype(np.float32) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"embedding": {"$gte": embedding_zeros}}) @pytest.mark.unit - def test_lt_filter(self, docstore, filterable_docs): + def test_lt_filter(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"number": {"$lt": 0.0}}) assert self.contains_same_docs( @@ -405,26 +405,26 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_lt_filter_non_numeric(self, docstore, filterable_docs): + def test_lt_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"page": {"$lt": "100"}}) @pytest.mark.unit - def test_lt_filter_table(self, docstore, filterable_docs): + def test_lt_filter_table(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"content": {"$lt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) @pytest.mark.unit - def test_lt_filter_tensor(self, docstore, filterable_docs): + def test_lt_filter_tensor(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) embedding_ones = np.ones([768, 1]).astype(np.float32) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"embedding": {"$lt": embedding_ones}}) @pytest.mark.unit - def test_lte_filter(self, docstore, filterable_docs): + def test_lte_filter(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"number": {"$lte": 2.0}}) assert self.contains_same_docs( @@ -432,26 +432,26 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_lte_filter_non_numeric(self, docstore, filterable_docs): + def test_lte_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"page": {"$lte": "100"}}) @pytest.mark.unit - def test_lte_filter_table(self, docstore, filterable_docs): + def test_lte_filter_table(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"content": {"$lte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) @pytest.mark.unit - def test_lte_filter_tensor(self, docstore, filterable_docs): + def test_lte_filter_tensor(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) embedding_ones = np.ones([768, 1]).astype(np.float32) with pytest.raises(StoreError, match="Can't evaluate"): docstore.filter_documents(filters={"embedding": {"$lte": embedding_ones}}) @pytest.mark.unit - def test_filter_simple_implicit_and_with_multi_key_dict(self, docstore, filterable_docs): + def test_filter_simple_implicit_and_with_multi_key_dict(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0.0}}) assert self.contains_same_docs( @@ -464,7 +464,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_simple_explicit_and_with_multikey_dict(self, docstore, filterable_docs): + def test_filter_simple_explicit_and_with_multikey_dict(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"number": {"$and": {"$lte": 0, "$gte": -2}}}) assert self.contains_same_docs( @@ -477,7 +477,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_simple_explicit_and_with_list(self, docstore, filterable_docs): + def test_filter_simple_explicit_and_with_list(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}) assert self.contains_same_docs( @@ -490,7 +490,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_simple_implicit_and(self, docstore, filterable_docs): + def test_filter_simple_implicit_and(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0}}) assert self.contains_same_docs( @@ -503,7 +503,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_nested_explicit_and(self, docstore, filterable_docs): + def test_filter_nested_explicit_and(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) filters = {"$and": {"number": {"$and": {"$lte": 2, "$gte": 0}}, "name": {"$in": ["name_0", "name_1"]}}} result = docstore.filter_documents(filters=filters) @@ -522,7 +522,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_nested_implicit_and(self, docstore, filterable_docs): + def test_filter_nested_implicit_and(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) filters_simplified = {"number": {"$lte": 2, "$gte": 0}, "name": ["name_0", "name_1"]} result = docstore.filter_documents(filters=filters_simplified) @@ -541,7 +541,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_simple_or(self, docstore, filterable_docs): + def test_filter_simple_or(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) filters = {"$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} result = docstore.filter_documents(filters=filters) @@ -558,7 +558,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_nested_or(self, docstore, filterable_docs): + def test_filter_nested_or(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) filters = {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}} result = docstore.filter_documents(filters=filters) @@ -575,7 +575,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_nested_and_or_explicit(self, docstore, filterable_docs): + def test_filter_nested_and_or_explicit(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) filters_simplified = { "$and": {"page": {"$eq": "123"}, "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} @@ -597,7 +597,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_nested_and_or_implicit(self, docstore, filterable_docs): + def test_filter_nested_and_or_implicit(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) filters_simplified = { "page": {"$eq": "123"}, @@ -620,7 +620,7 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_nested_or_and(self, docstore, filterable_docs): + def test_filter_nested_or_and(self, docstore: Store, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) filters_simplified = { "$or": { @@ -645,7 +645,9 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_filter_nested_multiple_identical_operators_same_level(self, docstore, filterable_docs): + def test_filter_nested_multiple_identical_operators_same_level( + self, docstore: Store, filterable_docs: List[Document] + ): docstore.write_documents(filterable_docs) filters = { "$or": [ @@ -667,54 +669,54 @@ class DocumentStoreBaseTests: ) @pytest.mark.unit - def test_write(self, docstore): + def test_write(self, docstore: Store): doc = Document(content="test doc") docstore.write_documents([doc]) assert docstore.filter_documents(filters={"id": doc.id}) == [doc] @pytest.mark.unit - def test_write_duplicate_fail(self, docstore): + def test_write_duplicate_fail(self, docstore: Store): doc = Document(content="test doc") docstore.write_documents([doc]) with pytest.raises(DuplicateDocumentError, match=f"ID '{doc.id}' already exists."): - docstore.write_documents(documents=[doc]) + docstore.write_documents(documents=[doc], policy=DuplicatePolicy.FAIL) assert docstore.filter_documents(filters={"id": doc.id}) == [doc] @pytest.mark.unit - def test_write_duplicate_skip(self, docstore): + def test_write_duplicate_skip(self, docstore: Store): doc = Document(content="test doc") docstore.write_documents([doc]) - docstore.write_documents(documents=[doc], duplicates="skip") + docstore.write_documents(documents=[doc], policy=DuplicatePolicy.SKIP) assert docstore.filter_documents(filters={"id": doc.id}) == [doc] @pytest.mark.unit - def test_write_duplicate_overwrite(self, docstore): + def test_write_duplicate_overwrite(self, docstore: Store): doc1 = Document(content="test doc 1") doc2 = Document(content="test doc 2") object.__setattr__(doc2, "id", doc1.id) # Make two docs with different content but same ID docstore.write_documents([doc2]) docstore.filter_documents(filters={"id": doc1.id}) == [doc2] - docstore.write_documents(documents=[doc1], duplicates="overwrite") + docstore.write_documents(documents=[doc1], policy=DuplicatePolicy.OVERWRITE) assert docstore.filter_documents(filters={"id": doc1.id}) == [doc1] @pytest.mark.unit - def test_write_not_docs(self, docstore): + def test_write_not_docs(self, docstore: Store): with pytest.raises(ValueError, match="Please provide a list of Documents"): docstore.write_documents(["not a document for sure"]) @pytest.mark.unit - def test_write_not_list(self, docstore): + def test_write_not_list(self, docstore: Store): with pytest.raises(ValueError, match="Please provide a list of Documents"): docstore.write_documents("not a list actually") @pytest.mark.unit - def test_delete_empty(self, docstore): + def test_delete_empty(self, docstore: Store): with pytest.raises(MissingDocumentError): docstore.delete_documents(["test"]) @pytest.mark.unit - def test_delete_not_empty(self, docstore): + def test_delete_not_empty(self, docstore: Store): doc = Document(content="test doc") docstore.write_documents([doc]) @@ -724,7 +726,7 @@ class DocumentStoreBaseTests: assert docstore.filter_documents(filters={"id": doc.id}) @pytest.mark.unit - def test_delete_not_empty_nonexisting(self, docstore): + def test_delete_not_empty_nonexisting(self, docstore: Store): doc = Document(content="test doc") docstore.write_documents([doc]) diff --git a/test/preview/document_stores/test_memory.py b/test/preview/document_stores/test_memory.py index 1344234bd..69206d44d 100644 --- a/test/preview/document_stores/test_memory.py +++ b/test/preview/document_stores/test_memory.py @@ -4,7 +4,7 @@ import pandas as pd import pytest from haystack.preview import Document -from haystack.preview.document_stores import MemoryDocumentStore +from haystack.preview.document_stores import Store, MemoryDocumentStore from test.preview.document_stores._base import DocumentStoreBaseTests @@ -19,7 +19,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): return MemoryDocumentStore() @pytest.mark.unit - def test_bm25_retrieval(self, docstore): + def test_bm25_retrieval(self, docstore: Store): docstore = MemoryDocumentStore() # Tests if the bm25_retrieval method returns the correct document based on the input query. docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] @@ -29,7 +29,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): assert results[0].content == "Haystack supports multiple languages" @pytest.mark.unit - def test_bm25_retrieval_with_empty_document_store(self, docstore, caplog): + def test_bm25_retrieval_with_empty_document_store(self, docstore: Store, caplog): caplog.set_level(logging.INFO) # Tests if the bm25_retrieval method correctly returns an empty list when there are no documents in the store. results = docstore.bm25_retrieval(query="How to test this?", top_k=2) @@ -37,7 +37,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): assert "No documents found for BM25 retrieval. Returning empty list." in caplog.text @pytest.mark.unit - def test_bm25_retrieval_empty_query(self, docstore): + def test_bm25_retrieval_empty_query(self, docstore: Store): # Tests if the bm25_retrieval method returns a document when the query is an empty string. docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] docstore.write_documents(docs) @@ -45,7 +45,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): docstore.bm25_retrieval(query="", top_k=1) @pytest.mark.unit - def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_string(self, docstore): + def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_string(self, docstore: Store): # Tests if the bm25_retrieval method returns a document when the query is an empty string. docs = [ Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), @@ -57,7 +57,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): assert results[0].content == "Haystack supports multiple languages" @pytest.mark.unit - def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_list(self, docstore): + def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_list(self, docstore: Store): # Tests if the bm25_retrieval method returns a document when the query is an empty string. docs = [ Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), @@ -69,7 +69,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): assert results[0].content == "Haystack supports multiple languages" @pytest.mark.unit - def test_bm25_retrieval_filter_two_allowed_doc_type_as_list(self, docstore): + def test_bm25_retrieval_filter_two_allowed_doc_type_as_list(self, docstore: Store): # Tests if the bm25_retrieval method returns a document when the query is an empty string. docs = [ Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), @@ -80,7 +80,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): assert len(results) == 2 @pytest.mark.unit - def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_string(self, docstore): + def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_string(self, docstore: Store): # Tests if the bm25_retrieval method returns a document when the query is an empty string. docs = [ Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"), @@ -93,7 +93,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "audio"}) @pytest.mark.unit - def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_list(self, docstore): + def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_list(self, docstore: Store): # Tests if the bm25_retrieval method returns a document when the query is an empty string. docs = [ Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"), @@ -106,7 +106,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["audio"]}) @pytest.mark.unit - def test_bm25_retrieval_filter_two_not_all_allowed_doc_type_as_list(self, docstore): + def test_bm25_retrieval_filter_two_not_all_allowed_doc_type_as_list(self, docstore: Store): # Tests if the bm25_retrieval method returns a document when the query is an empty string. docs = [ Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), @@ -119,7 +119,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "audio"]}) @pytest.mark.unit - def test_bm25_retrieval_with_different_top_k(self, docstore): + def test_bm25_retrieval_with_different_top_k(self, docstore: Store): # Tests if the bm25_retrieval method correctly changes the number of returned documents # based on the top_k parameter. docs = [ @@ -139,7 +139,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # Test two queries and make sure the results are different @pytest.mark.unit - def test_bm25_retrieval_with_two_queries(self, docstore): + def test_bm25_retrieval_with_two_queries(self, docstore: Store): # Tests if the bm25_retrieval method returns different documents for different queries. docs = [ Document(content="Javascript is a popular programming language"), @@ -158,7 +158,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # Test a query, add a new document and make sure results are appropriately updated @pytest.mark.unit - def test_bm25_retrieval_with_updated_docs(self, docstore): + def test_bm25_retrieval_with_updated_docs(self, docstore: Store): # Tests if the bm25_retrieval method correctly updates the retrieved documents when new # documents are added to the store. docs = [Document(content="Hello world")] @@ -178,7 +178,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): assert results[0].content == "Python is a popular programming language" @pytest.mark.unit - def test_bm25_retrieval_with_scale_score(self, docstore): + def test_bm25_retrieval_with_scale_score(self, docstore: Store): docs = [Document(content="Python programming"), Document(content="Java programming")] docstore.write_documents(docs) @@ -191,7 +191,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): assert results[0].score != results1[0].score @pytest.mark.unit - def test_bm25_retrieval_with_table_content(self, docstore): + def test_bm25_retrieval_with_table_content(self, docstore: Store): # Tests if the bm25_retrieval method correctly returns a dataframe when the content_type is table. table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]}) docs = [