From db4d6f43baf4501a775812fc4e15889623ddf893 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Fri, 11 Feb 2022 17:42:47 +0100 Subject: [PATCH] Add tests on MultiLabel's meta and filter aggregation (#2169) --- test/test_document_store.py | 188 ++++++++++++++++++++++++++++++++++++ 1 file changed, 188 insertions(+) diff --git a/test/test_document_store.py b/test/test_document_store.py index 5eb3b49d6..d76b23f19 100644 --- a/test/test_document_store.py +++ b/test/test_document_store.py @@ -968,6 +968,194 @@ def test_multilabel_no_answer(document_store): assert len(multi_labels[0].answers) == 1 +# exclude weaviate because it does not support storing labels +# exclude faiss and milvus as label metadata is not implemented +@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True) +def test_multilabel_filter_aggregations(document_store): + labels = [ + Label( + id="standard", + query="question", + answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + filters={"name": ["123"]}, + ), + # different answer in same doc + Label( + id="diff-answer-same-doc", + query="question", + answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + filters={"name": ["123"]}, + ), + # answer in different doc + Label( + id="diff-answer-diff-doc", + query="question", + answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some other", id="333"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + filters={"name": ["333"]}, + ), + # 'no answer', should be excluded from MultiLabel + Label( + id="4-no-answer", + query="question", + answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), + document=Document(content="some", id="777"), + is_correct_answer=True, + is_correct_document=True, + no_answer=True, + origin="gold-label", + filters={"name": ["777"]}, + ), + # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" + Label( + id="5-negative", + query="question", + answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=False, + is_correct_document=True, + no_answer=False, + origin="gold-label", + filters={"name": ["123"]}, + ), + ] + document_store.write_labels(labels, index="haystack_test_multilabel") + # regular labels - not aggregated + list_labels = document_store.get_all_labels(index="haystack_test_multilabel") + assert list_labels == labels + assert len(list_labels) == 5 + + # Multi labels (open domain) + multi_labels_open = document_store.get_all_labels_aggregated( + index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True + ) + + # for open-domain we group all together as long as they have the same question and filters + assert len(multi_labels_open) == 3 + label_counts = set([len(ml.labels) for ml in multi_labels_open]) + assert label_counts == set([2, 1, 1]) + # all labels are in there except the negative one and the no_answer + assert "5-negative" not in [l.id for multi_label in multi_labels_open for l in multi_label.labels] + + assert len(multi_labels_open[0].answers) == len(multi_labels_open[0].document_ids) + + # for closed domain we group by document so we expect the same as with filters + multi_labels = document_store.get_all_labels_aggregated( + index="haystack_test_multilabel", open_domain=False, drop_negative_labels=True + ) + assert len(multi_labels) == 3 + label_counts = set([len(ml.labels) for ml in multi_labels]) + assert label_counts == set([2, 1, 1]) + + assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids) + + +# exclude weaviate because it does not support storing labels +# exclude faiss and milvus as label metadata is not implemented +@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True) +def test_multilabel_meta_aggregations(document_store): + labels = [ + Label( + id="standard", + query="question", + answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + meta={"file_id": ["123"]}, + ), + # different answer in same doc + Label( + id="diff-answer-same-doc", + query="question", + answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + meta={"file_id": ["123"]}, + ), + # answer in different doc + Label( + id="diff-answer-diff-doc", + query="question", + answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some other", id="333"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + meta={"file_id": ["333"]}, + ), + # 'no answer', should be excluded from MultiLabel + Label( + id="4-no-answer", + query="question", + answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), + document=Document(content="some", id="777"), + is_correct_answer=True, + is_correct_document=True, + no_answer=True, + origin="gold-label", + meta={"file_id": ["777"]}, + ), + # is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True" + Label( + id="5-888", + query="question", + answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + meta={"file_id": ["888"]}, + ), + ] + document_store.write_labels(labels, index="haystack_test_multilabel") + # regular labels - not aggregated + list_labels = document_store.get_all_labels(index="haystack_test_multilabel") + assert list_labels == labels + assert len(list_labels) == 5 + + # Multi labels (open domain) + multi_labels_open = document_store.get_all_labels_aggregated( + index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True + ) + + # for open-domain we group all together as long as they have the same question and filters + assert len(multi_labels_open) == 1 + assert len(multi_labels_open[0].labels) == 5 + + multi_labels = document_store.get_all_labels_aggregated( + index="haystack_test_multilabel", open_domain=True, drop_negative_labels=True, aggregate_by_meta="file_id" + ) + assert len(multi_labels) == 4 + label_counts = set([len(ml.labels) for ml in multi_labels]) + assert label_counts == set([2, 1, 1, 1]) + for multi_label in multi_labels: + for l in multi_label.labels: + assert l.filters == l.meta + assert multi_label.filters == l.filters + + @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True) # Currently update_document_meta() is not implemented for Memory doc store def test_update_meta(document_store):