refactor: move more tests to the base class (#3637)

* move more tests to the base class

* skip tests where unsupported

* do not pass index label explicitly

* skip test for Pinecone
This commit is contained in:
Massimiliano Pippi 2022-11-29 08:43:27 +01:00 committed by GitHub
parent 839eef6695
commit b20f808119
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 432 additions and 433 deletions

View File

@ -3,7 +3,7 @@ import sys
import pytest
import numpy as np
from haystack.schema import Document, Label, Answer
from haystack.schema import Document, Label, Answer, Span
from haystack.errors import DuplicateDocumentError
from haystack.document_stores import BaseDocumentStore
@ -101,6 +101,13 @@ class DocumentStoreBaseTestAbstract:
out = ds.get_all_documents()
assert out == documents
@pytest.mark.integration
def test_get_all_documents_without_embeddings(self, ds, documents):
ds.write_documents(documents)
out = ds.get_all_documents(return_embedding=False)
for doc in out:
assert doc.embedding is None
@pytest.mark.integration
def test_get_all_document_filter_duplicate_text_value(self, ds):
documents = [
@ -386,6 +393,14 @@ class DocumentStoreBaseTestAbstract:
ds.delete_documents(ids=[doc.id for doc in docs_to_delete])
assert ds.get_document_count() == 6
@pytest.mark.integration
def test_delete_documents_by_id_with_filters(self, ds, documents):
ds.write_documents(documents)
docs_to_delete = ds.get_all_documents(filters={"year": ["2020"]})
# this should delete only 1 document out of the 3 ids passed
ds.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"name": ["name_0"]})
assert ds.get_document_count() == 8
@pytest.mark.integration
def test_write_get_all_labels(self, ds, labels):
ds.write_labels(labels)
@ -462,6 +477,28 @@ class DocumentStoreBaseTestAbstract:
assert doc.meta["year"] == "2099"
assert doc.meta["month"] == "12"
@pytest.mark.integration
def test_labels_with_long_texts(self, ds, documents):
label = Label(
query="question1",
answer=Answer(
answer="answer",
type="extractive",
score=0.0,
context="something " * 10_000,
offsets_in_document=[Span(start=12, end=14)],
offsets_in_context=[Span(start=12, end=14)],
),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="something " * 10_000, id="123"),
origin="gold-label",
)
ds.write_labels(labels=[label])
labels = ds.get_all_labels()
assert len(labels) == 1
assert label == labels[0]
@pytest.mark.integration
@pytest.mark.skipif(sys.platform == "win32", reason="_get_documents_meta() fails with 'too many SQL variables'")
def test_get_all_documents_large_quantities(self, ds):
@ -476,6 +513,309 @@ class DocumentStoreBaseTestAbstract:
assert all(isinstance(d, Document) for d in documents)
assert len(documents) == len(docs_to_write)
@pytest.mark.integration
def test_multilabel(self, ds):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-negative",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
origin="gold-label",
),
]
ds.write_labels(labels)
# Multi labels (open domain)
multi_labels_open = ds.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question
assert len(multi_labels_open) == 1
# all labels are in there except the negative one and the no_answer
assert len(multi_labels_open[0].labels) == 4
assert len(multi_labels_open[0].answers) == 3
assert "5-negative" not in [l.id for l in multi_labels_open[0].labels]
assert len(multi_labels_open[0].document_ids) == 3
# Don't drop the negative label
multi_labels_open = ds.get_all_labels_aggregated(
open_domain=True, drop_no_answers=False, drop_negative_labels=False
)
assert len(multi_labels_open[0].labels) == 5
assert len(multi_labels_open[0].answers) == 4
assert len(multi_labels_open[0].document_ids) == 4
# Drop no answer + negative
multi_labels_open = ds.get_all_labels_aggregated(
open_domain=True, drop_no_answers=True, drop_negative_labels=True
)
assert len(multi_labels_open[0].labels) == 3
assert len(multi_labels_open[0].answers) == 3
assert len(multi_labels_open[0].document_ids) == 3
# for closed domain we group by document so we expect 3 multilabels with 2,1,1 labels each (negative dropped again)
multi_labels = ds.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True)
assert len(multi_labels) == 3
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1])
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
@pytest.mark.integration
def test_multilabel_no_answer(self, ds):
labels = [
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
# no answer in different doc
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="123"),
origin="gold-label",
),
# no answer in same doc, should be excluded
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
# no answer with is_correct_answer=False, should be excluded
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=False,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
]
ds.write_labels(labels)
multi_labels = ds.get_all_labels_aggregated(open_domain=True, drop_no_answers=False, drop_negative_labels=True)
assert len(multi_labels) == 1
assert multi_labels[0].no_answer == True
assert len(multi_labels[0].document_ids) == 0
assert len(multi_labels[0].answers) == 1
multi_labels = ds.get_all_labels_aggregated(open_domain=True, drop_no_answers=False, drop_negative_labels=False)
assert len(multi_labels) == 1
assert multi_labels[0].no_answer == True
assert len(multi_labels[0].document_ids) == 0
assert len(multi_labels[0].labels) == 3
assert len(multi_labels[0].answers) == 1
@pytest.mark.integration
def test_multilabel_filter_aggregations(self, ds):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-negative",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
]
ds.write_labels(labels)
# Multi labels (open domain)
multi_labels_open = ds.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 3
label_counts = set([len(ml.labels) for ml in multi_labels_open])
assert label_counts == set([2, 1, 1])
# all labels are in there except the negative one and the no_answer
assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels]
assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids)
# for closed domain we group by document so we expect the same as with filters
multi_labels = ds.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True)
assert len(multi_labels) == 3
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1])
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
@pytest.mark.integration
def test_multilabel_meta_aggregations(self, ds):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-888",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["888"]},
),
]
ds.write_labels(labels)
# Multi labels (open domain)
multi_labels_open = ds.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 1
assert len(multi_labels_open[0].labels) == 5
multi_labels = ds.get_all_labels_aggregated(
open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id"
)
assert len(multi_labels) == 4
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1, 1])
for multi_label in multi_labels:
for l in multi_label.labels:
assert l.filters == l.meta
assert multi_label.filters == l.filters
#
# Unit tests
#

View File

@ -3,7 +3,6 @@ import math
import numpy as np
import pandas as pd
from rank_bm25 import BM25
import pytest
from unittest.mock import Mock
@ -20,7 +19,6 @@ from haystack.document_stores import (
from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.es_converter import elasticsearch_index_to_document_store
from haystack.errors import DuplicateDocumentError
from haystack.schema import Document, Label, Answer, Span
from haystack.nodes import EmbeddingRetriever, PreProcessor
from haystack.pipelines import DocumentSearchPipeline
@ -60,53 +58,6 @@ DOCUMENTS = [
]
@pytest.mark.parametrize(
"document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate", "pinecone"], indirect=True
)
def test_write_with_duplicate_doc_ids_custom_index(document_store: BaseDocumentStore):
duplicate_documents = [
Document(content="Doc1", id_hash_keys=["content"]),
Document(content="Doc1", id_hash_keys=["content"]),
]
document_store.delete_index(index="haystack_custom_test")
document_store.write_documents(duplicate_documents, index="haystack_custom_test", duplicate_documents="skip")
assert len(document_store.get_all_documents(index="haystack_custom_test")) == 1
with pytest.raises(DuplicateDocumentError):
document_store.write_documents(duplicate_documents, index="haystack_custom_test", duplicate_documents="fail")
# Weaviate manipulates document objects in-place when writing them to an index.
# It generates a uuid based on the provided id and the index name where the document is added to.
# We need to get rid of these generated uuids for this test and therefore reset the document objects.
# As a result, the documents will receive a fresh uuid based on their id_hash_keys and a different index name.
if isinstance(document_store, WeaviateDocumentStore):
duplicate_documents = [
Document(content="Doc1", id_hash_keys=["content"]),
Document(content="Doc1", id_hash_keys=["content"]),
]
# writing to the default, empty index should still work
document_store.write_documents(duplicate_documents, duplicate_documents="fail")
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate"], indirect=True)
def test_document_with_embeddings(document_store: BaseDocumentStore):
documents = [
{"content": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)},
{"content": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64)},
{"content": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()},
{"content": "text4", "id": "4", "embedding": np.random.rand(768).astype(np.float32)},
]
document_store.write_documents(documents)
assert len(document_store.get_all_documents()) == 4
if not isinstance(document_store, WeaviateDocumentStore):
# weaviate is excluded because it would return dummy vectors instead of None
documents_without_embedding = document_store.get_all_documents(return_embedding=False)
assert documents_without_embedding[0].embedding is None
documents_with_embedding = document_store.get_all_documents(return_embedding=True)
assert isinstance(documents_with_embedding[0].embedding, (list, np.ndarray))
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate"], indirect=True)
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
def test_update_embeddings(document_store, retriever):
@ -268,381 +219,6 @@ def test_update_embeddings_table_text_retriever(document_store, retriever):
)
def test_delete_documents_by_id_with_filters(document_store_with_docs):
docs_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test2"]})
docs_not_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test3"]})
document_store_with_docs.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"meta_field": ["test1"]})
all_docs_left = document_store_with_docs.get_all_documents()
assert len(all_docs_left) == 4
assert all(doc.meta["meta_field"] != "test1" for doc in all_docs_left)
all_ids_left = [doc.id for doc in all_docs_left]
assert all(doc.id in all_ids_left for doc in docs_not_to_delete)
@pytest.mark.parametrize("document_store", ["elasticsearch", "opensearch"], indirect=True)
def test_labels_with_long_texts(document_store: BaseDocumentStore):
document_store.delete_index("label")
label = Label(
query="question1",
answer=Answer(
answer="answer",
type="extractive",
score=0.0,
context="something " * 10_000,
offsets_in_document=[Span(start=12, end=14)],
offsets_in_context=[Span(start=12, end=14)],
),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="something " * 10_000, id="123"),
origin="gold-label",
)
document_store.write_labels(labels=[label], index="label")
labels = document_store.get_all_labels(index="label")
assert len(labels) == 1
assert label == labels[0]
# exclude weaviate because it does not support storing labels
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "pinecone"], indirect=True)
def test_multilabel(document_store: BaseDocumentStore):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-negative",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
origin="gold-label",
),
]
document_store.write_labels(labels)
# regular labels - not aggregated
list_labels = document_store.get_all_labels()
assert set(list_labels) == set(labels)
assert len(list_labels) == 5
# Currently we don't enforce writing (missing) docs automatically when adding labels and there's no DB relationship between the two.
# We should introduce this when we refactored the logic of "index" to be rather a "collection" of labels+documents
# docs = document_store.get_all_documents()
# assert len(docs) == 3
# Multi labels (open domain)
multi_labels_open = document_store.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question
assert len(multi_labels_open) == 1
# all labels are in there except the negative one and the no_answer
assert len(multi_labels_open[0].labels) == 4
assert len(multi_labels_open[0].answers) == 3
assert "5-negative" not in [l.id for l in multi_labels_open[0].labels]
assert len(multi_labels_open[0].document_ids) == 3
# Don't drop the negative label
multi_labels_open = document_store.get_all_labels_aggregated(
open_domain=True, drop_no_answers=False, drop_negative_labels=False
)
assert len(multi_labels_open[0].labels) == 5
assert len(multi_labels_open[0].answers) == 4
assert len(multi_labels_open[0].document_ids) == 4
# Drop no answer + negative
multi_labels_open = document_store.get_all_labels_aggregated(
open_domain=True, drop_no_answers=True, drop_negative_labels=True
)
assert len(multi_labels_open[0].labels) == 3
assert len(multi_labels_open[0].answers) == 3
assert len(multi_labels_open[0].document_ids) == 3
# for closed domain we group by document so we expect 3 multilabels with 2,1,1 labels each (negative dropped again)
multi_labels = document_store.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True)
assert len(multi_labels) == 3
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1])
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
# exclude weaviate because it does not support storing labels
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "pinecone"], indirect=True)
def test_multilabel_no_answer(document_store: BaseDocumentStore):
labels = [
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
# no answer in different doc
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="123"),
origin="gold-label",
),
# no answer in same doc, should be excluded
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
# no answer with is_correct_answer=False, should be excluded
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=False,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
]
document_store.write_labels(labels)
labels = document_store.get_all_labels()
assert len(labels) == 4
multi_labels = document_store.get_all_labels_aggregated(
open_domain=True, drop_no_answers=False, drop_negative_labels=True
)
assert len(multi_labels) == 1
assert multi_labels[0].no_answer == True
assert len(multi_labels[0].document_ids) == 0
assert len(multi_labels[0].answers) == 1
multi_labels = document_store.get_all_labels_aggregated(
open_domain=True, drop_no_answers=False, drop_negative_labels=False
)
assert len(multi_labels) == 1
assert multi_labels[0].no_answer == True
assert len(multi_labels[0].document_ids) == 0
assert len(multi_labels[0].labels) == 3
assert len(multi_labels[0].answers) == 1
# exclude weaviate because it does not support storing labels
# exclude faiss and milvus as label metadata is not implemented
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
def test_multilabel_filter_aggregations(document_store: BaseDocumentStore):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
filters={"name": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-negative",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
origin="gold-label",
filters={"name": ["123"]},
),
]
document_store.write_labels(labels)
# regular labels - not aggregated
list_labels = document_store.get_all_labels()
assert list_labels == labels
assert len(list_labels) == 5
# Multi labels (open domain)
multi_labels_open = document_store.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 3
label_counts = set([len(ml.labels) for ml in multi_labels_open])
assert label_counts == set([2, 1, 1])
# all labels are in there except the negative one and the no_answer
assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels]
assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids)
# for closed domain we group by document so we expect the same as with filters
multi_labels = document_store.get_all_labels_aggregated(open_domain=False, drop_negative_labels=True)
assert len(multi_labels) == 3
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1])
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
# exclude weaviate because it does not support storing labels
# exclude faiss and milvus as label metadata is not implemented
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
def test_multilabel_meta_aggregations(document_store: BaseDocumentStore):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-888",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
meta={"file_id": ["888"]},
),
]
document_store.write_labels(labels)
# regular labels - not aggregated
list_labels = document_store.get_all_labels()
assert list_labels == labels
assert len(list_labels) == 5
# Multi labels (open domain)
multi_labels_open = document_store.get_all_labels_aggregated(open_domain=True, drop_negative_labels=True)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 1
assert len(multi_labels_open[0].labels) == 5
multi_labels = document_store.get_all_labels_aggregated(
open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id"
)
assert len(multi_labels) == 4
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1, 1])
for multi_label in multi_labels:
for l in multi_label.labels:
assert l.filters == l.meta
assert multi_label.filters == l.filters
@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"])
def test_custom_embedding_field(document_store_type, tmp_path):
document_store = get_document_store(

View File

@ -245,12 +245,22 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
def test_nested_condition_not_filters(self, ds, documents):
pass
@pytest.mark.skip
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_delete_labels_by_filter(self, ds, labels):
pass
@pytest.mark.skip
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_delete_labels_by_filter_id(self, ds, labels):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass

View File

@ -86,12 +86,22 @@ class TestMilvusDocumentStore(DocumentStoreBaseTestAbstract):
# NOTE: again inherithed from the SQLDocumentStore, labels metadata are not supported
@pytest.mark.skip
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_delete_labels_by_filter(self, ds, labels):
pass
@pytest.mark.skip
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_delete_labels_by_filter_id(self, ds, labels):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass

View File

@ -190,6 +190,11 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
def test_nested_condition_not_filters(self, ds, documents):
pass
@pytest.mark.skip
@pytest.mark.integration
def test_delete_documents_by_id_with_filters(self, ds, documents):
pass
# NOTE: labels metadata are not supported
@pytest.mark.skip
@ -207,6 +212,31 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
def test_simplified_filters(self, ds, documents):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_labels_with_long_texts(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_no_answer(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass
# NOTE: Pinecone does not support dates, so it can't do lte or gte on date fields. When a new release introduces this feature,
# the entire family of test_get_all_documents_extended_filter_* tests will become identical to the one present in the
# base document store suite, and can be removed from here.

View File

@ -108,14 +108,22 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
def test_nested_condition_not_filters(self, ds, documents):
pass
# NOTE: labels metadata are not supported
@pytest.mark.skip
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_delete_labels_by_filter(self, ds, labels):
pass
@pytest.mark.skip
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_delete_labels_by_filter_id(self, ds, labels):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="labels metadata are not supported")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass

View File

@ -103,6 +103,31 @@ class TestWeaviateDocumentStore(DocumentStoreBaseTestAbstract):
def test_write_get_all_labels(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_labels_with_long_texts(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_multilabel(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_multilabel_no_answer(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_multilabel_filter_aggregations(self):
pass
@pytest.mark.skip(reason="Weaviate does not support labels")
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass
@pytest.mark.integration
def test_ne_filters(self, ds, documents):
"""