Make MultiLabel preserve order (#2956)

* try simple approach

* added test

* add requested test
This commit is contained in:
Stefano Fiorucci 2022-08-09 15:53:24 +02:00 committed by GitHub
parent dfeb171686
commit 09707b576a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 123 additions and 1 deletions

View File

@ -624,7 +624,7 @@ class MultiLabel:
:param drop_no_answers: Whether to drop labels that specify the answer is impossible
"""
# drop duplicate labels and remove negative labels if needed.
labels = list(set(labels))
labels = list(dict.fromkeys(labels))
if drop_negative_labels:
labels = [l for l in labels if is_positive_label(l)]

View File

@ -239,6 +239,128 @@ def test_aggregate_labels_with_labels():
label = MultiLabel(labels=[label1_with_filter1, label3_with_filter2])
def test_multilabel_preserve_order():
labels = [
Label(
id="0",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="1",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="2",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="3",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
no_answer=True,
origin="gold-label",
),
Label(
id="4",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
]
multilabel = MultiLabel(labels=labels)
for i in range(0, 5):
assert multilabel.labels[i].id == str(i)
def test_multilabel_preserve_order_w_duplicates():
labels = [
Label(
id="0",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="1",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="2",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="0",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="2",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
]
multilabel = MultiLabel(labels=labels)
assert len(multilabel.document_ids) == 3
for i in range(0, 3):
assert multilabel.labels[i].id == str(i)
def test_serialize_speech_document():
speech_doc = SpeechDocument(
id=12345,