mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 10:26:27 +00:00
550 lines
21 KiB
Python
550 lines
21 KiB
Python
# pylint: disable=too-many-public-methods
|
|
import sys
|
|
|
|
import pytest
|
|
import numpy as np
|
|
|
|
from haystack.schema import Document, Label, Answer, Span
|
|
from haystack.errors import DuplicateDocumentError
|
|
from haystack.document_stores import BaseDocumentStore
|
|
|
|
|
|
@pytest.mark.document_store
|
|
class DocumentStoreBaseTestAbstract:
|
|
"""
|
|
This is a base class to test abstract methods from DocumentStoreBase to be inherited by any Document Store
|
|
testsuite. It doesn't have the `Test` prefix in the name so that its methods won't be collected for this
|
|
class but only for its subclasses.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def documents(self):
|
|
documents = []
|
|
for i in range(3):
|
|
documents.append(
|
|
Document(
|
|
content=f"A Foo Document {i}",
|
|
meta={"name": f"name_{i}", "year": "2020", "month": "01", "numbers": [2, 4]},
|
|
embedding=np.random.rand(768).astype(np.float32),
|
|
)
|
|
)
|
|
|
|
documents.append(
|
|
Document(
|
|
content=f"A Bar Document {i}",
|
|
meta={"name": f"name_{i}", "year": "2021", "month": "02", "numbers": [-2, -4]},
|
|
embedding=np.random.rand(768).astype(np.float32),
|
|
)
|
|
)
|
|
|
|
documents.append(
|
|
Document(
|
|
content=f"Document {i} without embeddings",
|
|
meta={"name": f"name_{i}", "no_embedding": True, "month": "03"},
|
|
)
|
|
)
|
|
|
|
return documents
|
|
|
|
@pytest.fixture
|
|
def labels(self, documents):
|
|
labels = []
|
|
for i, d in enumerate(documents):
|
|
labels.append(
|
|
Label(
|
|
query=f"query_{i}",
|
|
document=d,
|
|
is_correct_document=True,
|
|
is_correct_answer=False,
|
|
# create a mix set of labels
|
|
origin="user-feedback" if i % 2 else "gold-label",
|
|
answer=None if not i else Answer(f"the answer is {i}", document_ids=[d.id]),
|
|
meta={"name": f"label_{i}", "year": f"{2020 + i}"},
|
|
)
|
|
)
|
|
return labels
|
|
|
|
#
|
|
# Integration tests
|
|
#
|
|
|
|
@pytest.mark.integration
|
|
def test_write_documents(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
docs = ds.get_all_documents()
|
|
assert len(docs) == len(documents)
|
|
expected_ids = set(doc.id for doc in documents)
|
|
ids = set(doc.id for doc in docs)
|
|
assert ids == expected_ids
|
|
|
|
@pytest.mark.integration
|
|
def test_write_labels(self, ds, labels):
|
|
ds.write_labels(labels)
|
|
assert ds.get_all_labels() == labels
|
|
|
|
@pytest.mark.integration
|
|
def test_write_with_duplicate_doc_ids(self, ds):
|
|
duplicate_documents = [
|
|
Document(content="Doc1", id_hash_keys=["content"], meta={"key1": "value1"}),
|
|
Document(content="Doc1", id_hash_keys=["content"], meta={"key1": "value1"}),
|
|
]
|
|
ds.write_documents(duplicate_documents, duplicate_documents="skip")
|
|
results = ds.get_all_documents()
|
|
assert len(results) == 1
|
|
assert results[0] == duplicate_documents[0]
|
|
with pytest.raises(Exception):
|
|
ds.write_documents(duplicate_documents, duplicate_documents="fail")
|
|
|
|
@pytest.mark.integration
|
|
def test_get_embedding_count(self, ds, documents):
|
|
"""
|
|
We expect 6 docs with embeddings because only 6 documents in the documents fixture for this class contain
|
|
embeddings.
|
|
"""
|
|
ds.write_documents(documents)
|
|
assert ds.get_embedding_count() == 6
|
|
|
|
@pytest.mark.skip
|
|
@pytest.mark.integration
|
|
def test_get_all_documents_without_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
out = ds.get_all_documents()
|
|
assert out == documents
|
|
|
|
@pytest.mark.integration
|
|
def test_get_all_documents_without_embeddings(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
out = ds.get_all_documents(return_embedding=False)
|
|
for doc in out:
|
|
assert doc.embedding is None
|
|
|
|
@pytest.mark.integration
|
|
def test_get_all_document_filter_duplicate_text_value(self, ds):
|
|
documents = [
|
|
Document(content="duplicated", meta={"meta_field": "0"}, id_hash_keys=["meta"]),
|
|
Document(content="duplicated", meta={"meta_field": "1", "name": "file.txt"}, id_hash_keys=["meta"]),
|
|
Document(content="Doc2", meta={"name": "file_2.txt"}, id_hash_keys=["meta"]),
|
|
]
|
|
ds.write_documents(documents)
|
|
documents = ds.get_all_documents(filters={"meta_field": ["1"]})
|
|
assert len(documents) == 1
|
|
assert documents[0].content == "duplicated"
|
|
assert documents[0].meta["name"] == "file.txt"
|
|
|
|
documents = ds.get_all_documents(filters={"meta_field": ["0"]})
|
|
assert len(documents) == 1
|
|
assert documents[0].content == "duplicated"
|
|
assert documents[0].meta.get("name") is None
|
|
|
|
documents = ds.get_all_documents(filters={"name": ["file_2.txt"]})
|
|
assert len(documents) == 1
|
|
assert documents[0].content == "Doc2"
|
|
assert documents[0].meta.get("meta_field") is None
|
|
|
|
@pytest.mark.integration
|
|
def test_get_all_documents_with_correct_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
result = ds.get_all_documents(filters={"year": ["2020"]})
|
|
assert len(result) == 3
|
|
|
|
documents = ds.get_all_documents(filters={"year": ["2020", "2021"]})
|
|
assert len(documents) == 6
|
|
|
|
@pytest.mark.integration
|
|
def test_get_all_documents_with_incorrect_filter_name(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
result = ds.get_all_documents(filters={"non_existing_meta_field": ["whatever"]})
|
|
assert len(result) == 0
|
|
|
|
@pytest.mark.integration
|
|
def test_get_all_documents_with_incorrect_filter_value(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
result = ds.get_all_documents(filters={"year": ["nope"]})
|
|
assert len(result) == 0
|
|
|
|
@pytest.mark.integration
|
|
def test_eq_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
result = ds.get_all_documents(filters={"year": {"$eq": "2020"}})
|
|
assert len(result) == 3
|
|
result = ds.get_all_documents(filters={"year": "2020"})
|
|
assert len(result) == 3
|
|
|
|
@pytest.mark.integration
|
|
def test_in_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
result = ds.get_all_documents(filters={"year": {"$in": ["2020", "2021", "n.a."]}})
|
|
assert len(result) == 6
|
|
result = ds.get_all_documents(filters={"year": ["2020", "2021", "n.a."]})
|
|
assert len(result) == 6
|
|
|
|
@pytest.mark.integration
|
|
def test_ne_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
result = ds.get_all_documents(filters={"year": {"$ne": "2020"}})
|
|
assert len(result) == 6
|
|
|
|
@pytest.mark.integration
|
|
def test_nin_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
result = ds.get_all_documents(filters={"year": {"$nin": ["2020", "2021", "n.a."]}})
|
|
assert len(result) == 3
|
|
|
|
@pytest.mark.integration
|
|
def test_comparison_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
result = ds.get_all_documents(filters={"numbers": {"$gt": 0.0}})
|
|
assert len(result) == 3
|
|
|
|
result = ds.get_all_documents(filters={"numbers": {"$gte": -2.0}})
|
|
assert len(result) == 6
|
|
|
|
result = ds.get_all_documents(filters={"numbers": {"$lt": 0.0}})
|
|
assert len(result) == 3
|
|
|
|
result = ds.get_all_documents(filters={"numbers": {"$lte": 2.0}})
|
|
assert len(result) == 6
|
|
|
|
@pytest.mark.integration
|
|
def test_compound_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
result = ds.get_all_documents(filters={"year": {"$lte": "2021", "$gte": "2020"}})
|
|
assert len(result) == 6
|
|
|
|
@pytest.mark.integration
|
|
def test_simplified_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
filters = {"$and": {"year": {"$lte": "2021", "$gte": "2020"}, "name": {"$in": ["name_0", "name_1"]}}}
|
|
result = ds.get_all_documents(filters=filters)
|
|
assert len(result) == 4
|
|
|
|
filters_simplified = {"year": {"$lte": "2021", "$gte": "2020"}, "name": ["name_0", "name_1"]}
|
|
result = ds.get_all_documents(filters=filters_simplified)
|
|
assert len(result) == 4
|
|
|
|
@pytest.mark.integration
|
|
def test_nested_condition_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
filters = {
|
|
"$and": {
|
|
"year": {"$lte": "2021", "$gte": "2020"},
|
|
"$or": {"name": {"$in": ["name_0", "name_1"]}, "numbers": {"$lt": 5.0}},
|
|
}
|
|
}
|
|
result = ds.get_all_documents(filters=filters)
|
|
assert len(result) == 6
|
|
|
|
filters_simplified = {
|
|
"year": {"$lte": "2021", "$gte": "2020"},
|
|
"$or": {"name": {"$in": ["name_0", "name_2"]}, "numbers": {"$lt": 5.0}},
|
|
}
|
|
result = ds.get_all_documents(filters=filters_simplified)
|
|
assert len(result) == 6
|
|
|
|
filters = {
|
|
"$and": {
|
|
"year": {"$lte": "2021", "$gte": "2020"},
|
|
"$or": {
|
|
"name": {"$in": ["name_0", "name_1"]},
|
|
"$and": {"numbers": {"$lt": 5.0}, "$not": {"month": {"$eq": "01"}}},
|
|
},
|
|
}
|
|
}
|
|
result = ds.get_all_documents(filters=filters)
|
|
assert len(result) == 5
|
|
|
|
filters_simplified = {
|
|
"year": {"$lte": "2021", "$gte": "2020"},
|
|
"$or": {"name": ["name_0", "name_1"], "$and": {"numbers": {"$lt": 5.0}, "$not": {"month": {"$eq": "01"}}}},
|
|
}
|
|
result = ds.get_all_documents(filters=filters_simplified)
|
|
assert len(result) == 5
|
|
|
|
@pytest.mark.integration
|
|
def test_nested_condition_not_filters(self, ds, documents):
|
|
"""
|
|
Test nested logical operations within "$not", important as we apply De Morgan's laws in WeaviateDocumentstore
|
|
"""
|
|
ds.write_documents(documents)
|
|
filters = {
|
|
"$not": {
|
|
"$or": {
|
|
"$and": {"numbers": {"$lt": 5.0}, "month": {"$ne": "01"}},
|
|
"$not": {"year": {"$lte": "2021", "$gte": "2020"}},
|
|
}
|
|
}
|
|
}
|
|
result = ds.get_all_documents(filters=filters)
|
|
assert len(result) == 3
|
|
|
|
docs_meta = result[0].meta["numbers"]
|
|
assert [2, 4] == docs_meta
|
|
|
|
# Test same logical operator twice on same level
|
|
|
|
filters = {
|
|
"$or": [
|
|
{"$and": {"name": {"$in": ["name_0", "name_1"]}, "year": {"$gte": "2020"}}},
|
|
{"$and": {"name": {"$in": ["name_0", "name_1"]}, "year": {"$lt": "2021"}}},
|
|
]
|
|
}
|
|
result = ds.get_all_documents(filters=filters)
|
|
docs_meta = [doc.meta["name"] for doc in result]
|
|
assert len(result) == 4
|
|
assert "name_0" in docs_meta
|
|
assert "name_2" not in docs_meta
|
|
|
|
@pytest.mark.integration
|
|
def test_get_document_by_id(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
doc = ds.get_document_by_id(documents[0].id)
|
|
assert doc.id == documents[0].id
|
|
assert doc.content == documents[0].content
|
|
|
|
@pytest.mark.integration
|
|
def test_get_documents_by_id(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
ids = [doc.id for doc in documents]
|
|
result = {doc.id for doc in ds.get_documents_by_id(ids, batch_size=2)}
|
|
assert set(ids) == result
|
|
|
|
@pytest.mark.integration
|
|
def test_get_document_count(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
assert ds.get_document_count() == len(documents)
|
|
assert ds.get_document_count(filters={"year": ["2020"]}) == 3
|
|
assert ds.get_document_count(filters={"month": ["02"]}) == 3
|
|
|
|
@pytest.mark.integration
|
|
def test_get_all_documents_generator(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
assert len(list(ds.get_all_documents_generator(batch_size=2))) == 9
|
|
|
|
@pytest.mark.integration
|
|
def test_duplicate_documents_skip(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
updated_docs = []
|
|
for d in documents:
|
|
updated_d = Document.from_dict(d.to_dict())
|
|
updated_d.meta["name"] = "Updated"
|
|
updated_docs.append(updated_d)
|
|
|
|
ds.write_documents(updated_docs, duplicate_documents="skip")
|
|
for d in ds.get_all_documents():
|
|
assert d.meta.get("name") != "Updated"
|
|
|
|
@pytest.mark.integration
|
|
def test_duplicate_documents_overwrite(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
updated_docs = []
|
|
for d in documents:
|
|
updated_d = Document.from_dict(d.to_dict())
|
|
updated_d.meta["name"] = "Updated"
|
|
updated_docs.append(updated_d)
|
|
|
|
ds.write_documents(updated_docs, duplicate_documents="overwrite")
|
|
for doc in ds.get_all_documents():
|
|
assert doc.meta["name"] == "Updated"
|
|
|
|
@pytest.mark.integration
|
|
def test_duplicate_documents_fail(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
|
|
updated_docs = []
|
|
for d in documents:
|
|
updated_d = Document.from_dict(d.to_dict())
|
|
updated_d.meta["name"] = "Updated"
|
|
updated_docs.append(updated_d)
|
|
|
|
with pytest.raises(DuplicateDocumentError):
|
|
ds.write_documents(updated_docs, duplicate_documents="fail")
|
|
|
|
@pytest.mark.integration
|
|
def test_write_document_meta(self, ds):
|
|
ds.write_documents(
|
|
[
|
|
{"content": "dict_without_meta", "id": "1"},
|
|
{"content": "dict_with_meta", "meta_field": "test2", "id": "2"},
|
|
Document(content="document_object_without_meta", id="3"),
|
|
Document(content="document_object_with_meta", meta={"meta_field": "test4"}, id="4"),
|
|
]
|
|
)
|
|
assert not ds.get_document_by_id("1").meta
|
|
assert ds.get_document_by_id("2").meta["meta_field"] == "test2"
|
|
assert not ds.get_document_by_id("3").meta
|
|
assert ds.get_document_by_id("4").meta["meta_field"] == "test4"
|
|
|
|
@pytest.mark.integration
|
|
def test_delete_documents(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
ds.delete_documents()
|
|
assert ds.get_document_count() == 0
|
|
|
|
@pytest.mark.integration
|
|
def test_delete_documents_with_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
ds.delete_documents(filters={"year": ["2020", "2021"]})
|
|
documents = ds.get_all_documents()
|
|
assert ds.get_document_count() == 3
|
|
|
|
@pytest.mark.integration
|
|
def test_delete_documents_by_id(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
docs_to_delete = ds.get_all_documents(filters={"year": ["2020"]})
|
|
ds.delete_documents(ids=[doc.id for doc in docs_to_delete])
|
|
assert ds.get_document_count() == 6
|
|
|
|
@pytest.mark.integration
|
|
def test_delete_documents_by_id_with_filters(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
docs_to_delete = ds.get_all_documents(filters={"year": ["2020"]})
|
|
# this should delete only 1 document out of the 3 ids passed
|
|
ds.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"name": ["name_0"]})
|
|
assert ds.get_document_count() == 8
|
|
|
|
@pytest.mark.integration
|
|
def test_write_get_all_labels(self, ds, labels):
|
|
ds.write_labels(labels)
|
|
ds.write_labels(labels[:3], index="custom_index")
|
|
assert len(ds.get_all_labels()) == 9
|
|
assert len(ds.get_all_labels(index="custom_index")) == 3
|
|
# remove the index we created in this test
|
|
ds.delete_index("custom_index")
|
|
|
|
@pytest.mark.integration
|
|
def test_delete_labels(self, ds, labels):
|
|
ds.write_labels(labels)
|
|
ds.write_labels(labels[:3], index="custom_index")
|
|
ds.delete_labels()
|
|
ds.delete_labels(index="custom_index")
|
|
assert len(ds.get_all_labels()) == 0
|
|
assert len(ds.get_all_labels(index="custom_index")) == 0
|
|
# remove the index we created in this test
|
|
ds.delete_index("custom_index")
|
|
|
|
@pytest.mark.integration
|
|
def test_write_labels_duplicate(self, ds, labels):
|
|
# create a duplicate
|
|
dupe = Label.from_dict(labels[0].to_dict())
|
|
|
|
ds.write_labels(labels + [dupe])
|
|
|
|
# ensure the duplicate was discarded
|
|
assert len(ds.get_all_labels()) == len(labels)
|
|
|
|
@pytest.mark.integration
|
|
def test_delete_labels_by_id(self, ds, labels):
|
|
ds.write_labels(labels)
|
|
ds.delete_labels(ids=[labels[0].id])
|
|
assert len(ds.get_all_labels()) == len(labels) - 1
|
|
|
|
@pytest.mark.integration
|
|
def test_delete_labels_by_filter(self, ds, labels):
|
|
ds.write_labels(labels)
|
|
ds.delete_labels(filters={"query": "query_1"})
|
|
assert len(ds.get_all_labels()) == len(labels) - 1
|
|
|
|
@pytest.mark.integration
|
|
def test_delete_labels_by_filter_id(self, ds, labels):
|
|
ds.write_labels(labels)
|
|
|
|
# ids and filters are ANDed, the following should have no effect
|
|
ds.delete_labels(ids=[labels[0].id], filters={"query": "query_9"})
|
|
assert len(ds.get_all_labels()) == len(labels)
|
|
|
|
#
|
|
ds.delete_labels(ids=[labels[0].id], filters={"query": "query_0"})
|
|
assert len(ds.get_all_labels()) == len(labels) - 1
|
|
|
|
@pytest.mark.integration
|
|
def test_get_label_count(self, ds, labels):
|
|
ds.write_labels(labels)
|
|
assert ds.get_label_count() == len(labels)
|
|
|
|
@pytest.mark.integration
|
|
def test_delete_index(self, ds, documents):
|
|
ds.write_documents(documents, index="custom_index")
|
|
assert ds.get_document_count(index="custom_index") == len(documents)
|
|
ds.delete_index(index="custom_index")
|
|
with pytest.raises(Exception):
|
|
ds.get_document_count(index="custom_index")
|
|
|
|
@pytest.mark.integration
|
|
def test_update_meta(self, ds, documents):
|
|
ds.write_documents(documents)
|
|
doc = documents[0]
|
|
ds.update_document_meta(doc.id, meta={"year": "2099", "month": "12"})
|
|
doc = ds.get_document_by_id(doc.id)
|
|
assert doc.meta["year"] == "2099"
|
|
assert doc.meta["month"] == "12"
|
|
|
|
@pytest.mark.integration
|
|
def test_labels_with_long_texts(self, ds, documents):
|
|
label = Label(
|
|
query="question1",
|
|
answer=Answer(
|
|
answer="answer",
|
|
type="extractive",
|
|
score=0.0,
|
|
context="something " * 10_000,
|
|
offsets_in_document=[Span(start=12, end=14)],
|
|
offsets_in_context=[Span(start=12, end=14)],
|
|
),
|
|
is_correct_answer=True,
|
|
is_correct_document=True,
|
|
document=Document(content="something " * 10_000, id="123"),
|
|
origin="gold-label",
|
|
)
|
|
ds.write_labels(labels=[label])
|
|
labels = ds.get_all_labels()
|
|
assert len(labels) == 1
|
|
assert label == labels[0]
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="_get_documents_meta() fails with 'too many SQL variables'")
|
|
def test_get_all_documents_large_quantities(self, ds):
|
|
# Test to exclude situations like Weaviate not returning more than 100 docs by default
|
|
# https://github.com/deepset-ai/haystack/issues/1893
|
|
docs_to_write = [
|
|
{"meta": {"name": f"name_{i}"}, "content": f"text_{i}", "embedding": np.random.rand(768).astype(np.float32)}
|
|
for i in range(1000)
|
|
]
|
|
ds.write_documents(docs_to_write)
|
|
documents = ds.get_all_documents()
|
|
assert all(isinstance(d, Document) for d in documents)
|
|
assert len(documents) == len(docs_to_write)
|
|
|
|
@pytest.mark.integration
|
|
def test_custom_embedding_field(self, ds):
|
|
ds.embedding_field = "custom_embedding_field"
|
|
doc_to_write = {"content": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
|
|
ds.write_documents([doc_to_write])
|
|
documents = ds.get_all_documents(return_embedding=True)
|
|
assert len(documents) == 1
|
|
assert documents[0].content == "test"
|
|
# Some document stores normalize the embedding on save, let's just compare the length
|
|
assert doc_to_write["custom_embedding_field"].shape == documents[0].embedding.shape
|
|
|
|
#
|
|
# Unit tests
|
|
#
|
|
|
|
@pytest.mark.unit
|
|
def test_normalize_embeddings_diff_shapes(self):
|
|
VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32")
|
|
BaseDocumentStore.normalize_embedding(VEC_1)
|
|
assert np.linalg.norm(VEC_1) - 1 < 0.01
|
|
|
|
VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32").reshape(1, -1)
|
|
BaseDocumentStore.normalize_embedding(VEC_1)
|
|
assert np.linalg.norm(VEC_1) - 1 < 0.01
|