mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 23:48:53 +00:00
Make MultiLabel preserve order (#2956)
* try simple approach * added test * add requested test
This commit is contained in:
parent
dfeb171686
commit
09707b576a
@ -624,7 +624,7 @@ class MultiLabel:
|
||||
:param drop_no_answers: Whether to drop labels that specify the answer is impossible
|
||||
"""
|
||||
# drop duplicate labels and remove negative labels if needed.
|
||||
labels = list(set(labels))
|
||||
labels = list(dict.fromkeys(labels))
|
||||
if drop_negative_labels:
|
||||
labels = [l for l in labels if is_positive_label(l)]
|
||||
|
||||
|
||||
@ -239,6 +239,128 @@ def test_aggregate_labels_with_labels():
|
||||
label = MultiLabel(labels=[label1_with_filter1, label3_with_filter2])
|
||||
|
||||
|
||||
def test_multilabel_preserve_order():
|
||||
labels = [
|
||||
Label(
|
||||
id="0",
|
||||
query="question",
|
||||
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
),
|
||||
Label(
|
||||
id="1",
|
||||
query="question",
|
||||
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
),
|
||||
Label(
|
||||
id="2",
|
||||
query="question",
|
||||
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some other", id="333"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
),
|
||||
Label(
|
||||
id="3",
|
||||
query="question",
|
||||
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
|
||||
document=Document(content="some", id="777"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=True,
|
||||
origin="gold-label",
|
||||
),
|
||||
Label(
|
||||
id="4",
|
||||
query="question",
|
||||
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=False,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
),
|
||||
]
|
||||
|
||||
multilabel = MultiLabel(labels=labels)
|
||||
|
||||
for i in range(0, 5):
|
||||
assert multilabel.labels[i].id == str(i)
|
||||
|
||||
|
||||
def test_multilabel_preserve_order_w_duplicates():
|
||||
labels = [
|
||||
Label(
|
||||
id="0",
|
||||
query="question",
|
||||
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
),
|
||||
Label(
|
||||
id="1",
|
||||
query="question",
|
||||
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
),
|
||||
Label(
|
||||
id="2",
|
||||
query="question",
|
||||
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some other", id="333"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
),
|
||||
Label(
|
||||
id="0",
|
||||
query="question",
|
||||
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some", id="123"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
),
|
||||
Label(
|
||||
id="2",
|
||||
query="question",
|
||||
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
|
||||
document=Document(content="some other", id="333"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
no_answer=False,
|
||||
origin="gold-label",
|
||||
),
|
||||
]
|
||||
|
||||
multilabel = MultiLabel(labels=labels)
|
||||
|
||||
assert len(multilabel.document_ids) == 3
|
||||
|
||||
for i in range(0, 3):
|
||||
assert multilabel.labels[i].id == str(i)
|
||||
|
||||
|
||||
def test_serialize_speech_document():
|
||||
speech_doc = SpeechDocument(
|
||||
id=12345,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user