Add tests on MultiLabel's meta and filter aggregation (#2169)

This commit is contained in:
tstadel 2022-02-11 17:42:47 +01:00 committed by GitHub
parent fdc36292f1
commit db4d6f43ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -968,6 +968,194 @@ def test_multilabel_no_answer(document_store):
assert len(multi_labels[0].answers) == 1
# exclude weaviate because it does not support storing labels
# exclude faiss and milvus as label metadata is not implemented
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
def test_multilabel_filter_aggregations(document_store):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
filters={"name": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
filters={"name": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
filters={"name": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
no_answer=True,
origin="gold-label",
filters={"name": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-negative",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
no_answer=False,
origin="gold-label",
filters={"name": ["123"]},
),
]
document_store.write_labels(labels, index="haystack_test_multilabel")
# regular labels - not aggregated
list_labels = document_store.get_all_labels(index="haystack_test_multilabel")
assert list_labels == labels
assert len(list_labels) == 5
# Multi labels (open domain)
multi_labels_open = document_store.get_all_labels_aggregated(
index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True
)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 3
label_counts = set([len(ml.labels) for ml in multi_labels_open])
assert label_counts == set([2, 1, 1])
# all labels are in there except the negative one and the no_answer
assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels]
assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids)
# for closed domain we group by document so we expect the same as with filters
multi_labels = document_store.get_all_labels_aggregated(
index="haystack_test_multilabel", open_domain=False, drop_negative_labels=True
)
assert len(multi_labels) == 3
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1])
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
# exclude weaviate because it does not support storing labels
# exclude faiss and milvus as label metadata is not implemented
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
def test_multilabel_meta_aggregations(document_store):
labels = [
Label(
id="standard",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
meta={"file_id": ["123"]},
),
# different answer in same doc
Label(
id="diff-answer-same-doc",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
meta={"file_id": ["123"]},
),
# answer in different doc
Label(
id="diff-answer-diff-doc",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
meta={"file_id": ["333"]},
),
# 'no answer', should be excluded from MultiLabel
Label(
id="4-no-answer",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
no_answer=True,
origin="gold-label",
meta={"file_id": ["777"]},
),
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
Label(
id="5-888",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
meta={"file_id": ["888"]},
),
]
document_store.write_labels(labels, index="haystack_test_multilabel")
# regular labels - not aggregated
list_labels = document_store.get_all_labels(index="haystack_test_multilabel")
assert list_labels == labels
assert len(list_labels) == 5
# Multi labels (open domain)
multi_labels_open = document_store.get_all_labels_aggregated(
index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True
)
# for open-domain we group all together as long as they have the same question and filters
assert len(multi_labels_open) == 1
assert len(multi_labels_open[0].labels) == 5
multi_labels = document_store.get_all_labels_aggregated(
index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id"
)
assert len(multi_labels) == 4
label_counts = set([len(ml.labels) for ml in multi_labels])
assert label_counts == set([2, 1, 1, 1])
for multi_label in multi_labels:
for l in multi_label.labels:
assert l.filters == l.meta
assert multi_label.filters == l.filters
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True)
# Currently update_document_meta() is not implemented for Memory doc store
def test_update_meta(document_store):