haystack/test/others/test_schema.py
Stefano Fiorucci 54ec13eaf7
refactor: Change no_answer attribute (#3411)
* always run validation

* update schemas

* no_answer as a property. break things!

* forgotten schema

* fix

* update openapi

* removed my unnecessary test

* fix sql document store

Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai>
2022-10-25 13:07:00 +02:00

487 lines
15 KiB
Python

from haystack.schema import Document, Label, Answer, Span, MultiLabel, SpeechDocument, SpeechAnswer
import pytest
import numpy as np
import pandas as pd
from ..conftest import SAMPLES_PATH
LABELS = [
Label(
query="some",
answer=Answer(
answer="an answer",
type="extractive",
score=0.1,
document_id="123",
offsets_in_document=[Span(start=1, end=3)],
),
document=Document(content="some text", content_type="text"),
is_correct_answer=True,
is_correct_document=True,
origin="user-feedback",
),
Label(
query="some",
answer=Answer(answer="annother answer", type="extractive", score=0.1, document_id="123"),
document=Document(content="some text", content_type="text"),
is_correct_answer=True,
is_correct_document=True,
origin="user-feedback",
),
Label(
query="some",
answer=Answer(
answer="an answer",
type="extractive",
score=0.1,
document_id="123",
offsets_in_document=[Span(start=1, end=3)],
),
document=Document(content="some text", content_type="text"),
is_correct_answer=True,
is_correct_document=True,
origin="user-feedback",
),
]
def test_no_answer_label():
labels = [
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
Label(
query="question",
answer=Answer(answer=""),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
Label(
query="question",
answer=Answer(answer="some"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
Label(
query="question",
answer=Answer(answer="some"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
),
]
assert labels[0].no_answer == True
assert labels[1].no_answer == True
assert labels[2].no_answer == False
assert labels[3].no_answer == False
def test_equal_label():
assert LABELS[2] == LABELS[0]
assert LABELS[1] != LABELS[0]
def test_answer_to_json():
a = Answer(
answer="an answer",
type="extractive",
score=0.1,
context="abc",
offsets_in_document=[Span(start=1, end=10)],
offsets_in_context=[Span(start=3, end=5)],
document_id="123",
)
j = a.to_json()
assert type(j) == str
assert len(j) > 30
a_new = Answer.from_json(j)
assert type(a_new.offsets_in_document[0]) == Span
assert a_new == a
def test_answer_to_dict():
a = Answer(
answer="an answer",
type="extractive",
score=0.1,
context="abc",
offsets_in_document=[Span(start=1, end=10)],
offsets_in_context=[Span(start=3, end=5)],
document_id="123",
)
j = a.to_dict()
assert type(j) == dict
a_new = Answer.from_dict(j)
assert type(a_new.offsets_in_document[0]) == Span
assert a_new == a
def test_label_to_json():
j0 = LABELS[0].to_json()
l_new = Label.from_json(j0)
assert l_new == LABELS[0]
def test_label_to_json():
j0 = LABELS[0].to_json()
l_new = Label.from_json(j0)
assert l_new == LABELS[0]
assert l_new.answer.offsets_in_document[0].start == 1
def test_label_to_dict():
j0 = LABELS[0].to_dict()
l_new = Label.from_dict(j0)
assert l_new == LABELS[0]
assert l_new.answer.offsets_in_document[0].start == 1
def test_doc_to_json():
# With embedding
d = Document(
content="some text",
content_type="text",
score=0.99988,
meta={"name": "doc1"},
embedding=np.random.rand(768).astype(np.float32),
)
j0 = d.to_json()
d_new = Document.from_json(j0)
assert d == d_new
# No embedding
d = Document(content="some text", content_type="text", score=0.99988, meta={"name": "doc1"}, embedding=None)
j0 = d.to_json()
d_new = Document.from_json(j0)
assert d == d_new
def test_answer_postinit():
a = Answer(answer="test", offsets_in_document=[{"start": 10, "end": 20}])
assert a.meta == {}
assert isinstance(a.offsets_in_document[0], Span)
def test_generate_doc_id_using_text():
text1 = "text1"
text2 = "text2"
doc1_text1 = Document(content=text1, meta={"name": "doc1"})
doc2_text1 = Document(content=text1, meta={"name": "doc2"})
doc3_text2 = Document(content=text2, meta={"name": "doc3"})
assert doc1_text1.id == doc2_text1.id
assert doc1_text1.id != doc3_text2.id
def test_generate_doc_id_using_custom_list():
text1 = "text1"
text2 = "text2"
doc1_meta1_id_by_content = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content"])
doc1_meta2_id_by_content = Document(content=text1, meta={"name": "doc2"}, id_hash_keys=["content"])
assert doc1_meta1_id_by_content.id == doc1_meta2_id_by_content.id
doc1_meta1_id_by_content_and_meta = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content", "meta"])
doc1_meta2_id_by_content_and_meta = Document(content=text1, meta={"name": "doc2"}, id_hash_keys=["content", "meta"])
assert doc1_meta1_id_by_content_and_meta.id != doc1_meta2_id_by_content_and_meta.id
doc1_text1 = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content"])
doc3_text2 = Document(content=text2, meta={"name": "doc3"}, id_hash_keys=["content"])
assert doc1_text1.id != doc3_text2.id
with pytest.raises(ValueError):
_ = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content", "non_existing_field"])
def test_aggregate_labels_with_labels():
label1_with_filter1 = Label(
query="question",
answer=Answer(answer="1"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
filters={"name": ["filename1"]},
)
label2_with_filter1 = Label(
query="question",
answer=Answer(answer="2"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
filters={"name": ["filename1"]},
)
label3_with_filter2 = Label(
query="question",
answer=Answer(answer="2"),
is_correct_answer=True,
is_correct_document=True,
document=Document(content="some", id="777"),
origin="gold-label",
filters={"name": ["filename2"]},
)
label = MultiLabel(labels=[label1_with_filter1, label2_with_filter1])
assert label.filters == {"name": ["filename1"]}
with pytest.raises(ValueError):
label = MultiLabel(labels=[label1_with_filter1, label3_with_filter2])
def test_multilabel_preserve_order():
labels = [
Label(
id="0",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
Label(
id="1",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
Label(
id="2",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
Label(
id="3",
query="question",
answer=Answer(answer="", offsets_in_document=[Span(start=0, end=0)]),
document=Document(content="some", id="777"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
Label(
id="4",
query="question",
answer=Answer(answer="answer5", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=False,
is_correct_document=True,
origin="gold-label",
),
]
multilabel = MultiLabel(labels=labels)
for i in range(0, 5):
assert multilabel.labels[i].id == str(i)
def test_multilabel_preserve_order_w_duplicates():
labels = [
Label(
id="0",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
Label(
id="1",
query="question",
answer=Answer(answer="answer2", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
Label(
id="2",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
Label(
id="0",
query="question",
answer=Answer(answer="answer1", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some", id="123"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
Label(
id="2",
query="question",
answer=Answer(answer="answer3", offsets_in_document=[Span(start=12, end=18)]),
document=Document(content="some other", id="333"),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
),
]
multilabel = MultiLabel(labels=labels)
assert len(multilabel.document_ids) == 3
for i in range(0, 3):
assert multilabel.labels[i].id == str(i)
def test_multilabel_id():
query1 = "question 1"
query2 = "question 2"
document1 = Document(content="something", id="1")
answer1 = Answer(answer="answer 1")
filter1 = {"name": ["name 1"]}
filter2 = {"name": ["name 1"], "author": ["author 1"]}
label1 = Label(
query=query1,
document=document1,
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
answer=answer1,
filters=filter1,
)
label2 = Label(
query=query2,
document=document1,
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
answer=answer1,
filters=filter2,
)
label3 = Label(
query=query1,
document=document1,
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
answer=answer1,
filters=filter2,
)
assert MultiLabel(labels=[label1]).id == "33a3e58e13b16e9d6ec682ffe59ccc89"
assert MultiLabel(labels=[label2]).id == "1b3ad38b629db7b0e869373b01bc32b1"
assert MultiLabel(labels=[label3]).id == "531445fa3bdf98b8598a3bea032bd605"
def test_multilabel_with_doc_containing_dataframes():
label = Label(
query="A question",
document=Document(content=pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})),
is_correct_answer=True,
is_correct_document=True,
origin="gold-label",
answer=Answer(answer="answer 1"),
)
assert len(MultiLabel(labels=[label]).contexts) == 1
assert type(MultiLabel(labels=[label]).contexts[0]) is str
def test_serialize_speech_document():
speech_doc = SpeechDocument(
id=12345,
content_type="audio",
content="this is the content of the document",
content_audio=SAMPLES_PATH / "audio" / "this is the content of the document.wav",
meta={"some": "meta"},
)
speech_doc_dict = speech_doc.to_dict()
assert speech_doc_dict["content"] == "this is the content of the document"
assert speech_doc_dict["content_audio"] == str(
(SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute()
)
def test_deserialize_speech_document():
speech_doc = SpeechDocument(
id=12345,
content_type="audio",
content="this is the content of the document",
content_audio=SAMPLES_PATH / "audio" / "this is the content of the document.wav",
meta={"some": "meta"},
)
assert speech_doc == SpeechDocument.from_dict(speech_doc.to_dict())
def test_serialize_speech_answer():
speech_answer = SpeechAnswer(
answer="answer",
answer_audio=SAMPLES_PATH / "audio" / "answer.wav",
context="the context for this answer is here",
context_audio=SAMPLES_PATH / "audio" / "the context for this answer is here.wav",
)
speech_answer_dict = speech_answer.to_dict()
assert speech_answer_dict["answer"] == "answer"
assert speech_answer_dict["answer_audio"] == str((SAMPLES_PATH / "audio" / "answer.wav").absolute())
assert speech_answer_dict["context"] == "the context for this answer is here"
assert speech_answer_dict["context_audio"] == str(
(SAMPLES_PATH / "audio" / "the context for this answer is here.wav").absolute()
)
def test_deserialize_speech_answer():
speech_answer = SpeechAnswer(
answer="answer",
answer_audio=SAMPLES_PATH / "audio" / "answer.wav",
context="the context for this answer is here",
context_audio=SAMPLES_PATH / "audio" / "the context for this answer is here.wav",
)
assert speech_answer == SpeechAnswer.from_dict(speech_answer.to_dict())
def test_span_in():
assert 10 in Span(5, 15)
assert not 20 in Span(1, 15)
def test_span_in_edges():
assert 5 in Span(5, 15)
assert not 15 in Span(5, 15)
def test_span_in_other_values():
assert 10.0 in Span(5, 15)
assert "10" in Span(5, 15)
with pytest.raises(ValueError):
"hello" in Span(5, 15)
def test_assert_span_vs_span():
assert Span(10, 11) in Span(5, 15)
assert Span(5, 10) in Span(5, 15)
assert not Span(10, 15) in Span(5, 15)
assert not Span(5, 15) in Span(5, 15)
assert Span(5, 14) in Span(5, 15)
assert not Span(0, 1) in Span(5, 15)
assert not Span(0, 10) in Span(5, 15)
assert not Span(10, 20) in Span(5, 15)