refactor: move more tests to the base class (#3637)

* move more tests to the base class

* skip tests where unsupported

* do not pass index label explicitly

* skip test for Pinecone
This commit is contained in:
Massimiliano Pippi 2022-11-29 08:43:27 +01:00 committed by GitHub
parent 839eef6695
commit b20f808119
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 432 additions and 433 deletions

View File

@ -3,7 +3,7 @@ import sys
import pytest import pytest
import numpy as np import numpy as np
from haystack.schema import Document, Label, Answer from haystack.schema import Document, Label, Answer, Span
from haystack.errors import DuplicateDocumentError from haystack.errors import DuplicateDocumentError
from haystack.document_stores import BaseDocumentStore from haystack.document_stores import BaseDocumentStore
@ -101,6 +101,13 @@ class DocumentStoreBaseTestAbstract:
out = ds.get_all_documents() out = ds.get_all_documents()
assert out == documents assert out == documents
@pytest.mark.integration
def test_get_all_documents_without_embeddings(self, ds, documents):
ds.write_documents(documents)
out = ds.get_all_documents(return_embedding=False)
for doc in out:
assert doc.embedding is None
@pytest.mark.integration @pytest.mark.integration
def test_get_all_document_filter_duplicate_text_value(self, ds): def test_get_all_document_filter_duplicate_text_value(self, ds):
documents = [ documents = [
@ -386,6 +393,14 @@ class DocumentStoreBaseTestAbstract:
ds.delete_documents(ids=[doc.id for doc in docs_to_delete]) ds.delete_documents(ids=[doc.id for doc in docs_to_delete])
assert ds.get_document_count() == 6 assert ds.get_document_count() == 6
@pytest.mark.integration
def test_delete_documents_by_id_with_filters(self, ds, documents):
ds.write_documents(documents)
docs_to_delete = ds.get_all_documents(filters={"year": ["2020"]})
# this should delete only 1 document out of the 3 ids passed
ds.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"name": ["name_0"]})
assert ds.get_document_count() == 8
@pytest.mark.integration @pytest.mark.integration
def test_write_get_all_labels(self, ds, labels): def test_write_get_all_labels(self, ds, labels):
ds.write_labels(labels) ds.write_labels(labels)
@ -462,6 +477,28 @@ class DocumentStoreBaseTestAbstract:
assert doc.meta["year"] == "2099" assert doc.meta["year"] == "2099"
assert doc.meta["month"] == "12" assert doc.meta["month"] == "12"
@pytest.mark.integration
def test_labels_with_long_texts(self, ds, documents):
label = Label(
query="question1",
answer=Answer(
answer="answer",
type="extractive",
score=0.0,
context="something " * 10_000,
offsets_in_document=[Span(start=12, end=14)],
offsets_in_context=[Span(start=12, end=14)],
),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="something " * 10_000, id="123"),
origin="gold-label",
)
ds.write_labels(labels=[label])
labels = ds.get_all_labels()
assert len(labels) == 1
assert label == labels[0]
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.skipif(sys.platform == "win32", reason="_get_documents_meta() fails with 'too many SQL variables'") @pytest.mark.skipif(sys.platform == "win32", reason="_get_documents_meta() fails with 'too many SQL variables'")
def test_get_all_documents_large_quantities(self, ds): def test_get_all_documents_large_quantities(self, ds):
@ -476,6 +513,309 @@ class DocumentStoreBaseTestAbstract:
assert all(isinstance(d, Document) for d in documents) assert all(isinstance(d, Document) for d in documents)
assert len(documents) == len(docs_to_write) assert len(documents) == len(docs_to_write)
@pytest.mark.integration
def test_multilabel(self, ds):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-negative",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
origin="gold-label",
),
]
ds.write_labels(labels)
# Multi labels (open domain)
multi_labels_open = ds.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question
assert len(multi_labels_open) == 1
# all labels are in there except the negative one and the no_answer
assert len(multi_labels_open[0].labels) == 4
assert len(multi_labels_open[0].answers) == 3
assert "5-negative" not in [l.id for l in multi_labels_open[0].labels]
assert len(multi_labels_open[0].document_ids) == 3
# Don't drop the negative label
multi_labels_open = ds.get_all_labels_aggregated(
open_domain=True, drop_no_answers=False, drop_negative_labels=False
)
assert len(multi_labels_open[0].labels) == 5
assert len(multi_labels_open[0].answers) == 4
assert len(multi_labels_open[0].document_ids) == 4
# Drop no answer + negative
multi_labels_open = ds.get_all_labels_aggregated(
open_domain=True, drop_no_answers=True, drop_negative_labels=True
)
assert len(multi_labels_open[0].labels) == 3
assert len(multi_labels_open[0].answers) == 3
assert len(multi_labels_open[0].document_ids) == 3
# for closed domain we group by document so we expect 3 multilabels with 2,1,1 labels each (negative dropped again)
multi_labels = ds.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True)
assert len(multi_labels) == 3
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1])
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
@pytest.mark.integration
def test_multilabel_no_answer(self, ds):
labels = [
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
# no answer in different doc
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="123"),
origin="gold-label",
),
# no answer in same doc, should be excluded
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
# no answer with is_correct_answer=False, should be excluded
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=False,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
]
ds.write_labels(labels)
multi_labels = ds.get_all_labels_aggregated(open_domain=True, drop_no_answers=False, drop_negative_labels=True)
assert len(multi_labels) == 1
assert multi_labels[0].no_answer == True
assert len(multi_labels[0].document_ids) == 0
assert len(multi_labels[0].answers) == 1
multi_labels = ds.get_all_labels_aggregated(open_domain=True, drop_no_answers=False, drop_negative_labels=False)
assert len(multi_labels) == 1
assert multi_labels[0].no_answer == True
assert len(multi_labels[0].document_ids) == 0
assert len(multi_labels[0].labels) == 3
assert len(multi_labels[0].answers) == 1
@pytest.mark.integration
def test_multilabel_filter_aggregations(self, ds):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-negative",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
]
ds.write_labels(labels)
# Multi labels (open domain)
multi_labels_open = ds.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 3
label_counts = set([len(ml.labels) for ml in multi_labels_open])
assert label_counts == set([2, 1, 1])
# all labels are in there except the negative one and the no_answer
assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels]
assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids)
# for closed domain we group by document so we expect the same as with filters
multi_labels = ds.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True)
assert len(multi_labels) == 3
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1])
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
@pytest.mark.integration
def test_multilabel_meta_aggregations(self, ds):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-888",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["888"]},
),
]
ds.write_labels(labels)
# Multi labels (open domain)
multi_labels_open = ds.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 1
assert len(multi_labels_open[0].labels) == 5
multi_labels = ds.get_all_labels_aggregated(
open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id"
)
assert len(multi_labels) == 4
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1, 1])
for multi_label in multi_labels:
for l in multi_label.labels:
assert l.filters == l.meta
assert multi_label.filters == l.filters
# #
# Unit tests # Unit tests
# #

View File

@ -3,7 +3,6 @@ import math
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from rank_bm25 import BM25
import pytest import pytest
from unittest.mock import Mock from unittest.mock import Mock
@ -20,7 +19,6 @@ from haystack.document_stores import (
from haystack.document_stores.base import BaseDocumentStore from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.es_converter import elasticsearch_index_to_document_store from haystack.document_stores.es_converter import elasticsearch_index_to_document_store
from haystack.errors import DuplicateDocumentError
from haystack.schema import Document, Label, Answer, Span from haystack.schema import Document, Label, Answer, Span
from haystack.nodes import EmbeddingRetriever, PreProcessor from haystack.nodes import EmbeddingRetriever, PreProcessor
from haystack.pipelines import DocumentSearchPipeline from haystack.pipelines import DocumentSearchPipeline
@ -60,53 +58,6 @@ DOCUMENTS = [
] ]
@pytest.mark.parametrize(
"document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate", "pinecone"], indirect=True
)
def test_write_with_duplicate_doc_ids_custom_index(document_store: BaseDocumentStore):
duplicate_documents = [
Document(content="Doc1", id_hash_keys=["content"]),
Document(content="Doc1", id_hash_keys=["content"]),
]
document_store.delete_index(index="haystack_custom_test")
document_store.write_documents(duplicate_documents, index="haystack_custom_test", duplicate_documents="skip")
assert len(document_store.get_all_documents(index="haystack_custom_test")) == 1
with pytest.raises(DuplicateDocumentError):
document_store.write_documents(duplicate_documents, index="haystack_custom_test", duplicate_documents="fail")
# Weaviate manipulates document objects in-place when writing them to an index.
# It generates a uuid based on the provided id and the index name where the document is added to.
# We need to get rid of these generated uuids for this test and therefore reset the document objects.
# As a result, the documents will receive a fresh uuid based on their id_hash_keys and a different index name.
if isinstance(document_store, WeaviateDocumentStore):
duplicate_documents = [
Document(content="Doc1", id_hash_keys=["content"]),
Document(content="Doc1", id_hash_keys=["content"]),
]
# writing to the default, empty index should still work
document_store.write_documents(duplicate_documents, duplicate_documents="fail")
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate"], indirect=True)
def test_document_with_embeddings(document_store: BaseDocumentStore):
documents = [
{"content": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)},
{"content": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64)},
{"content": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()},
{"content": "text4", "id": "4", "embedding": np.random.rand(768).astype(np.float32)},
]
document_store.write_documents(documents)
assert len(document_store.get_all_documents()) == 4
if not isinstance(document_store, WeaviateDocumentStore):
# weaviate is excluded because it would return dummy vectors instead of None
documents_without_embedding = document_store.get_all_documents(return_embedding=False)
assert documents_without_embedding[0].embedding is None
documents_with_embedding = document_store.get_all_documents(return_embedding=True)
assert isinstance(documents_with_embedding[0].embedding, (list, np.ndarray))
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate"], indirect=True) @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate"], indirect=True)
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True) @pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
def test_update_embeddings(document_store, retriever): def test_update_embeddings(document_store, retriever):
@ -268,381 +219,6 @@ def test_update_embeddings_table_text_retriever(document_store, retriever):
) )
def test_delete_documents_by_id_with_filters(document_store_with_docs):
docs_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test2"]})
docs_not_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test3"]})
document_store_with_docs.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"meta_field": ["test1"]})
all_docs_left = document_store_with_docs.get_all_documents()
assert len(all_docs_left) == 4
assert all(doc.meta["meta_field"] != "test1" for doc in all_docs_left)
all_ids_left = [doc.id for doc in all_docs_left]
assert all(doc.id in all_ids_left for doc in docs_not_to_delete)
@pytest.mark.parametrize("document_store", ["elasticsearch", "opensearch"], indirect=True)
def test_labels_with_long_texts(document_store: BaseDocumentStore):
document_store.delete_index("label")
label = Label(
query="question1",
answer=Answer(
answer="answer",
type="extractive",
score=0.0,
context="something " * 10_000,
offsets_in_document=[Span(start=12, end=14)],
offsets_in_context=[Span(start=12, end=14)],
),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="something " * 10_000, id="123"),
origin="gold-label",
)
document_store.write_labels(labels=[label], index="label")
labels = document_store.get_all_labels(index="label")
assert len(labels) == 1
assert label == labels[0]
# exclude weaviate because it does not support storing labels
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "pinecone"], indirect=True)
def test_multilabel(document_store: BaseDocumentStore):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-negative",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
origin="gold-label",
),
]
document_store.write_labels(labels)
# regular labels - not aggregated
list_labels = document_store.get_all_labels()
assert set(list_labels) == set(labels)
assert len(list_labels) == 5
# Currently we don't enforce writing (missing) docs automatically when adding labels and there's no DB relationship between the two.
# We should introduce this when we refactored the logic of "index" to be rather a "collection" of labels+documents
# docs = document_store.get_all_documents()
# assert len(docs) == 3
# Multi labels (open domain)
multi_labels_open = document_store.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question
assert len(multi_labels_open) == 1
# all labels are in there except the negative one and the no_answer
assert len(multi_labels_open[0].labels) == 4
assert len(multi_labels_open[0].answers) == 3
assert "5-negative" not in [l.id for l in multi_labels_open[0].labels]
assert len(multi_labels_open[0].document_ids) == 3
# Don't drop the negative label
multi_labels_open = document_store.get_all_labels_aggregated(
open_domain=True, drop_no_answers=False, drop_negative_labels=False
)
assert len(multi_labels_open[0].labels) == 5
assert len(multi_labels_open[0].answers) == 4
assert len(multi_labels_open[0].document_ids) == 4
# Drop no answer + negative
multi_labels_open = document_store.get_all_labels_aggregated(
open_domain=True, drop_no_answers=True, drop_negative_labels=True
)
assert len(multi_labels_open[0].labels) == 3
assert len(multi_labels_open[0].answers) == 3
assert len(multi_labels_open[0].document_ids) == 3
# for closed domain we group by document so we expect 3 multilabels with 2,1,1 labels each (negative dropped again)
multi_labels = document_store.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True)
assert len(multi_labels) == 3
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1])
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
# exclude weaviate because it does not support storing labels
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "pinecone"], indirect=True)
def test_multilabel_no_answer(document_store: BaseDocumentStore):
labels = [
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
# no answer in different doc
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="123"),
origin="gold-label",
),
# no answer in same doc, should be excluded
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
# no answer with is_correct_answer=False, should be excluded
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=False,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
]
document_store.write_labels(labels)
labels = document_store.get_all_labels()
assert len(labels) == 4
multi_labels = document_store.get_all_labels_aggregated(
open_domain=True, drop_no_answers=False, drop_negative_labels=True
)
assert len(multi_labels) == 1
assert multi_labels[0].no_answer == True
assert len(multi_labels[0].document_ids) == 0
assert len(multi_labels[0].answers) == 1
multi_labels = document_store.get_all_labels_aggregated(
open_domain=True, drop_no_answers=False, drop_negative_labels=False
)
assert len(multi_labels) == 1
assert multi_labels[0].no_answer == True
assert len(multi_labels[0].document_ids) == 0
assert len(multi_labels[0].labels) == 3
assert len(multi_labels[0].answers) == 1
# exclude weaviate because it does not support storing labels
# exclude faiss and milvus as label metadata is not implemented
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
def test_multilabel_filter_aggregations(document_store: BaseDocumentStore):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-negative",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
]
document_store.write_labels(labels)
# regular labels - not aggregated
list_labels = document_store.get_all_labels()
assert list_labels == labels
assert len(list_labels) == 5
# Multi labels (open domain)
multi_labels_open = document_store.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 3
label_counts = set([len(ml.labels) for ml in multi_labels_open])
assert label_counts == set([2, 1, 1])
# all labels are in there except the negative one and the no_answer
assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels]
assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids)
# for closed domain we group by document so we expect the same as with filters
multi_labels = document_store.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True)
assert len(multi_labels) == 3
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1])
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
# exclude weaviate because it does not support storing labels
# exclude faiss and milvus as label metadata is not implemented
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
def test_multilabel_meta_aggregations(document_store: BaseDocumentStore):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-888",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["888"]},
),
]
document_store.write_labels(labels)
# regular labels - not aggregated
list_labels = document_store.get_all_labels()
assert list_labels == labels
assert len(list_labels) == 5
# Multi labels (open domain)
multi_labels_open = document_store.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 1
assert len(multi_labels_open[0].labels) == 5
multi_labels = document_store.get_all_labels_aggregated(
open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id"
)
assert len(multi_labels) == 4
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1, 1])
for multi_label in multi_labels:
for l in multi_label.labels:
assert l.filters == l.meta
assert multi_label.filters == l.filters
@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"]) @pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"])
def test_custom_embedding_field(document_store_type, tmp_path): def test_custom_embedding_field(document_store_type, tmp_path):
document_store = get_document_store( document_store = get_document_store(

View File

@ -245,12 +245,22 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
def test_nested_condition_not_filters(self, ds, documents): def test_nested_condition_not_filters(self, ds, documents):
pass pass
@pytest.mark.skip @pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration @pytest.mark.integration
def test_delete_labels_by_filter(self, ds, labels): def test_delete_labels_by_filter(self, ds, labels):
pass pass
@pytest.mark.skip @pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration @pytest.mark.integration
def test_delete_labels_by_filter_id(self, ds, labels): def test_delete_labels_by_filter_id(self, ds, labels):
pass pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass

View File

@ -86,12 +86,22 @@ class TestMilvusDocumentStore(DocumentStoreBaseTestAbstract):
# NOTE: again inherithed from the SQLDocumentStore, labels metadata are not supported # NOTE: again inherithed from the SQLDocumentStore, labels metadata are not supported
@pytest.mark.skip @pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration @pytest.mark.integration
def test_delete_labels_by_filter(self, ds, labels): def test_delete_labels_by_filter(self, ds, labels):
pass pass
@pytest.mark.skip @pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration @pytest.mark.integration
def test_delete_labels_by_filter_id(self, ds, labels): def test_delete_labels_by_filter_id(self, ds, labels):
pass pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass

View File

@ -190,6 +190,11 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
def test_nested_condition_not_filters(self, ds, documents): def test_nested_condition_not_filters(self, ds, documents):
pass pass
@pytest.mark.skip
@pytest.mark.integration
def test_delete_documents_by_id_with_filters(self, ds, documents):
pass
# NOTE: labels metadata are not supported # NOTE: labels metadata are not supported
@pytest.mark.skip @pytest.mark.skip
@ -207,6 +212,31 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
def test_simplified_filters(self, ds, documents): def test_simplified_filters(self, ds, documents):
pass pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_labels_with_long_texts(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_no_answer(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass
# NOTE: Pinecone does not support dates, so it can't do lte or gte on date fields. When a new release introduces this feature, # NOTE: Pinecone does not support dates, so it can't do lte or gte on date fields. When a new release introduces this feature,
# the entire family of test_get_all_documents_extended_filter_* tests will become identical to the one present in the # the entire family of test_get_all_documents_extended_filter_* tests will become identical to the one present in the
# base document store suite, and can be removed from here. # base document store suite, and can be removed from here.

View File

@ -108,14 +108,22 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
def test_nested_condition_not_filters(self, ds, documents): def test_nested_condition_not_filters(self, ds, documents):
pass pass
# NOTE: labels metadata are not supported @pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.skip
@pytest.mark.integration @pytest.mark.integration
def test_delete_labels_by_filter(self, ds, labels): def test_delete_labels_by_filter(self, ds, labels):
pass pass
@pytest.mark.skip @pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration @pytest.mark.integration
def test_delete_labels_by_filter_id(self, ds, labels): def test_delete_labels_by_filter_id(self, ds, labels):
pass pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass

View File

@ -103,6 +103,31 @@ class TestWeaviateDocumentStore(DocumentStoreBaseTestAbstract):
def test_write_get_all_labels(self): def test_write_get_all_labels(self):
pass pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_labels_with_long_texts(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_multilabel(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_multilabel_no_answer(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass
@pytest.mark.integration @pytest.mark.integration
def test_ne_filters(self, ds, documents): def test_ne_filters(self, ds, documents):
""" """