feat: initial implementation of MemoryDocumentStore for new Pipelines (#4447)

* add stub implementation

* reimplementation

* test files

* docstore tests

* tests for document

* better testing

* remove mmh3

* readme

* only store, no retrieval yet

* linting

* review feedback

* initial filters implementation

* working on filters

* linters

* filtering works and is isolated by document store

* simplify filters

* comments

* improve filters matching code

* review feedback

* pylint

* move logic into_create_id

* mypy
This commit is contained in:
ZanSara 2023-04-13 09:36:23 +02:00 committed by GitHub
parent db69141642
commit f2106ab37b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1026 additions and 1 deletions

View File

@ -0,0 +1 @@
# Haystack - Preview features

View File

@ -1,2 +1,3 @@
from canals import node
from haystack.preview.dataclasses import Document
from haystack.preview.pipeline import Pipeline, PipelineError, NoSuchStoreError, load_pipelines, save_pipelines

View File

@ -0,0 +1 @@
from haystack.preview.dataclasses.document import Document

View File

@ -0,0 +1,110 @@
from typing import List, Any, Dict, Literal, Optional, TYPE_CHECKING
import json
import hashlib
import logging
from pathlib import Path
from dataclasses import asdict, dataclass, field
from haystack.preview.utils.import_utils import optional_import
# We need to do this dance because ndarray is an optional dependency used as a type by dataclass
if TYPE_CHECKING:
from numpy import ndarray
else:
ndarray = optional_import("numpy", "ndarray", "You won't be able to use embeddings.", __name__)
DataFrame = optional_import("pandas", "DataFrame", "You won't be able to use table related features.", __name__)
logger = logging.getLogger(__name__)
ContentType = Literal["text", "table", "image", "audio"]
PYTHON_TYPES_FOR_CONTENT: Dict[ContentType, type] = {"text": str, "table": DataFrame, "image": Path, "audio": Path}
def _create_id(
classname: str, content: Any, metadata: Optional[Dict[str, Any]] = None, id_hash_keys: Optional[List[str]] = None
):
"""
Creates a hash of the content given that acts as the document's ID.
"""
content_to_hash = f"{classname}:{content}"
if id_hash_keys:
if not metadata:
raise ValueError("If 'id_hash_keys' is provided, you must provide 'metadata' too.")
content_to_hash = ":".join([content_to_hash, *[str(metadata.get(key, "")) for key in id_hash_keys]])
return hashlib.sha256(str(content_to_hash).encode("utf-8")).hexdigest()
@dataclass(frozen=True)
class Document:
"""
Base data class containing some data to be queried.
Can contain text snippets, tables, file paths to files like images or audios.
Documents can be sorted by score, serialized to/from dictionary and JSON, and are immutable.
Immutability is due to the fact that the document's ID depends on its content, so upon changing the content, also
the ID should change. To avoid keeping IDs in sync with the content by using properties, and asking docstores to
be aware of this corner case, we decide to make Documents immutable and remove the issue. If you need to modify a
Document, consider using `to_dict()`, modifying the dict, and then create a new Document object using
`Document.from_dict()`.
Note that `id_hash_keys` are referring to keys in the metadata. `content` is always included in the id hash.
In case of file-based documents (images, audios), the content that is hashed is the file paths,
so if the file is moved, the hash is different, but if the file is modified without renaming it, the has will
not differ.
"""
id: str = field(default_factory=str)
content: Any = field(default_factory=lambda: None)
content_type: ContentType = "text"
metadata: Dict[str, Any] = field(default_factory=dict, hash=False)
id_hash_keys: List[str] = field(default_factory=lambda: [], hash=False)
score: Optional[float] = field(default=None, compare=True)
embedding: Optional[ndarray] = field(default=None, repr=False)
def __str__(self):
return f"{self.__class__.__name__}('{self.content}')"
def __post_init__(self):
"""
Generate the ID based on the init parameters and make sure that content_type
matches the actual type of content.
"""
# Validate content_type
if not isinstance(self.content, PYTHON_TYPES_FOR_CONTENT[self.content_type]):
raise ValueError(
f"The type of content ({type(self.content)}) does not match the "
f"content type: '{self.content_type}' expects '{PYTHON_TYPES_FOR_CONTENT[self.content_type]}'."
)
# Check if id_hash_keys are all present in the meta
for key in self.id_hash_keys:
if key not in self.metadata:
raise ValueError(
f"'{key}' must be present in the metadata of the Document if you want to use it to generate the ID."
)
# Generate the ID
hashed_content = _create_id(
classname=self.__class__.__name__,
content=str(self.content),
metadata=self.metadata,
id_hash_keys=self.id_hash_keys,
)
# Note: we need to set the id this way because the dataclass is frozen. See the docstring.
object.__setattr__(self, "id", hashed_content)
def to_dict(self):
return asdict(self)
def to_json(self, **json_kwargs):
return json.dumps(self.to_dict(), *json_kwargs)
@classmethod
def from_dict(cls, dictionary):
return cls(**dictionary)
@classmethod
def from_json(cls, data, **json_kwargs):
dictionary = json.loads(data, **json_kwargs)
return cls.from_dict(dictionary=dictionary)

View File

@ -0,0 +1,2 @@
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore
from haystack.preview.document_stores.errors import StoreError, DuplicateDocumentError, MissingDocumentError

View File

@ -0,0 +1,10 @@
class StoreError(Exception):
pass
class DuplicateDocumentError(StoreError):
pass
class MissingDocumentError(StoreError):
pass

View File

@ -0,0 +1 @@
from haystack.preview.document_stores.memory.document_store import MemoryDocumentStore

View File

@ -0,0 +1,255 @@
from typing import List, Any
from haystack.preview.dataclasses import Document
def not_operation(conditions: List[Any], document: Document, _current_key: str):
"""
Applies a NOT to all the nested conditions.
:param conditions: the filters dictionary.
:param document: the document to test.
:param _current_key: internal, don't use.
:return: True if the document matches the negated filters, False otherwise
"""
return not and_operation(conditions=conditions, document=document, _current_key=_current_key)
def and_operation(conditions: List[Any], document: Document, _current_key: str):
"""
Applies an AND to all the nested conditions.
:param conditions: the filters dictionary.
:param document: the document to test.
:param _current_key: internal, don't use.
:return: True if the document matches all the filters, False otherwise
"""
for condition in conditions:
if not _match(conditions=condition, document=document, _current_key=_current_key):
return False
return True
def or_operation(conditions: List[Any], document: Document, _current_key: str):
"""
Applies an OR to all the nested conditions.
:param conditions: the filters dictionary.
:param document: the document to test.
:param _current_key: internal, don't use.
:return: True if the document matches ano of the filters, False otherwise
"""
for condition in conditions:
if _match(conditions=condition, document=document, _current_key=_current_key):
return True
return False
def eq_operation(fields, field_name, value):
"""
Checks for equality between the document's metadata value and a fixed value.
:param fields: all the document's metadata
:param field_name: the field to test
:param value; the fixed value to compare against
:return: True if the values are equal, False otherwise
"""
if not field_name in fields:
return False
return fields[field_name] == value
def in_operation(fields, field_name, value):
"""
Checks for whether the document's metadata value is present into the given list.
:param fields: all the document's metadata
:param field_name: the field to test
:param value; the fixed value to compare against
:return: True if the document's value is included in the given list, False otherwise
"""
if not field_name in fields:
return False
return fields[field_name] in value
def ne_operation(fields, field_name, value):
"""
Checks for inequality between the document's metadata value and a fixed value.
:param fields: all the document's metadata
:param field_name: the field to test
:param value; the fixed value to compare against
:return: True if the values are different, False otherwise
"""
if not field_name in fields:
return True
return fields[field_name] != value
def nin_operation(fields, field_name, value):
"""
Checks whether the document's metadata value is absent from the given list.
:param fields: all the document's metadata
:param field_name: the field to test
:param value; the fixed value to compare against
:return: True if the document's value is not included in the given list, False otherwise
"""
if not field_name in fields:
return True
return fields[field_name] not in value
def gt_operation(fields, field_name, value):
"""
Checks whether the document's metadata value is (strictly) larger than the given value.
:param fields: all the document's metadata
:param field_name: the field to test
:param value; the fixed value to compare against
:return: True if the document's value is strictly larger than the fixed value, False otherwise
"""
if not field_name in fields:
return False
return fields[field_name] > value
def gte_operation(fields, field_name, value):
"""
Checks whether the document's metadata value is larger than or equal to the given value.
:param fields: all the document's metadata
:param field_name: the field to test
:param value; the fixed value to compare against
:return: True if the document's value is larger than or equal to the fixed value, False otherwise
"""
if not field_name in fields:
return False
return fields[field_name] >= value
def lt_operation(fields, field_name, value):
"""
Checks whether the document's metadata value is (strictly) smaller than the given value.
:param fields: all the document's metadata
:param field_name: the field to test
:param value; the fixed value to compare against
:return: True if the document's value is strictly smaller than the fixed value, False otherwise
"""
if not field_name in fields:
return False
return fields[field_name] < value
def lte_operation(fields, field_name, value):
"""
Checks whether the document's metadata value is smaller than or equal to the given value.
:param fields: all the document's metadata
:param field_name: the field to test
:param value; the fixed value to compare against
:return: True if the document's value is smaller than or equal to the fixed value, False otherwise
"""
if not field_name in fields:
return False
return fields[field_name] <= value
LOGICAL_STATEMENTS = {"$not": not_operation, "$and": and_operation, "$or": or_operation}
OPERATORS = {
"$eq": eq_operation,
"$in": in_operation,
"$ne": ne_operation,
"$nin": nin_operation,
"$gt": gt_operation,
"$gte": gte_operation,
"$lt": lt_operation,
"$lte": lte_operation,
}
RESERVED_KEYS = [*LOGICAL_STATEMENTS.keys(), *OPERATORS.keys()]
def match(conditions: Any, document: Document):
"""
This method applies the filters to any given document and returns True when the documents
metadata matches the filters, False otherwise.
:param conditions: the filters dictionary.
:param document: the document to test.
:return: True if the document matches the filters, False otherwise
"""
if isinstance(conditions, list):
# The default operation for a list of sibling conditions is $and
return _match(conditions=conditions, document=document, _current_key="$and")
if isinstance(conditions, dict):
if len(conditions.keys()) > 1:
# The default operation for a list of sibling conditions is $and
return _match(conditions=conditions, document=document, _current_key="$and")
field_key, field_value = list(conditions.items())[0]
return _match(conditions=field_value, document=document, _current_key=field_key)
raise ValueError("Filters must be dictionaries or lists. See the examples in the documentation.")
def _match(conditions: Any, document: Document, _current_key: str):
"""
Recursive implementation of match().
"""
if isinstance(conditions, list):
# The default operation for a list of sibling conditions is $and
return _match(conditions={"$and": conditions}, document=document, _current_key=_current_key)
if isinstance(conditions, dict):
# Check for malformed filters, like {"name": {"year": "2020"}}
if _current_key not in RESERVED_KEYS and any(key not in RESERVED_KEYS for key in conditions.keys()):
raise ValueError(
f"This filter ({_current_key}, {conditions}) seems to be malformed. Comparisons with dictionaries are "
"not currently supported. Check the documentation to learn more about filters syntax."
)
# The default operation for a list of sibling conditions is $and
if len(conditions.keys()) > 1:
return and_operation(
conditions=_conditions_as_list(conditions), document=document, _current_key=_current_key
)
field_key, field_value = list(conditions.items())[0]
if field_key in LOGICAL_STATEMENTS.keys():
# It's a nested logical statement ($and, $or, $not)
return LOGICAL_STATEMENTS[field_key](
conditions=_conditions_as_list(field_value), document=document, _current_key=_current_key
)
if field_key in OPERATORS.keys():
# It's a comparison operator ($eq, $in, $gte, ...)
if not _current_key:
raise ValueError(
"Filters can't start with an operator like $eq and $in. You have to specify the field name first. "
"See the examples in the documentation."
)
return OPERATORS[field_key](fields=document.metadata, field_name=_current_key, value=field_value)
if isinstance(field_value, list):
# The default operator for a {key: [value1, value2]} filter is $in
return in_operation(fields=document.metadata, field_name=field_key, value=field_value)
# The default operator for a {key: value} filter is $eq
return eq_operation(fields=document.metadata, field_name=_current_key, value=conditions)
def _conditions_as_list(conditions: Any) -> List[Any]:
"""
Make sure all nested conditions are not dictionaries or single values, but always lists.
:param conditions: the conditions to transform into a list
:returns: a list of filters
"""
if isinstance(conditions, list):
return conditions
if isinstance(conditions, dict):
return [{key: value} for key, value in conditions.items()]
return [conditions]

View File

@ -0,0 +1,144 @@
from typing import Literal, Any, Dict, List, Optional, Iterable
import logging
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores.memory._filters import match
from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError
logger = logging.getLogger(__name__)
DuplicatePolicy = Literal["skip", "overwrite", "fail"]
class MemoryDocumentStore:
"""
Stores data in-memory. It's ephemeral and cannot be saved to disk.
"""
def __init__(self):
"""
Initializes the store.
"""
self.storage = {}
def count_documents(self) -> int:
"""
Returns the number of how many documents are present in the document store.
"""
return len(self.storage.keys())
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.
Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`,
`"$or"`, `"$not"`), a comparison operator (`"$eq"`, `$ne`, `"$in"`, `$nin`, `"$gt"`, `"$gte"`, `"$lt"`,
`"$lte"`) or a metadata field name.
Logical operator keys take a dictionary of metadata field names and/or logical operators as value. Metadata
field names take a dictionary of comparison operators as value. Comparison operator keys take a single value or
(in case of `"$in"`) a list of values as value. If no logical operator is provided, `"$and"` is used as default
operation. If no comparison operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used
as default operation.
Example:
```python
filters = {
"$and": {
"type": {"$eq": "article"},
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": {"$in": ["economy", "politics"]},
"publisher": {"$eq": "nytimes"}
}
}
}
# or simpler using default operators
filters = {
"type": "article",
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
"rating": {"$gte": 3},
"$or": {
"genre": ["economy", "politics"],
"publisher": "nytimes"
}
}
```
To use the same logical operator multiple times on the same level, logical operators can take a list of
dictionaries as value.
Example:
```python
filters = {
"$or": [
{
"$and": {
"Type": "News Paper",
"Date": {
"$lt": "2019-01-01"
}
}
},
{
"$and": {
"Type": "Blog Post",
"Date": {
"$gte": "2019-01-01"
}
}
}
]
}
```
:param filters: the filters to apply to the document list.
:return: a list of Documents that match the given filters.
"""
if filters:
return [doc for doc in self.storage.values() if match(conditions=filters, document=doc)]
return list(self.storage.values())
def write_documents(self, documents: List[Document], duplicates: DuplicatePolicy = "fail") -> None:
"""
Writes (or overwrites) documents into the store.
:param documents: a list of documents.
:param duplicates: documents with the same ID count as duplicates. When duplicates are met,
the store can:
- skip: keep the existing document and ignore the new one.
- overwrite: remove the old document and write the new one.
- fail: an error is raised
:raises DuplicateError: Exception trigger on duplicate document if `duplicates="fail"`
:return: None
"""
if (
not isinstance(documents, Iterable)
or isinstance(documents, str)
or any(not isinstance(doc, Document) for doc in documents)
):
raise ValueError("Please provide a list of Documents.")
for document in documents:
if document.id in self.storage.keys():
if duplicates == "fail":
raise DuplicateDocumentError(f"ID '{document.id}' already exists.")
if duplicates == "skip":
logger.warning("ID '%s' already exists", document.id)
self.storage[document.id] = document
def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with a matching document_ids from the document store.
Fails with `MissingDocumentError` if no document with this id is present in the store.
:param object_ids: the object_ids to delete
"""
for doc_id in document_ids:
if not doc_id in self.storage.keys():
raise MissingDocumentError(f"ID '{doc_id}' not found, cannot delete it.")
del self.storage[doc_id]

View File

@ -21,7 +21,7 @@ class Pipeline(CanalsPipeline):
def __init__(self):
super().__init__()
self.stores: Dict[str, object] = {}
self.stores = {}
def add_store(self, name: str, store: object) -> None:
"""

View File

@ -0,0 +1,24 @@
from typing import Optional, Any
import importlib
import logging
def optional_import(import_path: str, import_target: Optional[str], error_msg: str, importer_module: str) -> Any:
"""
Imports an optional dependency. Emits a DEBUG log if the dependency is missing.
"""
try:
module = importlib.import_module(import_path)
if import_target:
return getattr(module, import_target)
return module
except ImportError as exc:
logging.getLogger(importer_module).debug(
"%s%s%s can't be imported: %s Error raised: %s",
import_path,
"." if import_target else "",
import_target,
error_msg,
exc,
)
return None

View File

@ -0,0 +1,152 @@
from pathlib import Path
import hashlib
import pandas as pd
import numpy as np
from haystack.preview import Document
from haystack.preview.dataclasses.document import _create_id
def test_default_text_document_to_dict():
assert Document(content="test content").to_dict() == {
"id": _create_id(classname=Document.__name__, content="test content"),
"content": "test content",
"content_type": "text",
"metadata": {},
"id_hash_keys": [],
"score": None,
"embedding": None,
}
def test_default_text_document_from_dict():
assert Document.from_dict(
{
"id": _create_id(classname=Document.__name__, content="test content"),
"content": "test content",
"content_type": "text",
"metadata": {},
"id_hash_keys": [],
"score": None,
"embedding": None,
}
) == Document(content="test content")
def test_default_table_document_to_dict():
df = pd.DataFrame([1, 2])
dictionary = Document(content=df, content_type="table").to_dict()
dataframe = dictionary.pop("content")
assert dataframe.equals(df)
assert dictionary == {
"id": _create_id(classname=Document.__name__, content=df),
"content_type": "table",
"metadata": {},
"id_hash_keys": [],
"score": None,
"embedding": None,
}
def test_default_table_document_from_dict():
df = pd.DataFrame([1, 2])
assert Document.from_dict(
{
"id": _create_id(classname=Document.__name__, content=df),
"content": df,
"content_type": "table",
"metadata": {},
"id_hash_keys": [],
"score": None,
"embedding": None,
}
) == Document(content=df, content_type="table")
def test_default_image_document_to_dict():
path = Path(__file__).parent / "test_files" / "apple.jpg"
assert Document(content=path, content_type="image").to_dict() == {
"id": _create_id(classname=Document.__name__, content=path),
"content": path,
"content_type": "image",
"metadata": {},
"id_hash_keys": [],
"score": None,
"embedding": None,
}
def test_default_image_document_from_dict():
path = Path(__file__).parent / "test_files" / "apple.jpg"
assert Document.from_dict(
{
"id": _create_id(classname=Document.__name__, content=path),
"content": path,
"content_type": "image",
"metadata": {},
"id_hash_keys": [],
"score": None,
"embedding": None,
}
) == Document(content=path, content_type="image")
def test_document_with_most_attributes_to_dict():
"""
This tests also id_hash_keys
"""
doc = Document(
content="test content",
content_type="text",
metadata={"some": "values", "test": 10},
id_hash_keys=["test"],
score=0.99,
embedding=np.zeros([10, 10]),
)
dictionary = doc.to_dict()
embedding = dictionary.pop("embedding")
assert (embedding == np.zeros([10, 10])).all()
assert dictionary == {
"id": _create_id(
classname=Document.__name__,
content="test content",
id_hash_keys=["test"],
metadata={"some": "values", "test": 10},
),
"content": "test content",
"content_type": "text",
"metadata": {"some": "values", "test": 10},
"id_hash_keys": ["test"],
"score": 0.99,
}
def test_document_with_most_attributes_from_dict():
embedding = np.zeros([10, 10])
assert Document.from_dict(
{
"id": _create_id(
classname=Document.__name__,
content="test content",
id_hash_keys=["test"],
metadata={"some": "values", "test": 10},
),
"content": "test content",
"content_type": "text",
"metadata": {"some": "values", "test": 10},
"id_hash_keys": ["test"],
"score": 0.99,
"embedding": embedding,
}
) == Document(
content="test content",
content_type="text",
metadata={"some": "values", "test": 10},
id_hash_keys=["test"],
score=0.99,
embedding=embedding,
)

View File

@ -0,0 +1,286 @@
import pytest
import numpy as np
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores import MissingDocumentError, DuplicateDocumentError
class DocumentStoreBaseTests:
@pytest.fixture
def docstore(self):
raise NotImplementedError()
@pytest.fixture
def filterable_docs(self):
documents = []
for i in range(3):
documents.append(
Document(
content=f"A Foo Document {i}",
metadata={"name": f"name_{i}", "year": "2020", "month": "01", "number": 2},
embedding=np.random.rand(768).astype(np.float32),
)
)
documents.append(
Document(
content=f"A Bar Document {i}",
metadata={"name": f"name_{i}", "year": "2021", "month": "02", "number": -2},
embedding=np.random.rand(768).astype(np.float32),
)
)
documents.append(
Document(
content=f"A Foobar Document {i}",
metadata={"name": f"name_{i}", "year": "2000", "month": "03", "number": -10},
embedding=np.random.rand(768).astype(np.float32),
)
)
documents.append(
Document(
content=f"Document {i} without embedding",
metadata={"name": f"name_{i}", "no_embedding": True, "month": "03"},
)
)
return documents
def test_count_empty(self, docstore):
assert docstore.count_documents() == 0
def test_count_not_empty(self, docstore):
self.direct_write(
docstore, [Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")]
)
assert docstore.count_documents() == 3
def test_no_filter_empty(self, docstore):
assert docstore.filter_documents() == []
assert docstore.filter_documents(filters={}) == []
def test_no_filter_not_empty(self, docstore):
docs = [Document(content="test doc")]
self.direct_write(docstore, docs)
assert docstore.filter_documents() == docs
assert docstore.filter_documents(filters={}) == docs
def test_filter_simple_value(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"year": "2020"})
assert len(result) == 3
def test_filter_simple_list(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"year": ["2020"]})
assert all(doc.metadata["year"] == "2020" for doc in result)
result = docstore.filter_documents(filters={"year": ["2020", "2021"]})
assert all(doc.metadata["year"] in ["2020", "2021"] for doc in result)
def test_incorrect_filter_name(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"non_existing_meta_field": ["whatever"]})
assert len(result) == 0
def test_incorrect_filter_type(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
with pytest.raises(ValueError, match="dictionaries or lists"):
docstore.filter_documents(filters="something odd")
def test_incorrect_filter_value(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"year": ["nope"]})
assert len(result) == 0
def test_incorrect_filter_nesting(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
with pytest.raises(ValueError, match="malformed"):
docstore.filter_documents(filters={"number": {"year": "2020"}})
with pytest.raises(ValueError, match="malformed"):
docstore.filter_documents(filters={"number": {"year": {"month": "01"}}})
def test_eq_filter(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"year": {"$eq": "2020"}})
assert all(doc.metadata["year"] == "2020" for doc in result)
result = docstore.filter_documents(filters={"year": "2020"})
assert all(doc.metadata["year"] == "2020" for doc in result)
def test_in_filter(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"year": {"$in": ["2020", "2021", "n.a."]}})
assert all(doc.metadata["year"] in ["2020", "2021"] for doc in result)
result = docstore.filter_documents(filters={"year": ["2020", "2021", "n.a."]})
assert all(doc.metadata["year"] in ["2020", "2021"] for doc in result)
def test_ne_filter(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"year": {"$ne": "2020"}})
assert all(doc.metadata.get("year", None) != "2020" for doc in result)
def test_nin_filter(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"year": {"$nin": ["2020", "2021", "n.a."]}})
assert all(doc.metadata.get("year", None) not in ["2020", "2021"] for doc in result)
def test_gt_filter(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"number": {"$gt": 0.0}})
assert all(doc.metadata["number"] > 0 for doc in result)
def test_gte_filter(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"number": {"$gte": -2.0}})
assert all(doc.metadata["number"] >= -2.0 for doc in result)
def test_lt_filter(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"number": {"$lt": 0.0}})
assert all(doc.metadata["number"] < 0 for doc in result)
def test_lte_filter(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"number": {"$lte": 2.0}})
assert all(doc.metadata["number"] <= 2.0 for doc in result)
def test_filter_simple_explicit_and(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"year": {"$and": {"$lte": "2021", "$gte": "2020"}}})
assert all(int(doc.metadata["year"]) >= 2020 and int(doc.metadata["year"]) <= 2021 for doc in result)
result = docstore.filter_documents(filters={"year": {"$and": [{"$lte": "2021"}, {"$gte": "2020"}]}})
assert all(int(doc.metadata["year"]) >= 2020 and int(doc.metadata["year"]) <= 2021 for doc in result)
def test_filter_simple_implicit_and(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
result = docstore.filter_documents(filters={"year": {"$lte": "2021", "$gte": "2020"}})
assert all(int(doc.metadata["year"]) >= 2020 and int(doc.metadata["year"]) <= 2021 for doc in result)
def test_filter_nested_explicit_and(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
filters = {"$and": {"year": {"$and": {"$lte": "2021", "$gte": "2020"}}, "name": {"$in": ["name_0", "name_1"]}}}
result = docstore.filter_documents(filters=filters)
assert all(
int(doc.metadata["year"]) >= 2020
and int(doc.metadata["year"]) <= 2021
and doc.metadata["name"] in ["name_0", "name_1"]
for doc in result
)
def test_filter_nested_implicit_and(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
filters_simplified = {"year": {"$lte": "2021", "$gte": "2020"}, "name": ["name_0", "name_1"]}
result = docstore.filter_documents(filters=filters_simplified)
assert all(
int(doc.metadata["year"]) >= 2020
and int(doc.metadata["year"]) <= 2021
and doc.metadata["name"] in ["name_0", "name_1"]
for doc in result
)
def test_filter_simple_or(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
filters = {"$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}}
result = docstore.filter_documents(filters=filters)
assert all(doc.metadata["name"] in ["name_0", "name_1"] or doc.metadata["number"] < 1.0 for doc in result)
def test_filter_nested_or(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
filters = {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}}
result = docstore.filter_documents(filters=filters)
assert all(doc.metadata["name"] in ["name_0", "name_1"] or doc.metadata["number"] < 1.0 for doc in result)
def test_filter_nested_and_or(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
filters_simplified = {
"year": {"$lte": "2021", "$gte": "2020"},
"$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}},
}
result = docstore.filter_documents(filters=filters_simplified)
assert all(
(int(doc.metadata["year"]) >= 2020 and int(doc.metadata["year"]) <= 2021)
and (doc.metadata["name"] in ["name_0", "name_1"] or doc.metadata["number"] < 1.0)
for doc in result
)
def test_filter_nested_or_and(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
filters_simplified = {
"$or": {
"number": {"$lt": 1.0},
"$and": {"name": {"$in": ["name_0", "name_1"]}, "$not": {"month": {"$eq": "01"}}},
}
}
result = docstore.filter_documents(filters=filters_simplified)
assert all(
doc.metadata.get("number", 2) < 1.0
or (doc.metadata["name"] in ["name_0", "name_1"] and doc.metadata["month"] != "01")
for doc in result
)
def test_filter_nested_multiple_identical_operators_same_level(self, docstore, filterable_docs):
self.direct_write(docstore, filterable_docs)
filters = {
"$or": [
{"$and": {"name": {"$in": ["name_0", "name_1"]}, "year": {"$gte": "2020"}}},
{"$and": {"name": {"$in": ["name_0", "name_1"]}, "year": {"$lt": "2021"}}},
]
}
result = docstore.filter_documents(filters=filters)
assert all(doc.metadata["name"] in ["name_0", "name_1"] for doc in result)
def test_write(self, docstore):
doc = Document(content="test doc")
docstore.write_documents(documents=[doc])
assert self.direct_access(docstore, doc_id=doc.id) == doc
def test_write_duplicate_fail(self, docstore):
doc = Document(content="test doc")
self.direct_write(docstore, [doc])
with pytest.raises(DuplicateDocumentError, match=f"ID '{doc.id}' already exists."):
docstore.write_documents(documents=[doc])
assert self.direct_access(docstore, doc_id=doc.id) == doc
def test_write_duplicate_skip(self, docstore):
doc = Document(content="test doc")
self.direct_write(docstore, [doc])
docstore.write_documents(documents=[doc], duplicates="skip")
assert self.direct_access(docstore, doc_id=doc.id) == doc
def test_write_duplicate_overwrite(self, docstore):
doc1 = Document(content="test doc 1")
doc2 = Document(content="test doc 2")
object.__setattr__(doc2, "id", doc1.id) # Make two docs with different content but same ID
self.direct_write(docstore, [doc2])
assert self.direct_access(docstore, doc_id=doc1.id) == doc2
docstore.write_documents(documents=[doc1], duplicates="overwrite")
assert self.direct_access(docstore, doc_id=doc1.id) == doc1
def test_write_not_docs(self, docstore):
with pytest.raises(ValueError, match="Please provide a list of Documents"):
docstore.write_documents(["not a document for sure"])
def test_write_not_list(self, docstore):
with pytest.raises(ValueError, match="Please provide a list of Documents"):
docstore.write_documents("not a list actually")
def test_delete_empty(self, docstore):
with pytest.raises(MissingDocumentError):
docstore.delete_documents(["test"])
def test_delete_not_empty(self, docstore):
doc = Document(content="test doc")
self.direct_write(docstore, [doc])
docstore.delete_documents([doc.id])
with pytest.raises(Exception):
assert self.direct_access(docstore, doc_id=doc.id)
def test_delete_not_empty_nonexisting(self, docstore):
doc = Document(content="test doc")
self.direct_write(docstore, [doc])
with pytest.raises(MissingDocumentError):
docstore.delete_documents(["non_existing"])
assert self.direct_access(docstore, doc_id=doc.id) == doc

View File

@ -0,0 +1,38 @@
import pytest
from haystack.preview.document_stores import MemoryDocumentStore
from test.preview.document_stores._base import DocumentStoreBaseTests
class TestMemoryDocumentStore(DocumentStoreBaseTests):
"""
Test MemoryDocumentStore's specific features
"""
@pytest.fixture
def docstore(self) -> MemoryDocumentStore:
return MemoryDocumentStore()
def direct_access(self, docstore, doc_id):
"""
Bypass `filter_documents()`
"""
return docstore.storage[doc_id]
def direct_write(self, docstore, documents):
"""
Bypass `write_documents()`
"""
for doc in documents:
docstore.storage[doc.id] = doc
def direct_delete(self, docstore, ids):
"""
Bypass `delete_documents()`
"""
for doc_id in ids:
del docstore.storage[doc_id]
#
# Test retrieval
#

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB