feat: introduce Store protocol (v2) (#5259)

* add protocol and adapt pipeline

* review feedback & update tests

* pylint

* Update haystack/preview/document_stores/protocols.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Update haystack/preview/document_stores/memory/document_store.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* docstring of Store

* adapt memorydocumentstore

* fix tests

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
ZanSara 2023-07-07 12:10:08 +02:00 committed by GitHub
parent 90ff3817e7
commit f49bd3a12f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 232 additions and 101 deletions

View File

@ -1,2 +1,3 @@
from haystack.preview.document_stores.protocols import Store, DuplicatePolicy
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError

View File

@ -8,12 +8,12 @@ import rank_bm25
from tqdm.auto import tqdm from tqdm.auto import tqdm
from haystack.preview.dataclasses import Document from haystack.preview.dataclasses import Document
from haystack.preview.document_stores.protocols import DuplicatePolicy
from haystack.preview.document_stores.memory._filters import match from haystack.preview.document_stores.memory._filters import match
from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError
from haystack.utils.scipy_utils import expit from haystack.utils.scipy_utils import expit
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DuplicatePolicy = Literal["skip", "overwrite", "fail"]
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to # document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a SCALING_FACTOR. A # True (default). Scaling uses the expit function (inverse of the logit function) after applying a SCALING_FACTOR. A
@ -126,17 +126,17 @@ class MemoryDocumentStore:
return [doc for doc in self.storage.values() if match(conditions=filters, document=doc)] return [doc for doc in self.storage.values() if match(conditions=filters, document=doc)]
return list(self.storage.values()) return list(self.storage.values())
def write_documents(self, documents: List[Document], duplicates: DuplicatePolicy = "fail") -> None: def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
""" """
Writes (or overwrites) documents into the store. Writes (or overwrites) documents into the store.
:param documents: a list of documents. :param documents: a list of documents.
:param duplicates: documents with the same ID count as duplicates. When duplicates are met, :param policy: documents with the same ID count as duplicates. When duplicates are met,
the store can: the store can:
- skip: keep the existing document and ignore the new one. - skip: keep the existing document and ignore the new one.
- overwrite: remove the old document and write the new one. - overwrite: remove the old document and write the new one.
- fail: an error is raised - fail: an error is raised
:raises DuplicateError: Exception trigger on duplicate document if `duplicates="fail"` :raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL`
:return: None :return: None
""" """
if ( if (
@ -147,10 +147,10 @@ class MemoryDocumentStore:
raise ValueError("Please provide a list of Documents.") raise ValueError("Please provide a list of Documents.")
for document in documents: for document in documents:
if document.id in self.storage.keys(): if policy != DuplicatePolicy.OVERWRITE and document.id in self.storage.keys():
if duplicates == "fail": if policy == DuplicatePolicy.FAIL:
raise DuplicateDocumentError(f"ID '{document.id}' already exists.") raise DuplicateDocumentError(f"ID '{document.id}' already exists.")
if duplicates == "skip": if policy == DuplicatePolicy.SKIP:
logger.warning("ID '%s' already exists", document.id) logger.warning("ID '%s' already exists", document.id)
self.storage[document.id] = document self.storage[document.id] = document

View File

@ -0,0 +1,126 @@
from typing import Protocol, Optional, Dict, Any, List
import logging
from enum import Enum
from haystack.preview.dataclasses import Document
logger = logging.getLogger(__name__)
class DuplicatePolicy(Enum):
SKIP = "skip"
OVERWRITE = "overwrite"
FAIL = "fail"
class Store(Protocol):
"""
Stores Documents to be used by the components of a Pipeline.
Classes implementing this protocol often store the documents permanently and allow specialized components to
perform retrieval on them, either by embedding, by keyword, hybrid, and so on, depending on the backend used.
In order to retrieve documents, consider using a Retriever that supports the document store implementation that
you're using.
"""
def count_documents(self) -> int:
"""
Returns the number of documents stored.
"""
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.
Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`,
`"$or"`, `"$not"`), a comparison operator (`"$eq"`, `$ne`, `"$in"`, `$nin`, `"$gt"`, `"$gte"`, `"$lt"`,
`"$lte"`) or a metadata field name.
Logical operator keys take a dictionary of metadata field names and/or logical operators as value. Metadata
field names take a dictionary of comparison operators as value. Comparison operator keys take a single value or
(in case of `"$in"`) a list of values as value. If no logical operator is provided, `"$and"` is used as default
operation. If no comparison operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used
as default operation.
Example:
```python
filters = {
"$and": {
"type": {"$eq": "article"},
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": {"$in": ["economy", "politics"]},
"publisher": {"$eq": "nytimes"}
}
}
}
# or simpler using default operators
filters = {
"type": "article",
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": ["economy", "politics"],
"publisher": "nytimes"
}
}
```
To use the same logical operator multiple times on the same level, logical operators can take a list of
dictionaries as value.
Example:
```python
filters = {
"$or": [
{
"$and": {
"Type": "News Paper",
"Date": {
"$lt": "2019-01-01"
}
}
},
{
"$and": {
"Type": "Blog Post",
"Date": {
"$gte": "2019-01-01"
}
}
}
]
}
```
:param filters: the filters to apply to the document list.
:return: a list of Documents that match the given filters.
"""
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
"""
Writes (or overwrites) documents into the store.
:param documents: a list of documents.
:param policy: documents with the same ID count as duplicates. When duplicates are met,
the store can:
- skip: keep the existing document and ignore the new one.
- overwrite: remove the old document and write the new one.
- fail: an error is raised
:raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL`
:return: None
"""
def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with a matching document_ids from the document store.
Fails with `MissingDocumentError` if no document with this id is present in the store.
:param object_ids: the object_ids to delete
"""

View File

@ -11,6 +11,8 @@ from canals.pipeline import (
) )
from canals.pipeline.sockets import find_input_sockets from canals.pipeline.sockets import find_input_sockets
from haystack.preview.document_stores.protocols import Store
class NoSuchStoreError(PipelineError): class NoSuchStoreError(PipelineError):
pass pass
@ -23,9 +25,9 @@ class Pipeline(CanalsPipeline):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.stores = {} self.stores: Dict[str, Store] = {}
def add_store(self, name: str, store: object) -> None: def add_store(self, name: str, store: Store) -> None:
""" """
Make a store available to all nodes of this pipeline. Make a store available to all nodes of this pipeline.
@ -43,7 +45,7 @@ class Pipeline(CanalsPipeline):
""" """
return list(self.stores.keys()) return list(self.stores.keys())
def get_store(self, name: str) -> object: def get_store(self, name: str) -> Store:
""" """
Returns the store associated with the given name. Returns the store associated with the given name.

View File

@ -5,17 +5,17 @@ import numpy as np
import pandas as pd import pandas as pd
from haystack.preview.dataclasses import Document from haystack.preview.dataclasses import Document
from haystack.preview.document_stores import StoreError from haystack.preview.document_stores import Store, StoreError, DuplicatePolicy
from haystack.preview.document_stores import MissingDocumentError, DuplicateDocumentError from haystack.preview.document_stores import MissingDocumentError, DuplicateDocumentError
class DocumentStoreBaseTests: class DocumentStoreBaseTests:
@pytest.fixture @pytest.fixture
def docstore(self): def docstore(self) -> Store:
raise NotImplementedError() raise NotImplementedError()
@pytest.fixture @pytest.fixture
def filterable_docs(self): def filterable_docs(self) -> List[Document]:
embedding_zero = np.zeros([768, 1]).astype(np.float32) embedding_zero = np.zeros([768, 1]).astype(np.float32)
embedding_one = np.ones([768, 1]).astype(np.float32) embedding_one = np.ones([768, 1]).astype(np.float32)
@ -70,42 +70,42 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_count_empty(self, docstore): def test_count_empty(self, docstore: Store):
assert docstore.count_documents() == 0 assert docstore.count_documents() == 0
@pytest.mark.unit @pytest.mark.unit
def test_count_not_empty(self, docstore): def test_count_not_empty(self, docstore: Store):
docstore.write_documents( docstore.write_documents(
[Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")] [Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")]
) )
assert docstore.count_documents() == 3 assert docstore.count_documents() == 3
@pytest.mark.unit @pytest.mark.unit
def test_no_filter_empty(self, docstore): def test_no_filter_empty(self, docstore: Store):
assert docstore.filter_documents() == [] assert docstore.filter_documents() == []
assert docstore.filter_documents(filters={}) == [] assert docstore.filter_documents(filters={}) == []
@pytest.mark.unit @pytest.mark.unit
def test_no_filter_not_empty(self, docstore): def test_no_filter_not_empty(self, docstore: Store):
docs = [Document(content="test doc")] docs = [Document(content="test doc")]
docstore.write_documents(docs) docstore.write_documents(docs)
assert docstore.filter_documents() == docs assert docstore.filter_documents() == docs
assert docstore.filter_documents(filters={}) == docs assert docstore.filter_documents(filters={}) == docs
@pytest.mark.unit @pytest.mark.unit
def test_filter_simple_metadata_value(self, docstore, filterable_docs): def test_filter_simple_metadata_value(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": "100"}) result = docstore.filter_documents(filters={"page": "100"})
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
@pytest.mark.unit @pytest.mark.unit
def test_filter_simple_list_single_element(self, docstore, filterable_docs): def test_filter_simple_list_single_element(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": ["100"]}) result = docstore.filter_documents(filters={"page": ["100"]})
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
@pytest.mark.unit @pytest.mark.unit
def test_filter_document_content(self, docstore, filterable_docs): def test_filter_document_content(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"content": "A Foo Document 1"}) result = docstore.filter_documents(filters={"content": "A Foo Document 1"})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -113,19 +113,19 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_document_type(self, docstore, filterable_docs): def test_filter_document_type(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"content_type": "table"}) result = docstore.filter_documents(filters={"content_type": "table"})
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.content_type == "table"]) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.content_type == "table"])
@pytest.mark.unit @pytest.mark.unit
def test_filter_simple_list_one_value(self, docstore, filterable_docs): def test_filter_simple_list_one_value(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": ["100"]}) result = docstore.filter_documents(filters={"page": ["100"]})
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") in ["100"]]) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") in ["100"]])
@pytest.mark.unit @pytest.mark.unit
def test_filter_simple_list(self, docstore, filterable_docs): def test_filter_simple_list(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": ["100", "123"]}) result = docstore.filter_documents(filters={"page": ["100", "123"]})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -133,49 +133,49 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_incorrect_filter_name(self, docstore, filterable_docs): def test_incorrect_filter_name(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"non_existing_meta_field": ["whatever"]}) result = docstore.filter_documents(filters={"non_existing_meta_field": ["whatever"]})
assert len(result) == 0 assert len(result) == 0
@pytest.mark.unit @pytest.mark.unit
def test_incorrect_filter_type(self, docstore, filterable_docs): def test_incorrect_filter_type(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(ValueError, match="dictionaries or lists"): with pytest.raises(ValueError, match="dictionaries or lists"):
docstore.filter_documents(filters="something odd") docstore.filter_documents(filters="something odd")
@pytest.mark.unit @pytest.mark.unit
def test_incorrect_filter_value(self, docstore, filterable_docs): def test_incorrect_filter_value(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": ["nope"]}) result = docstore.filter_documents(filters={"page": ["nope"]})
assert len(result) == 0 assert len(result) == 0
@pytest.mark.unit @pytest.mark.unit
def test_incorrect_filter_nesting(self, docstore, filterable_docs): def test_incorrect_filter_nesting(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(ValueError, match="malformed"): with pytest.raises(ValueError, match="malformed"):
docstore.filter_documents(filters={"number": {"page": "100"}}) docstore.filter_documents(filters={"number": {"page": "100"}})
@pytest.mark.unit @pytest.mark.unit
def test_deeper_incorrect_filter_nesting(self, docstore, filterable_docs): def test_deeper_incorrect_filter_nesting(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(ValueError, match="malformed"): with pytest.raises(ValueError, match="malformed"):
docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}}) docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}})
@pytest.mark.unit @pytest.mark.unit
def test_eq_filter_explicit(self, docstore, filterable_docs): def test_eq_filter_explicit(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": {"$eq": "100"}}) result = docstore.filter_documents(filters={"page": {"$eq": "100"}})
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
@pytest.mark.unit @pytest.mark.unit
def test_eq_filter_implicit(self, docstore, filterable_docs): def test_eq_filter_implicit(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": "100"}) result = docstore.filter_documents(filters={"page": "100"})
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
@pytest.mark.unit @pytest.mark.unit
def test_eq_filter_table(self, docstore, filterable_docs): def test_eq_filter_table(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"content": pd.DataFrame([1])}) result = docstore.filter_documents(filters={"content": pd.DataFrame([1])})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -188,7 +188,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_eq_filter_tensor(self, docstore, filterable_docs): def test_eq_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
embedding = np.zeros([768, 1]).astype(np.float32) embedding = np.zeros([768, 1]).astype(np.float32)
result = docstore.filter_documents(filters={"embedding": embedding}) result = docstore.filter_documents(filters={"embedding": embedding})
@ -197,25 +197,25 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_deeper_incorrect_filter_nesting(self, docstore, filterable_docs): def test_deeper_incorrect_filter_nesting(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(ValueError, match="malformed"): with pytest.raises(ValueError, match="malformed"):
docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}}) docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}})
@pytest.mark.unit @pytest.mark.unit
def test_eq_filter_explicit(self, docstore, filterable_docs): def test_eq_filter_explicit(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": {"$eq": "100"}}) result = docstore.filter_documents(filters={"page": {"$eq": "100"}})
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
@pytest.mark.unit @pytest.mark.unit
def test_eq_filter_implicit(self, docstore, filterable_docs): def test_eq_filter_implicit(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": "100"}) result = docstore.filter_documents(filters={"page": "100"})
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"]) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
@pytest.mark.unit @pytest.mark.unit
def test_in_filter_explicit(self, docstore, filterable_docs): def test_in_filter_explicit(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": {"$in": ["100", "123", "n.a."]}}) result = docstore.filter_documents(filters={"page": {"$in": ["100", "123", "n.a."]}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -223,7 +223,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_in_filter_implicit(self, docstore, filterable_docs): def test_in_filter_implicit(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": ["100", "123", "n.a."]}) result = docstore.filter_documents(filters={"page": ["100", "123", "n.a."]})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -231,7 +231,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_in_filter_table(self, docstore, filterable_docs): def test_in_filter_table(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"content": {"$in": [pd.DataFrame([1]), pd.DataFrame([2])]}}) result = docstore.filter_documents(filters={"content": {"$in": [pd.DataFrame([1]), pd.DataFrame([2])]}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -245,7 +245,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_in_filter_tensor(self, docstore, filterable_docs): def test_in_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
embedding_zero = np.zeros([768, 1]).astype(np.float32) embedding_zero = np.zeros([768, 1]).astype(np.float32)
embedding_one = np.ones([768, 1]).astype(np.float32) embedding_one = np.ones([768, 1]).astype(np.float32)
@ -261,13 +261,13 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_ne_filter(self, docstore, filterable_docs): def test_ne_filter(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": {"$ne": "100"}}) result = docstore.filter_documents(filters={"page": {"$ne": "100"}})
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") != "100"]) assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") != "100"])
@pytest.mark.unit @pytest.mark.unit
def test_ne_filter_table(self, docstore, filterable_docs): def test_ne_filter_table(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"content": {"$ne": pd.DataFrame([1])}}) result = docstore.filter_documents(filters={"content": {"$ne": pd.DataFrame([1])}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -280,7 +280,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_ne_filter_tensor(self, docstore, filterable_docs): def test_ne_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
embedding = np.zeros([768, 1]).astype(np.float32) embedding = np.zeros([768, 1]).astype(np.float32)
result = docstore.filter_documents(filters={"embedding": {"$ne": embedding}}) result = docstore.filter_documents(filters={"embedding": {"$ne": embedding}})
@ -294,7 +294,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_nin_filter(self, docstore, filterable_docs): def test_nin_filter(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}}) result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -302,7 +302,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_nin_filter_table(self, docstore, filterable_docs): def test_nin_filter_table(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"content": {"$nin": [pd.DataFrame([1]), pd.DataFrame([0])]}}) result = docstore.filter_documents(filters={"content": {"$nin": [pd.DataFrame([1]), pd.DataFrame([0])]}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -316,7 +316,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_nin_filter_tensor(self, docstore, filterable_docs): def test_nin_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
embedding_zeros = np.zeros([768, 1]).astype(np.float32) embedding_zeros = np.zeros([768, 1]).astype(np.float32)
embedding_ones = np.zeros([768, 1]).astype(np.float32) embedding_ones = np.zeros([768, 1]).astype(np.float32)
@ -335,7 +335,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_nin_filter(self, docstore, filterable_docs): def test_nin_filter(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}}) result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -343,7 +343,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_gt_filter(self, docstore, filterable_docs): def test_gt_filter(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"number": {"$gt": 0.0}}) result = docstore.filter_documents(filters={"number": {"$gt": 0.0}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -351,26 +351,26 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_gt_filter_non_numeric(self, docstore, filterable_docs): def test_gt_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"page": {"$gt": "100"}}) docstore.filter_documents(filters={"page": {"$gt": "100"}})
@pytest.mark.unit @pytest.mark.unit
def test_gt_filter_table(self, docstore, filterable_docs): def test_gt_filter_table(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"content": {"$gt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) docstore.filter_documents(filters={"content": {"$gt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}})
@pytest.mark.unit @pytest.mark.unit
def test_gt_filter_tensor(self, docstore, filterable_docs): def test_gt_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
embedding_zeros = np.zeros([768, 1]).astype(np.float32) embedding_zeros = np.zeros([768, 1]).astype(np.float32)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"embedding": {"$gt": embedding_zeros}}) docstore.filter_documents(filters={"embedding": {"$gt": embedding_zeros}})
@pytest.mark.unit @pytest.mark.unit
def test_gte_filter(self, docstore, filterable_docs): def test_gte_filter(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"number": {"$gte": -2.0}}) result = docstore.filter_documents(filters={"number": {"$gte": -2.0}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -378,26 +378,26 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_gte_filter_non_numeric(self, docstore, filterable_docs): def test_gte_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"page": {"$gte": "100"}}) docstore.filter_documents(filters={"page": {"$gte": "100"}})
@pytest.mark.unit @pytest.mark.unit
def test_gte_filter_table(self, docstore, filterable_docs): def test_gte_filter_table(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"content": {"$gte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) docstore.filter_documents(filters={"content": {"$gte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}})
@pytest.mark.unit @pytest.mark.unit
def test_gte_filter_tensor(self, docstore, filterable_docs): def test_gte_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
embedding_zeros = np.zeros([768, 1]).astype(np.float32) embedding_zeros = np.zeros([768, 1]).astype(np.float32)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"embedding": {"$gte": embedding_zeros}}) docstore.filter_documents(filters={"embedding": {"$gte": embedding_zeros}})
@pytest.mark.unit @pytest.mark.unit
def test_lt_filter(self, docstore, filterable_docs): def test_lt_filter(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"number": {"$lt": 0.0}}) result = docstore.filter_documents(filters={"number": {"$lt": 0.0}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -405,26 +405,26 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_lt_filter_non_numeric(self, docstore, filterable_docs): def test_lt_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"page": {"$lt": "100"}}) docstore.filter_documents(filters={"page": {"$lt": "100"}})
@pytest.mark.unit @pytest.mark.unit
def test_lt_filter_table(self, docstore, filterable_docs): def test_lt_filter_table(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"content": {"$lt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) docstore.filter_documents(filters={"content": {"$lt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}})
@pytest.mark.unit @pytest.mark.unit
def test_lt_filter_tensor(self, docstore, filterable_docs): def test_lt_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
embedding_ones = np.ones([768, 1]).astype(np.float32) embedding_ones = np.ones([768, 1]).astype(np.float32)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"embedding": {"$lt": embedding_ones}}) docstore.filter_documents(filters={"embedding": {"$lt": embedding_ones}})
@pytest.mark.unit @pytest.mark.unit
def test_lte_filter(self, docstore, filterable_docs): def test_lte_filter(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"number": {"$lte": 2.0}}) result = docstore.filter_documents(filters={"number": {"$lte": 2.0}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -432,26 +432,26 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_lte_filter_non_numeric(self, docstore, filterable_docs): def test_lte_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"page": {"$lte": "100"}}) docstore.filter_documents(filters={"page": {"$lte": "100"}})
@pytest.mark.unit @pytest.mark.unit
def test_lte_filter_table(self, docstore, filterable_docs): def test_lte_filter_table(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"content": {"$lte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) docstore.filter_documents(filters={"content": {"$lte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}})
@pytest.mark.unit @pytest.mark.unit
def test_lte_filter_tensor(self, docstore, filterable_docs): def test_lte_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
embedding_ones = np.ones([768, 1]).astype(np.float32) embedding_ones = np.ones([768, 1]).astype(np.float32)
with pytest.raises(StoreError, match="Can't evaluate"): with pytest.raises(StoreError, match="Can't evaluate"):
docstore.filter_documents(filters={"embedding": {"$lte": embedding_ones}}) docstore.filter_documents(filters={"embedding": {"$lte": embedding_ones}})
@pytest.mark.unit @pytest.mark.unit
def test_filter_simple_implicit_and_with_multi_key_dict(self, docstore, filterable_docs): def test_filter_simple_implicit_and_with_multi_key_dict(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0.0}}) result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0.0}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -464,7 +464,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_simple_explicit_and_with_multikey_dict(self, docstore, filterable_docs): def test_filter_simple_explicit_and_with_multikey_dict(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"number": {"$and": {"$lte": 0, "$gte": -2}}}) result = docstore.filter_documents(filters={"number": {"$and": {"$lte": 0, "$gte": -2}}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -477,7 +477,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_simple_explicit_and_with_list(self, docstore, filterable_docs): def test_filter_simple_explicit_and_with_list(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}) result = docstore.filter_documents(filters={"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -490,7 +490,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_simple_implicit_and(self, docstore, filterable_docs): def test_filter_simple_implicit_and(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0}}) result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0}})
assert self.contains_same_docs( assert self.contains_same_docs(
@ -503,7 +503,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_nested_explicit_and(self, docstore, filterable_docs): def test_filter_nested_explicit_and(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
filters = {"$and": {"number": {"$and": {"$lte": 2, "$gte": 0}}, "name": {"$in": ["name_0", "name_1"]}}} filters = {"$and": {"number": {"$and": {"$lte": 2, "$gte": 0}}, "name": {"$in": ["name_0", "name_1"]}}}
result = docstore.filter_documents(filters=filters) result = docstore.filter_documents(filters=filters)
@ -522,7 +522,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_nested_implicit_and(self, docstore, filterable_docs): def test_filter_nested_implicit_and(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
filters_simplified = {"number": {"$lte": 2, "$gte": 0}, "name": ["name_0", "name_1"]} filters_simplified = {"number": {"$lte": 2, "$gte": 0}, "name": ["name_0", "name_1"]}
result = docstore.filter_documents(filters=filters_simplified) result = docstore.filter_documents(filters=filters_simplified)
@ -541,7 +541,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_simple_or(self, docstore, filterable_docs): def test_filter_simple_or(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
filters = {"$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} filters = {"$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}}
result = docstore.filter_documents(filters=filters) result = docstore.filter_documents(filters=filters)
@ -558,7 +558,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_nested_or(self, docstore, filterable_docs): def test_filter_nested_or(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
filters = {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}} filters = {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}}
result = docstore.filter_documents(filters=filters) result = docstore.filter_documents(filters=filters)
@ -575,7 +575,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_nested_and_or_explicit(self, docstore, filterable_docs): def test_filter_nested_and_or_explicit(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
filters_simplified = { filters_simplified = {
"$and": {"page": {"$eq": "123"}, "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} "$and": {"page": {"$eq": "123"}, "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}}
@ -597,7 +597,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_nested_and_or_implicit(self, docstore, filterable_docs): def test_filter_nested_and_or_implicit(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
filters_simplified = { filters_simplified = {
"page": {"$eq": "123"}, "page": {"$eq": "123"},
@ -620,7 +620,7 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_nested_or_and(self, docstore, filterable_docs): def test_filter_nested_or_and(self, docstore: Store, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
filters_simplified = { filters_simplified = {
"$or": { "$or": {
@ -645,7 +645,9 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_filter_nested_multiple_identical_operators_same_level(self, docstore, filterable_docs): def test_filter_nested_multiple_identical_operators_same_level(
self, docstore: Store, filterable_docs: List[Document]
):
docstore.write_documents(filterable_docs) docstore.write_documents(filterable_docs)
filters = { filters = {
"$or": [ "$or": [
@ -667,54 +669,54 @@ class DocumentStoreBaseTests:
) )
@pytest.mark.unit @pytest.mark.unit
def test_write(self, docstore): def test_write(self, docstore: Store):
doc = Document(content="test doc") doc = Document(content="test doc")
docstore.write_documents([doc]) docstore.write_documents([doc])
assert docstore.filter_documents(filters={"id": doc.id}) == [doc] assert docstore.filter_documents(filters={"id": doc.id}) == [doc]
@pytest.mark.unit @pytest.mark.unit
def test_write_duplicate_fail(self, docstore): def test_write_duplicate_fail(self, docstore: Store):
doc = Document(content="test doc") doc = Document(content="test doc")
docstore.write_documents([doc]) docstore.write_documents([doc])
with pytest.raises(DuplicateDocumentError, match=f"ID '{doc.id}' already exists."): with pytest.raises(DuplicateDocumentError, match=f"ID '{doc.id}' already exists."):
docstore.write_documents(documents=[doc]) docstore.write_documents(documents=[doc], policy=DuplicatePolicy.FAIL)
assert docstore.filter_documents(filters={"id": doc.id}) == [doc] assert docstore.filter_documents(filters={"id": doc.id}) == [doc]
@pytest.mark.unit @pytest.mark.unit
def test_write_duplicate_skip(self, docstore): def test_write_duplicate_skip(self, docstore: Store):
doc = Document(content="test doc") doc = Document(content="test doc")
docstore.write_documents([doc]) docstore.write_documents([doc])
docstore.write_documents(documents=[doc], duplicates="skip") docstore.write_documents(documents=[doc], policy=DuplicatePolicy.SKIP)
assert docstore.filter_documents(filters={"id": doc.id}) == [doc] assert docstore.filter_documents(filters={"id": doc.id}) == [doc]
@pytest.mark.unit @pytest.mark.unit
def test_write_duplicate_overwrite(self, docstore): def test_write_duplicate_overwrite(self, docstore: Store):
doc1 = Document(content="test doc 1") doc1 = Document(content="test doc 1")
doc2 = Document(content="test doc 2") doc2 = Document(content="test doc 2")
object.__setattr__(doc2, "id", doc1.id) # Make two docs with different content but same ID object.__setattr__(doc2, "id", doc1.id) # Make two docs with different content but same ID
docstore.write_documents([doc2]) docstore.write_documents([doc2])
docstore.filter_documents(filters={"id": doc1.id}) == [doc2] docstore.filter_documents(filters={"id": doc1.id}) == [doc2]
docstore.write_documents(documents=[doc1], duplicates="overwrite") docstore.write_documents(documents=[doc1], policy=DuplicatePolicy.OVERWRITE)
assert docstore.filter_documents(filters={"id": doc1.id}) == [doc1] assert docstore.filter_documents(filters={"id": doc1.id}) == [doc1]
@pytest.mark.unit @pytest.mark.unit
def test_write_not_docs(self, docstore): def test_write_not_docs(self, docstore: Store):
with pytest.raises(ValueError, match="Please provide a list of Documents"): with pytest.raises(ValueError, match="Please provide a list of Documents"):
docstore.write_documents(["not a document for sure"]) docstore.write_documents(["not a document for sure"])
@pytest.mark.unit @pytest.mark.unit
def test_write_not_list(self, docstore): def test_write_not_list(self, docstore: Store):
with pytest.raises(ValueError, match="Please provide a list of Documents"): with pytest.raises(ValueError, match="Please provide a list of Documents"):
docstore.write_documents("not a list actually") docstore.write_documents("not a list actually")
@pytest.mark.unit @pytest.mark.unit
def test_delete_empty(self, docstore): def test_delete_empty(self, docstore: Store):
with pytest.raises(MissingDocumentError): with pytest.raises(MissingDocumentError):
docstore.delete_documents(["test"]) docstore.delete_documents(["test"])
@pytest.mark.unit @pytest.mark.unit
def test_delete_not_empty(self, docstore): def test_delete_not_empty(self, docstore: Store):
doc = Document(content="test doc") doc = Document(content="test doc")
docstore.write_documents([doc]) docstore.write_documents([doc])
@ -724,7 +726,7 @@ class DocumentStoreBaseTests:
assert docstore.filter_documents(filters={"id": doc.id}) assert docstore.filter_documents(filters={"id": doc.id})
@pytest.mark.unit @pytest.mark.unit
def test_delete_not_empty_nonexisting(self, docstore): def test_delete_not_empty_nonexisting(self, docstore: Store):
doc = Document(content="test doc") doc = Document(content="test doc")
docstore.write_documents([doc]) docstore.write_documents([doc])

View File

@ -4,7 +4,7 @@ import pandas as pd
import pytest import pytest
from haystack.preview import Document from haystack.preview import Document
from haystack.preview.document_stores import MemoryDocumentStore from haystack.preview.document_stores import Store, MemoryDocumentStore
from test.preview.document_stores._base import DocumentStoreBaseTests from test.preview.document_stores._base import DocumentStoreBaseTests
@ -19,7 +19,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
return MemoryDocumentStore() return MemoryDocumentStore()
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval(self, docstore): def test_bm25_retrieval(self, docstore: Store):
docstore = MemoryDocumentStore() docstore = MemoryDocumentStore()
# Tests if the bm25_retrieval method returns the correct document based on the input query. # Tests if the bm25_retrieval method returns the correct document based on the input query.
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
@ -29,7 +29,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
assert results[0].content == "Haystack supports multiple languages" assert results[0].content == "Haystack supports multiple languages"
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_with_empty_document_store(self, docstore, caplog): def test_bm25_retrieval_with_empty_document_store(self, docstore: Store, caplog):
caplog.set_level(logging.INFO) caplog.set_level(logging.INFO)
# Tests if the bm25_retrieval method correctly returns an empty list when there are no documents in the store. # Tests if the bm25_retrieval method correctly returns an empty list when there are no documents in the store.
results = docstore.bm25_retrieval(query="How to test this?", top_k=2) results = docstore.bm25_retrieval(query="How to test this?", top_k=2)
@ -37,7 +37,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
assert "No documents found for BM25 retrieval. Returning empty list." in caplog.text assert "No documents found for BM25 retrieval. Returning empty list." in caplog.text
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_empty_query(self, docstore): def test_bm25_retrieval_empty_query(self, docstore: Store):
# Tests if the bm25_retrieval method returns a document when the query is an empty string. # Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
docstore.write_documents(docs) docstore.write_documents(docs)
@ -45,7 +45,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
docstore.bm25_retrieval(query="", top_k=1) docstore.bm25_retrieval(query="", top_k=1)
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_string(self, docstore): def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_string(self, docstore: Store):
# Tests if the bm25_retrieval method returns a document when the query is an empty string. # Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [ docs = [
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
@ -57,7 +57,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
assert results[0].content == "Haystack supports multiple languages" assert results[0].content == "Haystack supports multiple languages"
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_list(self, docstore): def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_list(self, docstore: Store):
# Tests if the bm25_retrieval method returns a document when the query is an empty string. # Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [ docs = [
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
@ -69,7 +69,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
assert results[0].content == "Haystack supports multiple languages" assert results[0].content == "Haystack supports multiple languages"
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_filter_two_allowed_doc_type_as_list(self, docstore): def test_bm25_retrieval_filter_two_allowed_doc_type_as_list(self, docstore: Store):
# Tests if the bm25_retrieval method returns a document when the query is an empty string. # Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [ docs = [
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
@ -80,7 +80,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
assert len(results) == 2 assert len(results) == 2
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_string(self, docstore): def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_string(self, docstore: Store):
# Tests if the bm25_retrieval method returns a document when the query is an empty string. # Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [ docs = [
Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"), Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"),
@ -93,7 +93,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "audio"}) docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "audio"})
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_list(self, docstore): def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_list(self, docstore: Store):
# Tests if the bm25_retrieval method returns a document when the query is an empty string. # Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [ docs = [
Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"), Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"),
@ -106,7 +106,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["audio"]}) docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["audio"]})
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_filter_two_not_all_allowed_doc_type_as_list(self, docstore): def test_bm25_retrieval_filter_two_not_all_allowed_doc_type_as_list(self, docstore: Store):
# Tests if the bm25_retrieval method returns a document when the query is an empty string. # Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [ docs = [
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
@ -119,7 +119,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "audio"]}) docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "audio"]})
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_with_different_top_k(self, docstore): def test_bm25_retrieval_with_different_top_k(self, docstore: Store):
# Tests if the bm25_retrieval method correctly changes the number of returned documents # Tests if the bm25_retrieval method correctly changes the number of returned documents
# based on the top_k parameter. # based on the top_k parameter.
docs = [ docs = [
@ -139,7 +139,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
# Test two queries and make sure the results are different # Test two queries and make sure the results are different
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_with_two_queries(self, docstore): def test_bm25_retrieval_with_two_queries(self, docstore: Store):
# Tests if the bm25_retrieval method returns different documents for different queries. # Tests if the bm25_retrieval method returns different documents for different queries.
docs = [ docs = [
Document(content="Javascript is a popular programming language"), Document(content="Javascript is a popular programming language"),
@ -158,7 +158,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
# Test a query, add a new document and make sure results are appropriately updated # Test a query, add a new document and make sure results are appropriately updated
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_with_updated_docs(self, docstore): def test_bm25_retrieval_with_updated_docs(self, docstore: Store):
# Tests if the bm25_retrieval method correctly updates the retrieved documents when new # Tests if the bm25_retrieval method correctly updates the retrieved documents when new
# documents are added to the store. # documents are added to the store.
docs = [Document(content="Hello world")] docs = [Document(content="Hello world")]
@ -178,7 +178,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
assert results[0].content == "Python is a popular programming language" assert results[0].content == "Python is a popular programming language"
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_with_scale_score(self, docstore): def test_bm25_retrieval_with_scale_score(self, docstore: Store):
docs = [Document(content="Python programming"), Document(content="Java programming")] docs = [Document(content="Python programming"), Document(content="Java programming")]
docstore.write_documents(docs) docstore.write_documents(docs)
@ -191,7 +191,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
assert results[0].score != results1[0].score assert results[0].score != results1[0].score
@pytest.mark.unit @pytest.mark.unit
def test_bm25_retrieval_with_table_content(self, docstore): def test_bm25_retrieval_with_table_content(self, docstore: Store):
# Tests if the bm25_retrieval method correctly returns a dataframe when the content_type is table. # Tests if the bm25_retrieval method correctly returns a dataframe when the content_type is table.
table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]}) table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]})
docs = [ docs = [