mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-29 00:39:05 +00:00
feat: introduce Store protocol (v2) (#5259)
* add protocol and adapt pipeline * review feedback & update tests * pylint * Update haystack/preview/document_stores/protocols.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update haystack/preview/document_stores/memory/document_store.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * docstring of Store * adapt memorydocumentstore * fix tests --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
90ff3817e7
commit
f49bd3a12f
@ -1,2 +1,3 @@
|
||||
from haystack.preview.document_stores.protocols import Store, DuplicatePolicy
|
||||
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
|
||||
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError
|
||||
|
||||
@ -8,12 +8,12 @@ import rank_bm25
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores.protocols import DuplicatePolicy
|
||||
from haystack.preview.document_stores.memory._filters import match
|
||||
from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError
|
||||
from haystack.utils.scipy_utils import expit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DuplicatePolicy = Literal["skip", "overwrite", "fail"]
|
||||
|
||||
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
|
||||
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a SCALING_FACTOR. A
|
||||
@ -126,17 +126,17 @@ class MemoryDocumentStore:
|
||||
return [doc for doc in self.storage.values() if match(conditions=filters, document=doc)]
|
||||
return list(self.storage.values())
|
||||
|
||||
def write_documents(self, documents: List[Document], duplicates: DuplicatePolicy = "fail") -> None:
|
||||
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
|
||||
"""
|
||||
Writes (or overwrites) documents into the store.
|
||||
|
||||
:param documents: a list of documents.
|
||||
:param duplicates: documents with the same ID count as duplicates. When duplicates are met,
|
||||
:param policy: documents with the same ID count as duplicates. When duplicates are met,
|
||||
the store can:
|
||||
- skip: keep the existing document and ignore the new one.
|
||||
- overwrite: remove the old document and write the new one.
|
||||
- fail: an error is raised
|
||||
:raises DuplicateError: Exception trigger on duplicate document if `duplicates="fail"`
|
||||
:raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL`
|
||||
:return: None
|
||||
"""
|
||||
if (
|
||||
@ -147,10 +147,10 @@ class MemoryDocumentStore:
|
||||
raise ValueError("Please provide a list of Documents.")
|
||||
|
||||
for document in documents:
|
||||
if document.id in self.storage.keys():
|
||||
if duplicates == "fail":
|
||||
if policy != DuplicatePolicy.OVERWRITE and document.id in self.storage.keys():
|
||||
if policy == DuplicatePolicy.FAIL:
|
||||
raise DuplicateDocumentError(f"ID '{document.id}' already exists.")
|
||||
if duplicates == "skip":
|
||||
if policy == DuplicatePolicy.SKIP:
|
||||
logger.warning("ID '%s' already exists", document.id)
|
||||
self.storage[document.id] = document
|
||||
|
||||
|
||||
126
haystack/preview/document_stores/protocols.py
Normal file
126
haystack/preview/document_stores/protocols.py
Normal file
@ -0,0 +1,126 @@
|
||||
from typing import Protocol, Optional, Dict, Any, List
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from haystack.preview.dataclasses import Document
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DuplicatePolicy(Enum):
|
||||
SKIP = "skip"
|
||||
OVERWRITE = "overwrite"
|
||||
FAIL = "fail"
|
||||
|
||||
|
||||
class Store(Protocol):
|
||||
"""
|
||||
Stores Documents to be used by the components of a Pipeline.
|
||||
|
||||
Classes implementing this protocol often store the documents permanently and allow specialized components to
|
||||
perform retrieval on them, either by embedding, by keyword, hybrid, and so on, depending on the backend used.
|
||||
|
||||
In order to retrieve documents, consider using a Retriever that supports the document store implementation that
|
||||
you're using.
|
||||
"""
|
||||
|
||||
def count_documents(self) -> int:
|
||||
"""
|
||||
Returns the number of documents stored.
|
||||
"""
|
||||
|
||||
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||
"""
|
||||
Returns the documents that match the filters provided.
|
||||
|
||||
Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`,
|
||||
`"$or"`, `"$not"`), a comparison operator (`"$eq"`, `$ne`, `"$in"`, `$nin`, `"$gt"`, `"$gte"`, `"$lt"`,
|
||||
`"$lte"`) or a metadata field name.
|
||||
|
||||
Logical operator keys take a dictionary of metadata field names and/or logical operators as value. Metadata
|
||||
field names take a dictionary of comparison operators as value. Comparison operator keys take a single value or
|
||||
(in case of `"$in"`) a list of values as value. If no logical operator is provided, `"$and"` is used as default
|
||||
operation. If no comparison operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used
|
||||
as default operation.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
filters = {
|
||||
"$and": {
|
||||
"type": {"$eq": "article"},
|
||||
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||||
"rating": {"$gte": 3},
|
||||
"$or": {
|
||||
"genre": {"$in": ["economy", "politics"]},
|
||||
"publisher": {"$eq": "nytimes"}
|
||||
}
|
||||
}
|
||||
}
|
||||
# or simpler using default operators
|
||||
filters = {
|
||||
"type": "article",
|
||||
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||||
"rating": {"$gte": 3},
|
||||
"$or": {
|
||||
"genre": ["economy", "politics"],
|
||||
"publisher": "nytimes"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
To use the same logical operator multiple times on the same level, logical operators can take a list of
|
||||
dictionaries as value.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
filters = {
|
||||
"$or": [
|
||||
{
|
||||
"$and": {
|
||||
"Type": "News Paper",
|
||||
"Date": {
|
||||
"$lt": "2019-01-01"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$and": {
|
||||
"Type": "Blog Post",
|
||||
"Date": {
|
||||
"$gte": "2019-01-01"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
:param filters: the filters to apply to the document list.
|
||||
:return: a list of Documents that match the given filters.
|
||||
"""
|
||||
|
||||
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
|
||||
"""
|
||||
Writes (or overwrites) documents into the store.
|
||||
|
||||
:param documents: a list of documents.
|
||||
:param policy: documents with the same ID count as duplicates. When duplicates are met,
|
||||
the store can:
|
||||
- skip: keep the existing document and ignore the new one.
|
||||
- overwrite: remove the old document and write the new one.
|
||||
- fail: an error is raised
|
||||
:raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL`
|
||||
:return: None
|
||||
"""
|
||||
|
||||
def delete_documents(self, document_ids: List[str]) -> None:
|
||||
"""
|
||||
Deletes all documents with a matching document_ids from the document store.
|
||||
Fails with `MissingDocumentError` if no document with this id is present in the store.
|
||||
|
||||
:param object_ids: the object_ids to delete
|
||||
"""
|
||||
@ -11,6 +11,8 @@ from canals.pipeline import (
|
||||
)
|
||||
from canals.pipeline.sockets import find_input_sockets
|
||||
|
||||
from haystack.preview.document_stores.protocols import Store
|
||||
|
||||
|
||||
class NoSuchStoreError(PipelineError):
|
||||
pass
|
||||
@ -23,9 +25,9 @@ class Pipeline(CanalsPipeline):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stores = {}
|
||||
self.stores: Dict[str, Store] = {}
|
||||
|
||||
def add_store(self, name: str, store: object) -> None:
|
||||
def add_store(self, name: str, store: Store) -> None:
|
||||
"""
|
||||
Make a store available to all nodes of this pipeline.
|
||||
|
||||
@ -43,7 +45,7 @@ class Pipeline(CanalsPipeline):
|
||||
"""
|
||||
return list(self.stores.keys())
|
||||
|
||||
def get_store(self, name: str) -> object:
|
||||
def get_store(self, name: str) -> Store:
|
||||
"""
|
||||
Returns the store associated with the given name.
|
||||
|
||||
|
||||
@ -5,17 +5,17 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores import StoreError
|
||||
from haystack.preview.document_stores import Store, StoreError, DuplicatePolicy
|
||||
from haystack.preview.document_stores import MissingDocumentError, DuplicateDocumentError
|
||||
|
||||
|
||||
class DocumentStoreBaseTests:
|
||||
@pytest.fixture
|
||||
def docstore(self):
|
||||
def docstore(self) -> Store:
|
||||
raise NotImplementedError()
|
||||
|
||||
@pytest.fixture
|
||||
def filterable_docs(self):
|
||||
def filterable_docs(self) -> List[Document]:
|
||||
embedding_zero = np.zeros([768, 1]).astype(np.float32)
|
||||
embedding_one = np.ones([768, 1]).astype(np.float32)
|
||||
|
||||
@ -70,42 +70,42 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_count_empty(self, docstore):
|
||||
def test_count_empty(self, docstore: Store):
|
||||
assert docstore.count_documents() == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_count_not_empty(self, docstore):
|
||||
def test_count_not_empty(self, docstore: Store):
|
||||
docstore.write_documents(
|
||||
[Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")]
|
||||
)
|
||||
assert docstore.count_documents() == 3
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_filter_empty(self, docstore):
|
||||
def test_no_filter_empty(self, docstore: Store):
|
||||
assert docstore.filter_documents() == []
|
||||
assert docstore.filter_documents(filters={}) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_filter_not_empty(self, docstore):
|
||||
def test_no_filter_not_empty(self, docstore: Store):
|
||||
docs = [Document(content="test doc")]
|
||||
docstore.write_documents(docs)
|
||||
assert docstore.filter_documents() == docs
|
||||
assert docstore.filter_documents(filters={}) == docs
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_simple_metadata_value(self, docstore, filterable_docs):
|
||||
def test_filter_simple_metadata_value(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": "100"})
|
||||
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_simple_list_single_element(self, docstore, filterable_docs):
|
||||
def test_filter_simple_list_single_element(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": ["100"]})
|
||||
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_document_content(self, docstore, filterable_docs):
|
||||
def test_filter_document_content(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"content": "A Foo Document 1"})
|
||||
assert self.contains_same_docs(
|
||||
@ -113,19 +113,19 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_document_type(self, docstore, filterable_docs):
|
||||
def test_filter_document_type(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"content_type": "table"})
|
||||
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.content_type == "table"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_simple_list_one_value(self, docstore, filterable_docs):
|
||||
def test_filter_simple_list_one_value(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": ["100"]})
|
||||
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") in ["100"]])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_simple_list(self, docstore, filterable_docs):
|
||||
def test_filter_simple_list(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": ["100", "123"]})
|
||||
assert self.contains_same_docs(
|
||||
@ -133,49 +133,49 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_incorrect_filter_name(self, docstore, filterable_docs):
|
||||
def test_incorrect_filter_name(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"non_existing_meta_field": ["whatever"]})
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_incorrect_filter_type(self, docstore, filterable_docs):
|
||||
def test_incorrect_filter_type(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(ValueError, match="dictionaries or lists"):
|
||||
docstore.filter_documents(filters="something odd")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_incorrect_filter_value(self, docstore, filterable_docs):
|
||||
def test_incorrect_filter_value(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": ["nope"]})
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_incorrect_filter_nesting(self, docstore, filterable_docs):
|
||||
def test_incorrect_filter_nesting(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(ValueError, match="malformed"):
|
||||
docstore.filter_documents(filters={"number": {"page": "100"}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deeper_incorrect_filter_nesting(self, docstore, filterable_docs):
|
||||
def test_deeper_incorrect_filter_nesting(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(ValueError, match="malformed"):
|
||||
docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eq_filter_explicit(self, docstore, filterable_docs):
|
||||
def test_eq_filter_explicit(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": {"$eq": "100"}})
|
||||
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eq_filter_implicit(self, docstore, filterable_docs):
|
||||
def test_eq_filter_implicit(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": "100"})
|
||||
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eq_filter_table(self, docstore, filterable_docs):
|
||||
def test_eq_filter_table(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"content": pd.DataFrame([1])})
|
||||
assert self.contains_same_docs(
|
||||
@ -188,7 +188,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eq_filter_tensor(self, docstore, filterable_docs):
|
||||
def test_eq_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
embedding = np.zeros([768, 1]).astype(np.float32)
|
||||
result = docstore.filter_documents(filters={"embedding": embedding})
|
||||
@ -197,25 +197,25 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deeper_incorrect_filter_nesting(self, docstore, filterable_docs):
|
||||
def test_deeper_incorrect_filter_nesting(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(ValueError, match="malformed"):
|
||||
docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eq_filter_explicit(self, docstore, filterable_docs):
|
||||
def test_eq_filter_explicit(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": {"$eq": "100"}})
|
||||
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eq_filter_implicit(self, docstore, filterable_docs):
|
||||
def test_eq_filter_implicit(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": "100"})
|
||||
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") == "100"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_in_filter_explicit(self, docstore, filterable_docs):
|
||||
def test_in_filter_explicit(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": {"$in": ["100", "123", "n.a."]}})
|
||||
assert self.contains_same_docs(
|
||||
@ -223,7 +223,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_in_filter_implicit(self, docstore, filterable_docs):
|
||||
def test_in_filter_implicit(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": ["100", "123", "n.a."]})
|
||||
assert self.contains_same_docs(
|
||||
@ -231,7 +231,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_in_filter_table(self, docstore, filterable_docs):
|
||||
def test_in_filter_table(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"content": {"$in": [pd.DataFrame([1]), pd.DataFrame([2])]}})
|
||||
assert self.contains_same_docs(
|
||||
@ -245,7 +245,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_in_filter_tensor(self, docstore, filterable_docs):
|
||||
def test_in_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
embedding_zero = np.zeros([768, 1]).astype(np.float32)
|
||||
embedding_one = np.ones([768, 1]).astype(np.float32)
|
||||
@ -261,13 +261,13 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ne_filter(self, docstore, filterable_docs):
|
||||
def test_ne_filter(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": {"$ne": "100"}})
|
||||
assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.metadata.get("page") != "100"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ne_filter_table(self, docstore, filterable_docs):
|
||||
def test_ne_filter_table(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"content": {"$ne": pd.DataFrame([1])}})
|
||||
assert self.contains_same_docs(
|
||||
@ -280,7 +280,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ne_filter_tensor(self, docstore, filterable_docs):
|
||||
def test_ne_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
embedding = np.zeros([768, 1]).astype(np.float32)
|
||||
result = docstore.filter_documents(filters={"embedding": {"$ne": embedding}})
|
||||
@ -294,7 +294,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_nin_filter(self, docstore, filterable_docs):
|
||||
def test_nin_filter(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}})
|
||||
assert self.contains_same_docs(
|
||||
@ -302,7 +302,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_nin_filter_table(self, docstore, filterable_docs):
|
||||
def test_nin_filter_table(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"content": {"$nin": [pd.DataFrame([1]), pd.DataFrame([0])]}})
|
||||
assert self.contains_same_docs(
|
||||
@ -316,7 +316,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_nin_filter_tensor(self, docstore, filterable_docs):
|
||||
def test_nin_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
embedding_zeros = np.zeros([768, 1]).astype(np.float32)
|
||||
embedding_ones = np.zeros([768, 1]).astype(np.float32)
|
||||
@ -335,7 +335,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_nin_filter(self, docstore, filterable_docs):
|
||||
def test_nin_filter(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}})
|
||||
assert self.contains_same_docs(
|
||||
@ -343,7 +343,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gt_filter(self, docstore, filterable_docs):
|
||||
def test_gt_filter(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"number": {"$gt": 0.0}})
|
||||
assert self.contains_same_docs(
|
||||
@ -351,26 +351,26 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gt_filter_non_numeric(self, docstore, filterable_docs):
|
||||
def test_gt_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"page": {"$gt": "100"}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gt_filter_table(self, docstore, filterable_docs):
|
||||
def test_gt_filter_table(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"content": {"$gt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gt_filter_tensor(self, docstore, filterable_docs):
|
||||
def test_gt_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
embedding_zeros = np.zeros([768, 1]).astype(np.float32)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"embedding": {"$gt": embedding_zeros}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gte_filter(self, docstore, filterable_docs):
|
||||
def test_gte_filter(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"number": {"$gte": -2.0}})
|
||||
assert self.contains_same_docs(
|
||||
@ -378,26 +378,26 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gte_filter_non_numeric(self, docstore, filterable_docs):
|
||||
def test_gte_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"page": {"$gte": "100"}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gte_filter_table(self, docstore, filterable_docs):
|
||||
def test_gte_filter_table(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"content": {"$gte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gte_filter_tensor(self, docstore, filterable_docs):
|
||||
def test_gte_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
embedding_zeros = np.zeros([768, 1]).astype(np.float32)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"embedding": {"$gte": embedding_zeros}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_lt_filter(self, docstore, filterable_docs):
|
||||
def test_lt_filter(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"number": {"$lt": 0.0}})
|
||||
assert self.contains_same_docs(
|
||||
@ -405,26 +405,26 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_lt_filter_non_numeric(self, docstore, filterable_docs):
|
||||
def test_lt_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"page": {"$lt": "100"}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_lt_filter_table(self, docstore, filterable_docs):
|
||||
def test_lt_filter_table(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"content": {"$lt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_lt_filter_tensor(self, docstore, filterable_docs):
|
||||
def test_lt_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
embedding_ones = np.ones([768, 1]).astype(np.float32)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"embedding": {"$lt": embedding_ones}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_lte_filter(self, docstore, filterable_docs):
|
||||
def test_lte_filter(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"number": {"$lte": 2.0}})
|
||||
assert self.contains_same_docs(
|
||||
@ -432,26 +432,26 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_lte_filter_non_numeric(self, docstore, filterable_docs):
|
||||
def test_lte_filter_non_numeric(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"page": {"$lte": "100"}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_lte_filter_table(self, docstore, filterable_docs):
|
||||
def test_lte_filter_table(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"content": {"$lte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_lte_filter_tensor(self, docstore, filterable_docs):
|
||||
def test_lte_filter_tensor(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
embedding_ones = np.ones([768, 1]).astype(np.float32)
|
||||
with pytest.raises(StoreError, match="Can't evaluate"):
|
||||
docstore.filter_documents(filters={"embedding": {"$lte": embedding_ones}})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_simple_implicit_and_with_multi_key_dict(self, docstore, filterable_docs):
|
||||
def test_filter_simple_implicit_and_with_multi_key_dict(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0.0}})
|
||||
assert self.contains_same_docs(
|
||||
@ -464,7 +464,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_simple_explicit_and_with_multikey_dict(self, docstore, filterable_docs):
|
||||
def test_filter_simple_explicit_and_with_multikey_dict(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"number": {"$and": {"$lte": 0, "$gte": -2}}})
|
||||
assert self.contains_same_docs(
|
||||
@ -477,7 +477,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_simple_explicit_and_with_list(self, docstore, filterable_docs):
|
||||
def test_filter_simple_explicit_and_with_list(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}})
|
||||
assert self.contains_same_docs(
|
||||
@ -490,7 +490,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_simple_implicit_and(self, docstore, filterable_docs):
|
||||
def test_filter_simple_implicit_and(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0}})
|
||||
assert self.contains_same_docs(
|
||||
@ -503,7 +503,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_nested_explicit_and(self, docstore, filterable_docs):
|
||||
def test_filter_nested_explicit_and(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
filters = {"$and": {"number": {"$and": {"$lte": 2, "$gte": 0}}, "name": {"$in": ["name_0", "name_1"]}}}
|
||||
result = docstore.filter_documents(filters=filters)
|
||||
@ -522,7 +522,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_nested_implicit_and(self, docstore, filterable_docs):
|
||||
def test_filter_nested_implicit_and(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
filters_simplified = {"number": {"$lte": 2, "$gte": 0}, "name": ["name_0", "name_1"]}
|
||||
result = docstore.filter_documents(filters=filters_simplified)
|
||||
@ -541,7 +541,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_simple_or(self, docstore, filterable_docs):
|
||||
def test_filter_simple_or(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
filters = {"$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}}
|
||||
result = docstore.filter_documents(filters=filters)
|
||||
@ -558,7 +558,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_nested_or(self, docstore, filterable_docs):
|
||||
def test_filter_nested_or(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
filters = {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}}
|
||||
result = docstore.filter_documents(filters=filters)
|
||||
@ -575,7 +575,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_nested_and_or_explicit(self, docstore, filterable_docs):
|
||||
def test_filter_nested_and_or_explicit(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
filters_simplified = {
|
||||
"$and": {"page": {"$eq": "123"}, "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}}
|
||||
@ -597,7 +597,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_nested_and_or_implicit(self, docstore, filterable_docs):
|
||||
def test_filter_nested_and_or_implicit(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
filters_simplified = {
|
||||
"page": {"$eq": "123"},
|
||||
@ -620,7 +620,7 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_nested_or_and(self, docstore, filterable_docs):
|
||||
def test_filter_nested_or_and(self, docstore: Store, filterable_docs: List[Document]):
|
||||
docstore.write_documents(filterable_docs)
|
||||
filters_simplified = {
|
||||
"$or": {
|
||||
@ -645,7 +645,9 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filter_nested_multiple_identical_operators_same_level(self, docstore, filterable_docs):
|
||||
def test_filter_nested_multiple_identical_operators_same_level(
|
||||
self, docstore: Store, filterable_docs: List[Document]
|
||||
):
|
||||
docstore.write_documents(filterable_docs)
|
||||
filters = {
|
||||
"$or": [
|
||||
@ -667,54 +669,54 @@ class DocumentStoreBaseTests:
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_write(self, docstore):
|
||||
def test_write(self, docstore: Store):
|
||||
doc = Document(content="test doc")
|
||||
docstore.write_documents([doc])
|
||||
assert docstore.filter_documents(filters={"id": doc.id}) == [doc]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_write_duplicate_fail(self, docstore):
|
||||
def test_write_duplicate_fail(self, docstore: Store):
|
||||
doc = Document(content="test doc")
|
||||
docstore.write_documents([doc])
|
||||
with pytest.raises(DuplicateDocumentError, match=f"ID '{doc.id}' already exists."):
|
||||
docstore.write_documents(documents=[doc])
|
||||
docstore.write_documents(documents=[doc], policy=DuplicatePolicy.FAIL)
|
||||
assert docstore.filter_documents(filters={"id": doc.id}) == [doc]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_write_duplicate_skip(self, docstore):
|
||||
def test_write_duplicate_skip(self, docstore: Store):
|
||||
doc = Document(content="test doc")
|
||||
docstore.write_documents([doc])
|
||||
docstore.write_documents(documents=[doc], duplicates="skip")
|
||||
docstore.write_documents(documents=[doc], policy=DuplicatePolicy.SKIP)
|
||||
assert docstore.filter_documents(filters={"id": doc.id}) == [doc]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_write_duplicate_overwrite(self, docstore):
|
||||
def test_write_duplicate_overwrite(self, docstore: Store):
|
||||
doc1 = Document(content="test doc 1")
|
||||
doc2 = Document(content="test doc 2")
|
||||
object.__setattr__(doc2, "id", doc1.id) # Make two docs with different content but same ID
|
||||
|
||||
docstore.write_documents([doc2])
|
||||
docstore.filter_documents(filters={"id": doc1.id}) == [doc2]
|
||||
docstore.write_documents(documents=[doc1], duplicates="overwrite")
|
||||
docstore.write_documents(documents=[doc1], policy=DuplicatePolicy.OVERWRITE)
|
||||
assert docstore.filter_documents(filters={"id": doc1.id}) == [doc1]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_write_not_docs(self, docstore):
|
||||
def test_write_not_docs(self, docstore: Store):
|
||||
with pytest.raises(ValueError, match="Please provide a list of Documents"):
|
||||
docstore.write_documents(["not a document for sure"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_write_not_list(self, docstore):
|
||||
def test_write_not_list(self, docstore: Store):
|
||||
with pytest.raises(ValueError, match="Please provide a list of Documents"):
|
||||
docstore.write_documents("not a list actually")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delete_empty(self, docstore):
|
||||
def test_delete_empty(self, docstore: Store):
|
||||
with pytest.raises(MissingDocumentError):
|
||||
docstore.delete_documents(["test"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delete_not_empty(self, docstore):
|
||||
def test_delete_not_empty(self, docstore: Store):
|
||||
doc = Document(content="test doc")
|
||||
docstore.write_documents([doc])
|
||||
|
||||
@ -724,7 +726,7 @@ class DocumentStoreBaseTests:
|
||||
assert docstore.filter_documents(filters={"id": doc.id})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delete_not_empty_nonexisting(self, docstore):
|
||||
def test_delete_not_empty_nonexisting(self, docstore: Store):
|
||||
doc = Document(content="test doc")
|
||||
docstore.write_documents([doc])
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import pandas as pd
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
from haystack.preview.document_stores import Store, MemoryDocumentStore
|
||||
|
||||
from test.preview.document_stores._base import DocumentStoreBaseTests
|
||||
|
||||
@ -19,7 +19,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
return MemoryDocumentStore()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval(self, docstore):
|
||||
def test_bm25_retrieval(self, docstore: Store):
|
||||
docstore = MemoryDocumentStore()
|
||||
# Tests if the bm25_retrieval method returns the correct document based on the input query.
|
||||
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
||||
@ -29,7 +29,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
assert results[0].content == "Haystack supports multiple languages"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_empty_document_store(self, docstore, caplog):
|
||||
def test_bm25_retrieval_with_empty_document_store(self, docstore: Store, caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
# Tests if the bm25_retrieval method correctly returns an empty list when there are no documents in the store.
|
||||
results = docstore.bm25_retrieval(query="How to test this?", top_k=2)
|
||||
@ -37,7 +37,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
assert "No documents found for BM25 retrieval. Returning empty list." in caplog.text
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_empty_query(self, docstore):
|
||||
def test_bm25_retrieval_empty_query(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
||||
docstore.write_documents(docs)
|
||||
@ -45,7 +45,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
docstore.bm25_retrieval(query="", top_k=1)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_string(self, docstore):
|
||||
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_string(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
|
||||
@ -57,7 +57,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
assert results[0].content == "Haystack supports multiple languages"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_list(self, docstore):
|
||||
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_list(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
|
||||
@ -69,7 +69,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
assert results[0].content == "Haystack supports multiple languages"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_two_allowed_doc_type_as_list(self, docstore):
|
||||
def test_bm25_retrieval_filter_two_allowed_doc_type_as_list(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
|
||||
@ -80,7 +80,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_string(self, docstore):
|
||||
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_string(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"),
|
||||
@ -93,7 +93,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "audio"})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_list(self, docstore):
|
||||
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_list(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"),
|
||||
@ -106,7 +106,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["audio"]})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_two_not_all_allowed_doc_type_as_list(self, docstore):
|
||||
def test_bm25_retrieval_filter_two_not_all_allowed_doc_type_as_list(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
|
||||
@ -119,7 +119,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "audio"]})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_different_top_k(self, docstore):
|
||||
def test_bm25_retrieval_with_different_top_k(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method correctly changes the number of returned documents
|
||||
# based on the top_k parameter.
|
||||
docs = [
|
||||
@ -139,7 +139,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
# Test two queries and make sure the results are different
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_two_queries(self, docstore):
|
||||
def test_bm25_retrieval_with_two_queries(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method returns different documents for different queries.
|
||||
docs = [
|
||||
Document(content="Javascript is a popular programming language"),
|
||||
@ -158,7 +158,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
|
||||
# Test a query, add a new document and make sure results are appropriately updated
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_updated_docs(self, docstore):
|
||||
def test_bm25_retrieval_with_updated_docs(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method correctly updates the retrieved documents when new
|
||||
# documents are added to the store.
|
||||
docs = [Document(content="Hello world")]
|
||||
@ -178,7 +178,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
assert results[0].content == "Python is a popular programming language"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_scale_score(self, docstore):
|
||||
def test_bm25_retrieval_with_scale_score(self, docstore: Store):
|
||||
docs = [Document(content="Python programming"), Document(content="Java programming")]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
@ -191,7 +191,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
assert results[0].score != results1[0].score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_table_content(self, docstore):
|
||||
def test_bm25_retrieval_with_table_content(self, docstore: Store):
|
||||
# Tests if the bm25_retrieval method correctly returns a dataframe when the content_type is table.
|
||||
table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]})
|
||||
docs = [
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user