diff --git a/test/document_stores/test_base.py b/test/document_stores/test_base.py index f881097f4..f6bbc31b7 100644 --- a/test/document_stores/test_base.py +++ b/test/document_stores/test_base.py @@ -3,7 +3,7 @@ import sys import pytest import numpy as np -from haystack.schema import Document, Label, Answer +from haystack.schema import Document, Label, Answer, Span from haystack.errors import DuplicateDocumentError from haystack.document_stores import BaseDocumentStore @@ -101,6 +101,13 @@ class DocumentStoreBaseTestAbstract: out = ds.get_all_documents() assert out == documents + @pytest.mark.integration + def test_get_all_documents_without_embeddings(self, ds, documents): + ds.write_documents(documents) + out = ds.get_all_documents(return_embedding=False) + for doc in out: + assert doc.embedding is None + @pytest.mark.integration def test_get_all_document_filter_duplicate_text_value(self, ds): documents = [ @@ -386,6 +393,14 @@ class DocumentStoreBaseTestAbstract: ds.delete_documents(ids=[doc.id for doc in docs_to_delete]) assert ds.get_document_count() == 6 + @pytest.mark.integration + def test_delete_documents_by_id_with_filters(self, ds, documents): + ds.write_documents(documents) + docs_to_delete = ds.get_all_documents(filters={"year": ["2020"]}) + # this should delete only 1 document out of the 3 ids passed + ds.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"name": ["name_0"]}) + assert ds.get_document_count() == 8 + @pytest.mark.integration def test_write_get_all_labels(self, ds, labels): ds.write_labels(labels) @@ -462,6 +477,28 @@ class DocumentStoreBaseTestAbstract: assert doc.meta["year"] == "2099" assert doc.meta["month"] == "12" + @pytest.mark.integration + def test_labels_with_long_texts(self, ds, documents): + label = Label( + query="question1", + answer=Answer( + answer="answer", + type="extractive", + score=0.0, + context="something " * 10_000, + offsets_in_document=[Span(start=12, end=14)], + offsets_in_context=[Span(start=12, end=14)], + ), + is_correct_answer=True, + is_correct_document=True, + document=Document(content="something " * 10_000, id="123"), + origin="gold-label", + ) + ds.write_labels(labels=[label]) + labels = ds.get_all_labels() + assert len(labels) == 1 + assert label == labels[0] + @pytest.mark.integration @pytest.mark.skipif(sys.platform == "win32", reason="_get_documents_meta() fails with 'too many SQL variables'") def test_get_all_documents_large_quantities(self, ds): @@ -476,6 +513,309 @@ class DocumentStoreBaseTestAbstract: assert all(isinstance(d, Document) for d in documents) assert len(documents) == len(docs_to_write) + @pytest.mark.integration + def test_multilabel(self, ds): + labels = [ + Label( + id="standard", + query="question", + answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + ), + # different answer in same doc + Label( + id="diff-answer-same-doc", + query="question", + answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + ), + # answer in different doc + Label( + id="diff-answer-diff-doc", + query="question", + answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some other", id="333"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + ), + # 'no answer', should be excluded from MultiLabel + Label( + id="4-no-answer", + query="question", + answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), + document=Document(content="some", id="777"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + ), + # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" + Label( + id="5-negative", + query="question", + answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=False, + is_correct_document=True, + origin="gold-label", + ), + ] + ds.write_labels(labels) + + # Multi labels (open domain) + multi_labels_open = ds.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True) + + # for open-domain we group all together as long as they have the same question + assert len(multi_labels_open) == 1 + # all labels are in there except the negative one and the no_answer + assert len(multi_labels_open[0].labels) == 4 + assert len(multi_labels_open[0].answers) == 3 + assert "5-negative" not in [l.id for l in multi_labels_open[0].labels] + assert len(multi_labels_open[0].document_ids) == 3 + + # Don't drop the negative label + multi_labels_open = ds.get_all_labels_aggregated( + open_domain=True, drop_no_answers=False, drop_negative_labels=False + ) + assert len(multi_labels_open[0].labels) == 5 + assert len(multi_labels_open[0].answers) == 4 + assert len(multi_labels_open[0].document_ids) == 4 + + # Drop no answer + negative + multi_labels_open = ds.get_all_labels_aggregated( + open_domain=True, drop_no_answers=True, drop_negative_labels=True + ) + assert len(multi_labels_open[0].labels) == 3 + assert len(multi_labels_open[0].answers) == 3 + assert len(multi_labels_open[0].document_ids) == 3 + + # for closed domain we group by document so we expect 3 multilabels with 2,1,1 labels each (negative dropped again) + multi_labels = ds.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True) + assert len(multi_labels) == 3 + label_counts = set([len(ml.labels) for ml in multi_labels]) + assert label_counts == set([2, 1, 1]) + + assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids) + + @pytest.mark.integration + def test_multilabel_no_answer(self, ds): + labels = [ + Label( + query="question", + answer=Answer(answer=""), + is_correct_answer=True, + is_correct_document=True, + document=Document(content="some", id="777"), + origin="gold-label", + ), + # no answer in different doc + Label( + query="question", + answer=Answer(answer=""), + is_correct_answer=True, + is_correct_document=True, + document=Document(content="some", id="123"), + origin="gold-label", + ), + # no answer in same doc, should be excluded + Label( + query="question", + answer=Answer(answer=""), + is_correct_answer=True, + is_correct_document=True, + document=Document(content="some", id="777"), + origin="gold-label", + ), + # no answer with is_correct_answer=False, should be excluded + Label( + query="question", + answer=Answer(answer=""), + is_correct_answer=False, + is_correct_document=True, + document=Document(content="some", id="777"), + origin="gold-label", + ), + ] + + ds.write_labels(labels) + + multi_labels = ds.get_all_labels_aggregated(open_domain=True, drop_no_answers=False, drop_negative_labels=True) + assert len(multi_labels) == 1 + assert multi_labels[0].no_answer == True + assert len(multi_labels[0].document_ids) == 0 + assert len(multi_labels[0].answers) == 1 + + multi_labels = ds.get_all_labels_aggregated(open_domain=True, drop_no_answers=False, drop_negative_labels=False) + assert len(multi_labels) == 1 + assert multi_labels[0].no_answer == True + assert len(multi_labels[0].document_ids) == 0 + assert len(multi_labels[0].labels) == 3 + assert len(multi_labels[0].answers) == 1 + + @pytest.mark.integration + def test_multilabel_filter_aggregations(self, ds): + labels = [ + Label( + id="standard", + query="question", + answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + filters={"name": ["123"]}, + ), + # different answer in same doc + Label( + id="diff-answer-same-doc", + query="question", + answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + filters={"name": ["123"]}, + ), + # answer in different doc + Label( + id="diff-answer-diff-doc", + query="question", + answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some other", id="333"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + filters={"name": ["333"]}, + ), + # 'no answer', should be excluded from MultiLabel + Label( + id="4-no-answer", + query="question", + answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), + document=Document(content="some", id="777"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + filters={"name": ["777"]}, + ), + # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" + Label( + id="5-negative", + query="question", + answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=False, + is_correct_document=True, + origin="gold-label", + filters={"name": ["123"]}, + ), + ] + ds.write_labels(labels) + + # Multi labels (open domain) + multi_labels_open = ds.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True) + + # for open-domain we group all together as long as they have the same question and filters + assert len(multi_labels_open) == 3 + label_counts = set([len(ml.labels) for ml in multi_labels_open]) + assert label_counts == set([2, 1, 1]) + # all labels are in there except the negative one and the no_answer + assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels] + + assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids) + + # for closed domain we group by document so we expect the same as with filters + multi_labels = ds.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True) + assert len(multi_labels) == 3 + label_counts = set([len(ml.labels) for ml in multi_labels]) + assert label_counts == set([2, 1, 1]) + + assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids) + + @pytest.mark.integration + def test_multilabel_meta_aggregations(self, ds): + labels = [ + Label( + id="standard", + query="question", + answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + meta={"file_id": ["123"]}, + ), + # different answer in same doc + Label( + id="diff-answer-same-doc", + query="question", + answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + meta={"file_id": ["123"]}, + ), + # answer in different doc + Label( + id="diff-answer-diff-doc", + query="question", + answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some other", id="333"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + meta={"file_id": ["333"]}, + ), + # 'no answer', should be excluded from MultiLabel + Label( + id="4-no-answer", + query="question", + answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), + document=Document(content="some", id="777"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + meta={"file_id": ["777"]}, + ), + # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" + Label( + id="5-888", + query="question", + answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + origin="gold-label", + meta={"file_id": ["888"]}, + ), + ] + ds.write_labels(labels) + + # Multi labels (open domain) + multi_labels_open = ds.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True) + + # for open-domain we group all together as long as they have the same question and filters + assert len(multi_labels_open) == 1 + assert len(multi_labels_open[0].labels) == 5 + + multi_labels = ds.get_all_labels_aggregated( + open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id" + ) + assert len(multi_labels) == 4 + label_counts = set([len(ml.labels) for ml in multi_labels]) + assert label_counts == set([2, 1, 1, 1]) + for multi_label in multi_labels: + for l in multi_label.labels: + assert l.filters == l.meta + assert multi_label.filters == l.filters + # # Unit tests # diff --git a/test/document_stores/test_document_store.py b/test/document_stores/test_document_store.py index f76339a32..f5d8c1a25 100644 --- a/test/document_stores/test_document_store.py +++ b/test/document_stores/test_document_store.py @@ -3,7 +3,6 @@ import math import numpy as np import pandas as pd -from rank_bm25 import BM25 import pytest from unittest.mock import Mock @@ -20,7 +19,6 @@ from haystack.document_stores import ( from haystack.document_stores.base import BaseDocumentStore from haystack.document_stores.es_converter import elasticsearch_index_to_document_store -from haystack.errors import DuplicateDocumentError from haystack.schema import Document, Label, Answer, Span from haystack.nodes import EmbeddingRetriever, PreProcessor from haystack.pipelines import DocumentSearchPipeline @@ -60,53 +58,6 @@ DOCUMENTS = [ ] -@pytest.mark.parametrize( - "document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate", "pinecone"], indirect=True -) -def test_write_with_duplicate_doc_ids_custom_index(document_store: BaseDocumentStore): - duplicate_documents = [ - Document(content="Doc1", id_hash_keys=["content"]), - Document(content="Doc1", id_hash_keys=["content"]), - ] - document_store.delete_index(index="haystack_custom_test") - document_store.write_documents(duplicate_documents, index="haystack_custom_test", duplicate_documents="skip") - assert len(document_store.get_all_documents(index="haystack_custom_test")) == 1 - with pytest.raises(DuplicateDocumentError): - document_store.write_documents(duplicate_documents, index="haystack_custom_test", duplicate_documents="fail") - - # Weaviate manipulates document objects in-place when writing them to an index. - # It generates a uuid based on the provided id and the index name where the document is added to. - # We need to get rid of these generated uuids for this test and therefore reset the document objects. - # As a result, the documents will receive a fresh uuid based on their id_hash_keys and a different index name. - if isinstance(document_store, WeaviateDocumentStore): - duplicate_documents = [ - Document(content="Doc1", id_hash_keys=["content"]), - Document(content="Doc1", id_hash_keys=["content"]), - ] - # writing to the default, empty index should still work - document_store.write_documents(duplicate_documents, duplicate_documents="fail") - - -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate"], indirect=True) -def test_document_with_embeddings(document_store: BaseDocumentStore): - documents = [ - {"content": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)}, - {"content": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64)}, - {"content": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()}, - {"content": "text4", "id": "4", "embedding": np.random.rand(768).astype(np.float32)}, - ] - document_store.write_documents(documents) - assert len(document_store.get_all_documents()) == 4 - - if not isinstance(document_store, WeaviateDocumentStore): - # weaviate is excluded because it would return dummy vectors instead of None - documents_without_embedding = document_store.get_all_documents(return_embedding=False) - assert documents_without_embedding[0].embedding is None - - documents_with_embedding = document_store.get_all_documents(return_embedding=True) - assert isinstance(documents_with_embedding[0].embedding, (list, np.ndarray)) - - @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate"], indirect=True) @pytest.mark.parametrize("retriever", ["embedding"], indirect=True) def test_update_embeddings(document_store, retriever): @@ -268,381 +219,6 @@ def test_update_embeddings_table_text_retriever(document_store, retriever): ) -def test_delete_documents_by_id_with_filters(document_store_with_docs): - docs_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test2"]}) - docs_not_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test3"]}) - - document_store_with_docs.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"meta_field": ["test1"]}) - - all_docs_left = document_store_with_docs.get_all_documents() - assert len(all_docs_left) == 4 - assert all(doc.meta["meta_field"] != "test1" for doc in all_docs_left) - - all_ids_left = [doc.id for doc in all_docs_left] - assert all(doc.id in all_ids_left for doc in docs_not_to_delete) - - -@pytest.mark.parametrize("document_store", ["elasticsearch", "opensearch"], indirect=True) -def test_labels_with_long_texts(document_store: BaseDocumentStore): - document_store.delete_index("label") - label = Label( - query="question1", - answer=Answer( - answer="answer", - type="extractive", - score=0.0, - context="something " * 10_000, - offsets_in_document=[Span(start=12, end=14)], - offsets_in_context=[Span(start=12, end=14)], - ), - is_correct_answer=True, - is_correct_document=True, - document=Document(content="something " * 10_000, id="123"), - origin="gold-label", - ) - document_store.write_labels(labels=[label], index="label") - labels = document_store.get_all_labels(index="label") - assert len(labels) == 1 - assert label == labels[0] - - -# exclude weaviate because it does not support storing labels -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "pinecone"], indirect=True) -def test_multilabel(document_store: BaseDocumentStore): - labels = [ - Label( - id="standard", - query="question", - answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some", id="123"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - ), - # different answer in same doc - Label( - id="diff-answer-same-doc", - query="question", - answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some", id="123"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - ), - # answer in different doc - Label( - id="diff-answer-diff-doc", - query="question", - answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some other", id="333"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - ), - # 'no answer', should be excluded from MultiLabel - Label( - id="4-no-answer", - query="question", - answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), - document=Document(content="some", id="777"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - ), - # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" - Label( - id="5-negative", - query="question", - answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some", id="123"), - is_correct_answer=False, - is_correct_document=True, - origin="gold-label", - ), - ] - document_store.write_labels(labels) - # regular labels - not aggregated - list_labels = document_store.get_all_labels() - assert set(list_labels) == set(labels) - assert len(list_labels) == 5 - - # Currently we don't enforce writing (missing) docs automatically when adding labels and there's no DB relationship between the two. - # We should introduce this when we refactored the logic of "index" to be rather a "collection" of labels+documents - # docs = document_store.get_all_documents() - # assert len(docs) == 3 - - # Multi labels (open domain) - multi_labels_open = document_store.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True) - - # for open-domain we group all together as long as they have the same question - assert len(multi_labels_open) == 1 - # all labels are in there except the negative one and the no_answer - assert len(multi_labels_open[0].labels) == 4 - assert len(multi_labels_open[0].answers) == 3 - assert "5-negative" not in [l.id for l in multi_labels_open[0].labels] - assert len(multi_labels_open[0].document_ids) == 3 - - # Don't drop the negative label - multi_labels_open = document_store.get_all_labels_aggregated( - open_domain=True, drop_no_answers=False, drop_negative_labels=False - ) - assert len(multi_labels_open[0].labels) == 5 - assert len(multi_labels_open[0].answers) == 4 - assert len(multi_labels_open[0].document_ids) == 4 - - # Drop no answer + negative - multi_labels_open = document_store.get_all_labels_aggregated( - open_domain=True, drop_no_answers=True, drop_negative_labels=True - ) - assert len(multi_labels_open[0].labels) == 3 - assert len(multi_labels_open[0].answers) == 3 - assert len(multi_labels_open[0].document_ids) == 3 - - # for closed domain we group by document so we expect 3 multilabels with 2,1,1 labels each (negative dropped again) - multi_labels = document_store.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True) - assert len(multi_labels) == 3 - label_counts = set([len(ml.labels) for ml in multi_labels]) - assert label_counts == set([2, 1, 1]) - - assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids) - - -# exclude weaviate because it does not support storing labels -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "pinecone"], indirect=True) -def test_multilabel_no_answer(document_store: BaseDocumentStore): - labels = [ - Label( - query="question", - answer=Answer(answer=""), - is_correct_answer=True, - is_correct_document=True, - document=Document(content="some", id="777"), - origin="gold-label", - ), - # no answer in different doc - Label( - query="question", - answer=Answer(answer=""), - is_correct_answer=True, - is_correct_document=True, - document=Document(content="some", id="123"), - origin="gold-label", - ), - # no answer in same doc, should be excluded - Label( - query="question", - answer=Answer(answer=""), - is_correct_answer=True, - is_correct_document=True, - document=Document(content="some", id="777"), - origin="gold-label", - ), - # no answer with is_correct_answer=False, should be excluded - Label( - query="question", - answer=Answer(answer=""), - is_correct_answer=False, - is_correct_document=True, - document=Document(content="some", id="777"), - origin="gold-label", - ), - ] - - document_store.write_labels(labels) - - labels = document_store.get_all_labels() - assert len(labels) == 4 - - multi_labels = document_store.get_all_labels_aggregated( - open_domain=True, drop_no_answers=False, drop_negative_labels=True - ) - assert len(multi_labels) == 1 - assert multi_labels[0].no_answer == True - assert len(multi_labels[0].document_ids) == 0 - assert len(multi_labels[0].answers) == 1 - - multi_labels = document_store.get_all_labels_aggregated( - open_domain=True, drop_no_answers=False, drop_negative_labels=False - ) - assert len(multi_labels) == 1 - assert multi_labels[0].no_answer == True - assert len(multi_labels[0].document_ids) == 0 - assert len(multi_labels[0].labels) == 3 - assert len(multi_labels[0].answers) == 1 - - -# exclude weaviate because it does not support storing labels -# exclude faiss and milvus as label metadata is not implemented -@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True) -def test_multilabel_filter_aggregations(document_store: BaseDocumentStore): - labels = [ - Label( - id="standard", - query="question", - answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some", id="123"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - filters={"name": ["123"]}, - ), - # different answer in same doc - Label( - id="diff-answer-same-doc", - query="question", - answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some", id="123"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - filters={"name": ["123"]}, - ), - # answer in different doc - Label( - id="diff-answer-diff-doc", - query="question", - answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some other", id="333"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - filters={"name": ["333"]}, - ), - # 'no answer', should be excluded from MultiLabel - Label( - id="4-no-answer", - query="question", - answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), - document=Document(content="some", id="777"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - filters={"name": ["777"]}, - ), - # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" - Label( - id="5-negative", - query="question", - answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some", id="123"), - is_correct_answer=False, - is_correct_document=True, - origin="gold-label", - filters={"name": ["123"]}, - ), - ] - document_store.write_labels(labels) - # regular labels - not aggregated - list_labels = document_store.get_all_labels() - assert list_labels == labels - assert len(list_labels) == 5 - - # Multi labels (open domain) - multi_labels_open = document_store.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True) - - # for open-domain we group all together as long as they have the same question and filters - assert len(multi_labels_open) == 3 - label_counts = set([len(ml.labels) for ml in multi_labels_open]) - assert label_counts == set([2, 1, 1]) - # all labels are in there except the negative one and the no_answer - assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels] - - assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids) - - # for closed domain we group by document so we expect the same as with filters - multi_labels = document_store.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True) - assert len(multi_labels) == 3 - label_counts = set([len(ml.labels) for ml in multi_labels]) - assert label_counts == set([2, 1, 1]) - - assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids) - - -# exclude weaviate because it does not support storing labels -# exclude faiss and milvus as label metadata is not implemented -@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True) -def test_multilabel_meta_aggregations(document_store: BaseDocumentStore): - labels = [ - Label( - id="standard", - query="question", - answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some", id="123"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - meta={"file_id": ["123"]}, - ), - # different answer in same doc - Label( - id="diff-answer-same-doc", - query="question", - answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some", id="123"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - meta={"file_id": ["123"]}, - ), - # answer in different doc - Label( - id="diff-answer-diff-doc", - query="question", - answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some other", id="333"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - meta={"file_id": ["333"]}, - ), - # 'no answer', should be excluded from MultiLabel - Label( - id="4-no-answer", - query="question", - answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), - document=Document(content="some", id="777"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - meta={"file_id": ["777"]}, - ), - # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" - Label( - id="5-888", - query="question", - answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), - document=Document(content="some", id="123"), - is_correct_answer=True, - is_correct_document=True, - origin="gold-label", - meta={"file_id": ["888"]}, - ), - ] - document_store.write_labels(labels) - # regular labels - not aggregated - list_labels = document_store.get_all_labels() - assert list_labels == labels - assert len(list_labels) == 5 - - # Multi labels (open domain) - multi_labels_open = document_store.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True) - - # for open-domain we group all together as long as they have the same question and filters - assert len(multi_labels_open) == 1 - assert len(multi_labels_open[0].labels) == 5 - - multi_labels = document_store.get_all_labels_aggregated( - open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id" - ) - assert len(multi_labels) == 4 - label_counts = set([len(ml.labels) for ml in multi_labels]) - assert label_counts == set([2, 1, 1, 1]) - for multi_label in multi_labels: - for l in multi_label.labels: - assert l.filters == l.meta - assert multi_label.filters == l.filters - - @pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"]) def test_custom_embedding_field(document_store_type, tmp_path): document_store = get_document_store( diff --git a/test/document_stores/test_faiss.py b/test/document_stores/test_faiss.py index dc21e7662..a242e812a 100644 --- a/test/document_stores/test_faiss.py +++ b/test/document_stores/test_faiss.py @@ -245,12 +245,22 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract): def test_nested_condition_not_filters(self, ds, documents): pass - @pytest.mark.skip + @pytest.mark.skip(reason="labels metadata are not supported") @pytest.mark.integration def test_delete_labels_by_filter(self, ds, labels): pass - @pytest.mark.skip + @pytest.mark.skip(reason="labels metadata are not supported") @pytest.mark.integration def test_delete_labels_by_filter_id(self, ds, labels): pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel_filter_aggregations(self): + pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel_meta_aggregations(self): + pass diff --git a/test/document_stores/test_milvus.py b/test/document_stores/test_milvus.py index bc7650a4e..13dc02ec2 100644 --- a/test/document_stores/test_milvus.py +++ b/test/document_stores/test_milvus.py @@ -86,12 +86,22 @@ class TestMilvusDocumentStore(DocumentStoreBaseTestAbstract): # NOTE: again inherithed from the SQLDocumentStore, labels metadata are not supported - @pytest.mark.skip + @pytest.mark.skip(reason="labels metadata are not supported") @pytest.mark.integration def test_delete_labels_by_filter(self, ds, labels): pass - @pytest.mark.skip + @pytest.mark.skip(reason="labels metadata are not supported") @pytest.mark.integration def test_delete_labels_by_filter_id(self, ds, labels): pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel_filter_aggregations(self): + pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel_meta_aggregations(self): + pass diff --git a/test/document_stores/test_pinecone.py b/test/document_stores/test_pinecone.py index 8b1dd85a3..a7abf28f6 100644 --- a/test/document_stores/test_pinecone.py +++ b/test/document_stores/test_pinecone.py @@ -190,6 +190,11 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract): def test_nested_condition_not_filters(self, ds, documents): pass + @pytest.mark.skip + @pytest.mark.integration + def test_delete_documents_by_id_with_filters(self, ds, documents): + pass + # NOTE: labels metadata are not supported @pytest.mark.skip @@ -207,6 +212,31 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract): def test_simplified_filters(self, ds, documents): pass + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_labels_with_long_texts(self): + pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel(self): + pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel_no_answer(self): + pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel_filter_aggregations(self): + pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel_meta_aggregations(self): + pass + # NOTE: Pinecone does not support dates, so it can't do lte or gte on date fields. When a new release introduces this feature, # the entire family of test_get_all_documents_extended_filter_* tests will become identical to the one present in the # base document store suite, and can be removed from here. diff --git a/test/document_stores/test_sql.py b/test/document_stores/test_sql.py index ee777c61c..66743d824 100644 --- a/test/document_stores/test_sql.py +++ b/test/document_stores/test_sql.py @@ -108,14 +108,22 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract): def test_nested_condition_not_filters(self, ds, documents): pass - # NOTE: labels metadata are not supported - - @pytest.mark.skip + @pytest.mark.skip(reason="labels metadata are not supported") @pytest.mark.integration def test_delete_labels_by_filter(self, ds, labels): pass - @pytest.mark.skip + @pytest.mark.skip(reason="labels metadata are not supported") @pytest.mark.integration def test_delete_labels_by_filter_id(self, ds, labels): pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel_filter_aggregations(self): + pass + + @pytest.mark.skip(reason="labels metadata are not supported") + @pytest.mark.integration + def test_multilabel_meta_aggregations(self): + pass diff --git a/test/document_stores/test_weaviate.py b/test/document_stores/test_weaviate.py index 09412ac1d..a92ad9846 100644 --- a/test/document_stores/test_weaviate.py +++ b/test/document_stores/test_weaviate.py @@ -103,6 +103,31 @@ class TestWeaviateDocumentStore(DocumentStoreBaseTestAbstract): def test_write_get_all_labels(self): pass + @pytest.mark.skip(reason="Weaviate does not support labels") + @pytest.mark.integration + def test_labels_with_long_texts(self): + pass + + @pytest.mark.skip(reason="Weaviate does not support labels") + @pytest.mark.integration + def test_multilabel(self): + pass + + @pytest.mark.skip(reason="Weaviate does not support labels") + @pytest.mark.integration + def test_multilabel_no_answer(self): + pass + + @pytest.mark.skip(reason="Weaviate does not support labels") + @pytest.mark.integration + def test_multilabel_filter_aggregations(self): + pass + + @pytest.mark.skip(reason="Weaviate does not support labels") + @pytest.mark.integration + def test_multilabel_meta_aggregations(self): + pass + @pytest.mark.integration def test_ne_filters(self, ds, documents): """