mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-29 08:49:07 +00:00
Add tests on MultiLabel's meta and filter aggregation (#2169)
This commit is contained in:
parent
fdc36292f1
commit
db4d6f43ba
@ -968,6 +968,194 @@ def test_multilabel_no_answer(document_store):
|
||||
assert len(multi_labels[0].answers) == 1
|
||||
|
||||
|
||||
# exclude weaviate because it does not support storing labels
|
||||
# exclude faiss and milvus as label metadata is not implemented
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
|
||||
def test_multilabel_filter_aggregations(document_store):
|
||||
labels = [
|
||||
Label(
|
||||
id="standard",
|
||||
query="question",
|
||||
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
filters={"name": ["123"]},
|
||||
),
|
||||
# different answer in same doc
|
||||
Label(
|
||||
id="diff-answer-same-doc",
|
||||
query="question",
|
||||
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
filters={"name": ["123"]},
|
||||
),
|
||||
# answer in different doc
|
||||
Label(
|
||||
id="diff-answer-diff-doc",
|
||||
query="question",
|
||||
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some other", id="333"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
filters={"name": ["333"]},
|
||||
),
|
||||
# 'no answer', should be excluded from MultiLabel
|
||||
Label(
|
||||
id="4-no-answer",
|
||||
query="question",
|
||||
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
|
||||
document=Document(content="some", id="777"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=True,
|
||||
origin="gold-label",
|
||||
filters={"name": ["777"]},
|
||||
),
|
||||
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
|
||||
Label(
|
||||
id="5-negative",
|
||||
query="question",
|
||||
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=False,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
filters={"name": ["123"]},
|
||||
),
|
||||
]
|
||||
document_store.write_labels(labels, index="haystack_test_multilabel")
|
||||
# regular labels - not aggregated
|
||||
list_labels = document_store.get_all_labels(index="haystack_test_multilabel")
|
||||
assert list_labels == labels
|
||||
assert len(list_labels) == 5
|
||||
|
||||
# Multi labels (open domain)
|
||||
multi_labels_open = document_store.get_all_labels_aggregated(
|
||||
index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True
|
||||
)
|
||||
|
||||
# for open-domain we group all together as long as they have the same question and filters
|
||||
assert len(multi_labels_open) == 3
|
||||
label_counts = set([len(ml.labels) for ml in multi_labels_open])
|
||||
assert label_counts == set([2, 1, 1])
|
||||
# all labels are in there except the negative one and the no_answer
|
||||
assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels]
|
||||
|
||||
assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids)
|
||||
|
||||
# for closed domain we group by document so we expect the same as with filters
|
||||
multi_labels = document_store.get_all_labels_aggregated(
|
||||
index="haystack_test_multilabel", open_domain=False, drop_negative_labels=True
|
||||
)
|
||||
assert len(multi_labels) == 3
|
||||
label_counts = set([len(ml.labels) for ml in multi_labels])
|
||||
assert label_counts == set([2, 1, 1])
|
||||
|
||||
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
|
||||
|
||||
|
||||
# exclude weaviate because it does not support storing labels
|
||||
# exclude faiss and milvus as label metadata is not implemented
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
|
||||
def test_multilabel_meta_aggregations(document_store):
|
||||
labels = [
|
||||
Label(
|
||||
id="standard",
|
||||
query="question",
|
||||
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
meta={"file_id": ["123"]},
|
||||
),
|
||||
# different answer in same doc
|
||||
Label(
|
||||
id="diff-answer-same-doc",
|
||||
query="question",
|
||||
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
meta={"file_id": ["123"]},
|
||||
),
|
||||
# answer in different doc
|
||||
Label(
|
||||
id="diff-answer-diff-doc",
|
||||
query="question",
|
||||
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some other", id="333"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
meta={"file_id": ["333"]},
|
||||
),
|
||||
# 'no answer', should be excluded from MultiLabel
|
||||
Label(
|
||||
id="4-no-answer",
|
||||
query="question",
|
||||
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
|
||||
document=Document(content="some", id="777"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=True,
|
||||
origin="gold-label",
|
||||
meta={"file_id": ["777"]},
|
||||
),
|
||||
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
|
||||
Label(
|
||||
id="5-888",
|
||||
query="question",
|
||||
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
meta={"file_id": ["888"]},
|
||||
),
|
||||
]
|
||||
document_store.write_labels(labels, index="haystack_test_multilabel")
|
||||
# regular labels - not aggregated
|
||||
list_labels = document_store.get_all_labels(index="haystack_test_multilabel")
|
||||
assert list_labels == labels
|
||||
assert len(list_labels) == 5
|
||||
|
||||
# Multi labels (open domain)
|
||||
multi_labels_open = document_store.get_all_labels_aggregated(
|
||||
index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True
|
||||
)
|
||||
|
||||
# for open-domain we group all together as long as they have the same question and filters
|
||||
assert len(multi_labels_open) == 1
|
||||
assert len(multi_labels_open[0].labels) == 5
|
||||
|
||||
multi_labels = document_store.get_all_labels_aggregated(
|
||||
index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id"
|
||||
)
|
||||
assert len(multi_labels) == 4
|
||||
label_counts = set([len(ml.labels) for ml in multi_labels])
|
||||
assert label_counts == set([2, 1, 1, 1])
|
||||
for multi_label in multi_labels:
|
||||
for l in multi_label.labels:
|
||||
assert l.filters == l.meta
|
||||
assert multi_label.filters == l.filters
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True)
|
||||
# Currently update_document_meta() is not implemented for Memory doc store
|
||||
def test_update_meta(document_store):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user