From 09707b576ae4cc9ad147b46bad9d3c78f1993e42 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Tue, 9 Aug 2022 15:53:24 +0200 Subject: [PATCH] Make `MultiLabel` preserve order (#2956) * try simple approach * added test * add requested test --- haystack/schema.py | 2 +- test/others/test_schema.py | 122 +++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 1 deletion(-) diff --git a/haystack/schema.py b/haystack/schema.py index bd83def2e..1b21daa94 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -624,7 +624,7 @@ class MultiLabel: :param drop_no_answers: Whether to drop labels that specify the answer is impossible """ # drop duplicate labels and remove negative labels if needed. - labels = list(set(labels)) + labels = list(dict.fromkeys(labels)) if drop_negative_labels: labels = [l for l in labels if is_positive_label(l)] diff --git a/test/others/test_schema.py b/test/others/test_schema.py index 0078bdfd5..d1af0d02d 100644 --- a/test/others/test_schema.py +++ b/test/others/test_schema.py @@ -239,6 +239,128 @@ def test_aggregate_labels_with_labels(): label = MultiLabel(labels=[label1_with_filter1, label3_with_filter2]) +def test_multilabel_preserve_order(): + labels = [ + Label( + id="0", + query="question", + answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + ), + Label( + id="1", + query="question", + answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + ), + Label( + id="2", + query="question", + answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some other", id="333"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + ), + Label( + id="3", + query="question", + answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]), + document=Document(content="some", id="777"), + is_correct_answer=True, + is_correct_document=True, + no_answer=True, + origin="gold-label", + ), + Label( + id="4", + query="question", + answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=False, + is_correct_document=True, + no_answer=False, + origin="gold-label", + ), + ] + + multilabel = MultiLabel(labels=labels) + + for i in range(0, 5): + assert multilabel.labels[i].id == str(i) + + +def test_multilabel_preserve_order_w_duplicates(): + labels = [ + Label( + id="0", + query="question", + answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + ), + Label( + id="1", + query="question", + answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + ), + Label( + id="2", + query="question", + answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some other", id="333"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + ), + Label( + id="0", + query="question", + answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some", id="123"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + ), + Label( + id="2", + query="question", + answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]), + document=Document(content="some other", id="333"), + is_correct_answer=True, + is_correct_document=True, + no_answer=False, + origin="gold-label", + ), + ] + + multilabel = MultiLabel(labels=labels) + + assert len(multilabel.document_ids) == 3 + + for i in range(0, 3): + assert multilabel.labels[i].id == str(i) + + def test_serialize_speech_document(): speech_doc = SpeechDocument( id=12345,