import pytest import numpy as np from haystack.schema import Document, Label, Answer from haystack.errors import DuplicateDocumentError from haystack.document_stores import BaseDocumentStore @pytest.mark.document_store class DocumentStoreBaseTestAbstract: """ This is a base class to test abstract methods from DocumentStoreBase to be inherited by any Document Store testsuite. It doesn't have the `Test` prefix in the name so that its methods won't be collected for this class but only for its subclasses. """ @pytest.fixture def documents(self): documents = [] for i in range(3): documents.append( Document( content=f"A Foo Document {i}", meta={"name": f"name_{i}", "year": "2020", "month": "01", "numbers": [2, 4]}, embedding=np.random.rand(768).astype(np.float32), ) ) documents.append( Document( content=f"A Bar Document {i}", meta={"name": f"name_{i}", "year": "2021", "month": "02", "numbers": [-2, -4]}, embedding=np.random.rand(768).astype(np.float32), ) ) documents.append( Document( content=f"Document {i} without embeddings", meta={"name": f"name_{i}", "no_embedding": True, "month": "03"}, ) ) return documents @pytest.fixture def labels(self, documents): labels = [] for i, d in enumerate(documents): labels.append( Label( query=f"query_{i}", document=d, is_correct_document=True, is_correct_answer=False, # create a mix set of labels origin="user-feedback" if i % 2 else "gold-label", answer=None if not i else Answer(f"the answer is {i}"), meta={"name": f"label_{i}", "year": f"{2020 + i}"}, ) ) return labels # # Integration tests # @pytest.mark.integration def test_write_documents(self, ds, documents): ds.write_documents(documents) docs = ds.get_all_documents() assert len(docs) == len(documents) expected_ids = set(doc.id for doc in documents) ids = set(doc.id for doc in docs) assert ids == expected_ids @pytest.mark.integration def test_write_labels(self, ds, labels): ds.write_labels(labels) assert ds.get_all_labels() == labels @pytest.mark.integration def test_write_with_duplicate_doc_ids(self, ds): duplicate_documents = [ Document(content="Doc1", id_hash_keys=["content"], meta={"key1": "value1"}), Document(content="Doc1", id_hash_keys=["content"], meta={"key1": "value1"}), ] ds.write_documents(duplicate_documents, duplicate_documents="skip") results = ds.get_all_documents() assert len(results) == 1 assert results[0] == duplicate_documents[0] with pytest.raises(Exception): ds.write_documents(duplicate_documents, duplicate_documents="fail") @pytest.mark.skip @pytest.mark.integration def test_get_all_documents_without_filters(self, ds, documents): ds.write_documents(documents) out = ds.get_all_documents() assert out == documents @pytest.mark.integration def test_get_all_document_filter_duplicate_text_value(self, ds): documents = [ Document(content="duplicated", meta={"meta_field": "0"}, id_hash_keys=["meta"]), Document(content="duplicated", meta={"meta_field": "1", "name": "file.txt"}, id_hash_keys=["meta"]), Document(content="Doc2", meta={"name": "file_2.txt"}, id_hash_keys=["meta"]), ] ds.write_documents(documents) documents = ds.get_all_documents(filters={"meta_field": ["1"]}) assert len(documents) == 1 assert documents[0].content == "duplicated" assert documents[0].meta["name"] == "file.txt" documents = ds.get_all_documents(filters={"meta_field": ["0"]}) assert len(documents) == 1 assert documents[0].content == "duplicated" assert documents[0].meta.get("name") is None documents = ds.get_all_documents(filters={"name": ["file_2.txt"]}) assert len(documents) == 1 assert documents[0].content == "Doc2" assert documents[0].meta.get("meta_field") is None @pytest.mark.integration def test_get_all_documents_with_correct_filters(self, ds, documents): ds.write_documents(documents) result = ds.get_all_documents(filters={"year": ["2020"]}) assert len(result) == 3 documents = ds.get_all_documents(filters={"year": ["2020", "2021"]}) assert len(documents) == 6 @pytest.mark.integration def test_get_all_documents_with_incorrect_filter_name(self, ds, documents): ds.write_documents(documents) result = ds.get_all_documents(filters={"non_existing_meta_field": ["whatever"]}) assert len(result) == 0 @pytest.mark.integration def test_get_all_documents_with_incorrect_filter_value(self, ds, documents): ds.write_documents(documents) result = ds.get_all_documents(filters={"year": ["nope"]}) assert len(result) == 0 @pytest.mark.integration def test_eq_filters(self, ds, documents): ds.write_documents(documents) result = ds.get_all_documents(filters={"year": {"$eq": "2020"}}) assert len(result) == 3 result = ds.get_all_documents(filters={"year": "2020"}) assert len(result) == 3 @pytest.mark.integration def test_in_filters(self, ds, documents): ds.write_documents(documents) result = ds.get_all_documents(filters={"year": {"$in": ["2020", "2021", "n.a."]}}) assert len(result) == 6 result = ds.get_all_documents(filters={"year": ["2020", "2021", "n.a."]}) assert len(result) == 6 @pytest.mark.integration def test_ne_filters(self, ds, documents): ds.write_documents(documents) result = ds.get_all_documents(filters={"year": {"$ne": "2020"}}) assert len(result) == 6 @pytest.mark.integration def test_nin_filters(self, ds, documents): ds.write_documents(documents) result = ds.get_all_documents(filters={"year": {"$nin": ["2020", "2021", "n.a."]}}) assert len(result) == 3 @pytest.mark.integration def test_comparison_filters(self, ds, documents): ds.write_documents(documents) result = ds.get_all_documents(filters={"numbers": {"$gt": 0.0}}) assert len(result) == 3 result = ds.get_all_documents(filters={"numbers": {"$gte": -2.0}}) assert len(result) == 6 result = ds.get_all_documents(filters={"numbers": {"$lt": 0.0}}) assert len(result) == 3 result = ds.get_all_documents(filters={"numbers": {"$lte": 2.0}}) assert len(result) == 6 @pytest.mark.integration def test_compound_filters(self, ds, documents): ds.write_documents(documents) result = ds.get_all_documents(filters={"year": {"$lte": "2021", "$gte": "2020"}}) assert len(result) == 6 @pytest.mark.integration def test_simplified_filters(self, ds, documents): ds.write_documents(documents) filters = {"$and": {"year": {"$lte": "2021", "$gte": "2020"}, "name": {"$in": ["name_0", "name_1"]}}} result = ds.get_all_documents(filters=filters) assert len(result) == 4 filters_simplified = {"year": {"$lte": "2021", "$gte": "2020"}, "name": ["name_0", "name_1"]} result = ds.get_all_documents(filters=filters_simplified) assert len(result) == 4 @pytest.mark.integration def test_nested_condition_filters(self, ds, documents): ds.write_documents(documents) filters = { "$and": { "year": {"$lte": "2021", "$gte": "2020"}, "$or": {"name": {"$in": ["name_0", "name_1"]}, "numbers": {"$lt": 5.0}}, } } result = ds.get_all_documents(filters=filters) assert len(result) == 6 filters_simplified = { "year": {"$lte": "2021", "$gte": "2020"}, "$or": {"name": {"$in": ["name_0", "name_2"]}, "numbers": {"$lt": 5.0}}, } result = ds.get_all_documents(filters=filters_simplified) assert len(result) == 6 filters = { "$and": { "year": {"$lte": "2021", "$gte": "2020"}, "$or": { "name": {"$in": ["name_0", "name_1"]}, "$and": {"numbers": {"$lt": 5.0}, "$not": {"month": {"$eq": "01"}}}, }, } } result = ds.get_all_documents(filters=filters) assert len(result) == 5 filters_simplified = { "year": {"$lte": "2021", "$gte": "2020"}, "$or": {"name": ["name_0", "name_1"], "$and": {"numbers": {"$lt": 5.0}, "$not": {"month": {"$eq": "01"}}}}, } result = ds.get_all_documents(filters=filters_simplified) assert len(result) == 5 @pytest.mark.integration def test_nested_condition_not_filters(self, ds, documents): """ Test nested logical operations within "$not", important as we apply De Morgan's laws in WeaviateDocumentstore """ ds.write_documents(documents) filters = { "$not": { "$or": { "$and": {"numbers": {"$lt": 5.0}, "month": {"$ne": "01"}}, "$not": {"year": {"$lte": "2021", "$gte": "2020"}}, } } } result = ds.get_all_documents(filters=filters) assert len(result) == 3 docs_meta = result[0].meta["numbers"] assert [2, 4] == docs_meta # Test same logical operator twice on same level filters = { "$or": [ {"$and": {"name": {"$in": ["name_0", "name_1"]}, "year": {"$gte": "2020"}}}, {"$and": {"name": {"$in": ["name_0", "name_1"]}, "year": {"$lt": "2021"}}}, ] } result = ds.get_all_documents(filters=filters) docs_meta = [doc.meta["name"] for doc in result] assert len(result) == 4 assert "name_0" in docs_meta assert "name_2" not in docs_meta @pytest.mark.integration def test_get_document_by_id(self, ds, documents): ds.write_documents(documents) doc = ds.get_document_by_id(documents[0].id) assert doc.id == documents[0].id assert doc.content == documents[0].content @pytest.mark.integration def test_get_documents_by_id(self, ds, documents): ds.write_documents(documents) ids = [doc.id for doc in documents] result = {doc.id for doc in ds.get_documents_by_id(ids, batch_size=2)} assert set(ids) == result @pytest.mark.integration def test_get_document_count(self, ds, documents): ds.write_documents(documents) assert ds.get_document_count() == len(documents) assert ds.get_document_count(filters={"year": ["2020"]}) == 3 assert ds.get_document_count(filters={"month": ["02"]}) == 3 @pytest.mark.integration def test_get_all_documents_generator(self, ds, documents): ds.write_documents(documents) assert len(list(ds.get_all_documents_generator(batch_size=2))) == 9 @pytest.mark.integration def test_duplicate_documents_skip(self, ds, documents): ds.write_documents(documents) updated_docs = [] for d in documents: updated_d = Document.from_dict(d.to_dict()) updated_d.meta["name"] = "Updated" updated_docs.append(updated_d) ds.write_documents(updated_docs, duplicate_documents="skip") for d in ds.get_all_documents(): assert d.meta.get("name") != "Updated" @pytest.mark.integration def test_duplicate_documents_overwrite(self, ds, documents): ds.write_documents(documents) updated_docs = [] for d in documents: updated_d = Document.from_dict(d.to_dict()) updated_d.meta["name"] = "Updated" updated_docs.append(updated_d) ds.write_documents(updated_docs, duplicate_documents="overwrite") for doc in ds.get_all_documents(): assert doc.meta["name"] == "Updated" @pytest.mark.integration def test_duplicate_documents_fail(self, ds, documents): ds.write_documents(documents) updated_docs = [] for d in documents: updated_d = Document.from_dict(d.to_dict()) updated_d.meta["name"] = "Updated" updated_docs.append(updated_d) with pytest.raises(DuplicateDocumentError): ds.write_documents(updated_docs, duplicate_documents="fail") @pytest.mark.integration def test_write_document_meta(self, ds): ds.write_documents( [ {"content": "dict_without_meta", "id": "1"}, {"content": "dict_with_meta", "meta_field": "test2", "id": "2"}, Document(content="document_object_without_meta", id="3"), Document(content="document_object_with_meta", meta={"meta_field": "test4"}, id="4"), ] ) assert not ds.get_document_by_id("1").meta assert ds.get_document_by_id("2").meta["meta_field"] == "test2" assert not ds.get_document_by_id("3").meta assert ds.get_document_by_id("4").meta["meta_field"] == "test4" @pytest.mark.integration def test_delete_documents(self, ds, documents): ds.write_documents(documents) ds.delete_documents() assert ds.get_document_count() == 0 @pytest.mark.integration def test_delete_documents_with_filters(self, ds, documents): ds.write_documents(documents) ds.delete_documents(filters={"year": ["2020", "2021"]}) documents = ds.get_all_documents() assert ds.get_document_count() == 3 @pytest.mark.integration def test_delete_documents_by_id(self, ds, documents): ds.write_documents(documents) docs_to_delete = ds.get_all_documents(filters={"year": ["2020"]}) ds.delete_documents(ids=[doc.id for doc in docs_to_delete]) assert ds.get_document_count() == 6 @pytest.mark.integration def test_write_get_all_labels(self, ds, labels): ds.write_labels(labels) ds.write_labels(labels[:3], index="custom_index") assert len(ds.get_all_labels()) == 9 assert len(ds.get_all_labels(index="custom_index")) == 3 # remove the index we created in this test ds.delete_index("custom_index") @pytest.mark.integration def test_delete_labels(self, ds, labels): ds.write_labels(labels) ds.write_labels(labels[:3], index="custom_index") ds.delete_labels() ds.delete_labels(index="custom_index") assert len(ds.get_all_labels()) == 0 assert len(ds.get_all_labels(index="custom_index")) == 0 # remove the index we created in this test ds.delete_index("custom_index") @pytest.mark.integration def test_write_labels_duplicate(self, ds, labels): # create a duplicate dupe = Label.from_dict(labels[0].to_dict()) ds.write_labels(labels + [dupe]) # ensure the duplicate was discarded assert len(ds.get_all_labels()) == len(labels) @pytest.mark.integration def test_delete_labels_by_id(self, ds, labels): ds.write_labels(labels) ds.delete_labels(ids=[labels[0].id]) assert len(ds.get_all_labels()) == len(labels) - 1 @pytest.mark.integration def test_delete_labels_by_filter(self, ds, labels): ds.write_labels(labels) ds.delete_labels(filters={"query": "query_1"}) assert len(ds.get_all_labels()) == len(labels) - 1 @pytest.mark.integration def test_delete_labels_by_filter_id(self, ds, labels): ds.write_labels(labels) # ids and filters are ANDed, the following should have no effect ds.delete_labels(ids=[labels[0].id], filters={"query": "query_9"}) assert len(ds.get_all_labels()) == len(labels) # ds.delete_labels(ids=[labels[0].id], filters={"query": "query_0"}) assert len(ds.get_all_labels()) == len(labels) - 1 @pytest.mark.integration def test_get_label_count(self, ds, labels): ds.write_labels(labels) assert ds.get_label_count() == len(labels) @pytest.mark.integration def test_delete_index(self, ds, documents): ds.write_documents(documents, index="custom_index") assert ds.get_document_count(index="custom_index") == len(documents) ds.delete_index(index="custom_index") with pytest.raises(Exception): ds.get_document_count(index="custom_index") @pytest.mark.integration def test_update_meta(self, ds, documents): ds.write_documents(documents) doc = documents[0] ds.update_document_meta(doc.id, meta={"year": "2099", "month": "12"}) doc = ds.get_document_by_id(doc.id) assert doc.meta["year"] == "2099" assert doc.meta["month"] == "12" # # Unit tests # @pytest.mark.unit def test_normalize_embeddings_diff_shapes(self): VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32") BaseDocumentStore.normalize_embedding(VEC_1) assert np.linalg.norm(VEC_1) - 1 < 0.01 VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32").reshape(1, -1) BaseDocumentStore.normalize_embedding(VEC_1) assert np.linalg.norm(VEC_1) - 1 < 0.01