haystack/test/others/test_schema.py
camille f363b152ff
bug: make MultiLabel ids consistent across python interpreters (#2998)
* use hashlib.md5() instead of (interpreter dependent) hash() funtion to generate MultiLabel id

* add tests to assess constancy of MultiLabel.id

* make test_multilabel_id test ensure that MultiLabel ids are always the same
2022-08-10 09:43:21 +02:00

456 lines
14 KiB
Python

from haystack.schema import Document, Label, Answer, Span, MultiLabel, SpeechDocument, SpeechAnswer
import pytest
import numpy as np
from ..conftest import SAMPLES_PATH
LABELS = [
Label(
query="some",
answer=Answer(
answer="an answer",
type="extractive",
score=0.1,
document_id="123",
offsets_in_document=[Span(start=1, end=3)],
),
document=Document(content="some text", content_type="text"),
is_correct_answer=True,
is_correct_document=True,
origin="user-feedback",
),
Label(
query="some",
answer=Answer(answer="annother answer", type="extractive", score=0.1, document_id="123"),
document=Document(content="some text", content_type="text"),
is_correct_answer=True,
is_correct_document=True,
origin="user-feedback",
),
Label(
query="some",
answer=Answer(
answer="an answer",
type="extractive",
score=0.1,
document_id="123",
offsets_in_document=[Span(start=1, end=3)],
),
document=Document(content="some text", content_type="text"),
is_correct_answer=True,
is_correct_document=True,
origin="user-feedback",
),
]
def test_no_answer_label():
labels = [
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
no_answer=True,
origin="gold-label",
),
Label(
query="question",
answer=Answer(answer="some"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
Label(
query="question",
answer=Answer(answer="some"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
no_answer=False,
origin="gold-label",
),
]
assert labels[0].no_answer == True
assert labels[1].no_answer == True
assert labels[2].no_answer == False
assert labels[3].no_answer == False
def test_equal_label():
assert LABELS[2] == LABELS[0]
assert LABELS[1] != LABELS[0]
def test_answer_to_json():
a = Answer(
answer="an answer",
type="extractive",
score=0.1,
context="abc",
offsets_in_document=[Span(start=1, end=10)],
offsets_in_context=[Span(start=3, end=5)],
document_id="123",
)
j = a.to_json()
assert type(j) == str
assert len(j) > 30
a_new = Answer.from_json(j)
assert type(a_new.offsets_in_document[0]) == Span
assert a_new == a
def test_answer_to_dict():
a = Answer(
answer="an answer",
type="extractive",
score=0.1,
context="abc",
offsets_in_document=[Span(start=1, end=10)],
offsets_in_context=[Span(start=3, end=5)],
document_id="123",
)
j = a.to_dict()
assert type(j) == dict
a_new = Answer.from_dict(j)
assert type(a_new.offsets_in_document[0]) == Span
assert a_new == a
def test_label_to_json():
j0 = LABELS[0].to_json()
l_new = Label.from_json(j0)
assert l_new == LABELS[0]
def test_label_to_json():
j0 = LABELS[0].to_json()
l_new = Label.from_json(j0)
assert l_new == LABELS[0]
assert l_new.answer.offsets_in_document[0].start == 1
def test_label_to_dict():
j0 = LABELS[0].to_dict()
l_new = Label.from_dict(j0)
assert l_new == LABELS[0]
assert l_new.answer.offsets_in_document[0].start == 1
def test_doc_to_json():
# With embedding
d = Document(
content="some text",
content_type="text",
score=0.99988,
meta={"name": "doc1"},
embedding=np.random.rand(768).astype(np.float32),
)
j0 = d.to_json()
d_new = Document.from_json(j0)
assert d == d_new
# No embedding
d = Document(content="some text", content_type="text", score=0.99988, meta={"name": "doc1"}, embedding=None)
j0 = d.to_json()
d_new = Document.from_json(j0)
assert d == d_new
def test_answer_postinit():
a = Answer(answer="test", offsets_in_document=[{"start": 10, "end": 20}])
assert a.meta == {}
assert isinstance(a.offsets_in_document[0], Span)
def test_generate_doc_id_using_text():
text1 = "text1"
text2 = "text2"
doc1_text1 = Document(content=text1, meta={"name": "doc1"})
doc2_text1 = Document(content=text1, meta={"name": "doc2"})
doc3_text2 = Document(content=text2, meta={"name": "doc3"})
assert doc1_text1.id == doc2_text1.id
assert doc1_text1.id != doc3_text2.id
def test_generate_doc_id_using_custom_list():
text1 = "text1"
text2 = "text2"
doc1_meta1_id_by_content = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content"])
doc1_meta2_id_by_content = Document(content=text1, meta={"name": "doc2"}, id_hash_keys=["content"])
assert doc1_meta1_id_by_content.id == doc1_meta2_id_by_content.id
doc1_meta1_id_by_content_and_meta = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content", "meta"])
doc1_meta2_id_by_content_and_meta = Document(content=text1, meta={"name": "doc2"}, id_hash_keys=["content", "meta"])
assert doc1_meta1_id_by_content_and_meta.id != doc1_meta2_id_by_content_and_meta.id
doc1_text1 = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content"])
doc3_text2 = Document(content=text2, meta={"name": "doc3"}, id_hash_keys=["content"])
assert doc1_text1.id != doc3_text2.id
with pytest.raises(ValueError):
_ = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content", "non_existing_field"])
def test_aggregate_labels_with_labels():
label1_with_filter1 = Label(
query="question",
answer=Answer(answer="1"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
filters={"name": ["filename1"]},
)
label2_with_filter1 = Label(
query="question",
answer=Answer(answer="2"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
filters={"name": ["filename1"]},
)
label3_with_filter2 = Label(
query="question",
answer=Answer(answer="2"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
filters={"name": ["filename2"]},
)
label = MultiLabel(labels=[label1_with_filter1, label2_with_filter1])
assert label.filters == {"name": ["filename1"]}
with pytest.raises(ValueError):
label = MultiLabel(labels=[label1_with_filter1, label3_with_filter2])
def test_multilabel_preserve_order():
labels = [
Label(
id="0",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="1",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="2",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="3",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
no_answer=True,
origin="gold-label",
),
Label(
id="4",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
]
multilabel = MultiLabel(labels=labels)
for i in range(0, 5):
assert multilabel.labels[i].id == str(i)
def test_multilabel_preserve_order_w_duplicates():
labels = [
Label(
id="0",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="1",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="2",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="0",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
Label(
id="2",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
no_answer=False,
origin="gold-label",
),
]
multilabel = MultiLabel(labels=labels)
assert len(multilabel.document_ids) == 3
for i in range(0, 3):
assert multilabel.labels[i].id == str(i)
def test_multilabel_id():
query1 = "question 1"
query2 = "question 2"
document1 = Document(content="something", id="1")
answer1 = Answer(answer="answer 1")
filter1 = {"name": ["name 1"]}
filter2 = {"name": ["name 1"], "author": ["author 1"]}
label1 = Label(
query=query1,
document=document1,
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
answer=answer1,
filters=filter1,
)
label2 = Label(
query=query2,
document=document1,
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
answer=answer1,
filters=filter2,
)
label3 = Label(
query=query1,
document=document1,
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
answer=answer1,
filters=filter2,
)
assert MultiLabel(labels=[label1]).id == "33a3e58e13b16e9d6ec682ffe59ccc89"
assert MultiLabel(labels=[label2]).id == "1b3ad38b629db7b0e869373b01bc32b1"
assert MultiLabel(labels=[label3]).id == "531445fa3bdf98b8598a3bea032bd605"
def test_serialize_speech_document():
speech_doc = SpeechDocument(
id=12345,
content_type="audio",
content="this is the content of the document",
content_audio=SAMPLES_PATH / "audio" / "this is the content of the document.wav",
meta={"some": "meta"},
)
speech_doc_dict = speech_doc.to_dict()
assert speech_doc_dict["content"] == "this is the content of the document"
assert speech_doc_dict["content_audio"] == str(
(SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute()
)
def test_deserialize_speech_document():
speech_doc = SpeechDocument(
id=12345,
content_type="audio",
content="this is the content of the document",
content_audio=SAMPLES_PATH / "audio" / "this is the content of the document.wav",
meta={"some": "meta"},
)
assert speech_doc == SpeechDocument.from_dict(speech_doc.to_dict())
def test_serialize_speech_answer():
speech_answer = SpeechAnswer(
answer="answer",
answer_audio=SAMPLES_PATH / "audio" / "answer.wav",
context="the context for this answer is here",
context_audio=SAMPLES_PATH / "audio" / "the context for this answer is here.wav",
)
speech_answer_dict = speech_answer.to_dict()
assert speech_answer_dict["answer"] == "answer"
assert speech_answer_dict["answer_audio"] == str((SAMPLES_PATH / "audio" / "answer.wav").absolute())
assert speech_answer_dict["context"] == "the context for this answer is here"
assert speech_answer_dict["context_audio"] == str(
(SAMPLES_PATH / "audio" / "the context for this answer is here.wav").absolute()
)
def test_deserialize_speech_answer():
speech_answer = SpeechAnswer(
answer="answer",
answer_audio=SAMPLES_PATH / "audio" / "answer.wav",
context="the context for this answer is here",
context_audio=SAMPLES_PATH / "audio" / "the context for this answer is here.wav",
)
assert speech_answer == SpeechAnswer.from_dict(speech_answer.to_dict())