2022-02-03 19:19:05 +01:00
|
|
|
from haystack.schema import Document, Label, Answer, Span, MultiLabel
|
2022-01-03 16:58:19 +01:00
|
|
|
import pytest
|
2021-10-13 14:23:23 +02:00
|
|
|
import numpy as np
|
2021-05-17 21:21:52 +05:30
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
LABELS = [
|
2022-02-03 13:43:18 +01:00
|
|
|
Label(
|
|
|
|
query="some",
|
|
|
|
answer=Answer(
|
|
|
|
answer="an answer",
|
|
|
|
type="extractive",
|
|
|
|
score=0.1,
|
|
|
|
document_id="123",
|
|
|
|
offsets_in_document=[Span(start=1, end=3)],
|
|
|
|
),
|
|
|
|
document=Document(content="some text", content_type="text"),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
origin="user-feedback",
|
|
|
|
),
|
|
|
|
Label(
|
|
|
|
query="some",
|
|
|
|
answer=Answer(answer="annother answer", type="extractive", score=0.1, document_id="123"),
|
|
|
|
document=Document(content="some text", content_type="text"),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
origin="user-feedback",
|
|
|
|
),
|
|
|
|
Label(
|
|
|
|
query="some",
|
|
|
|
answer=Answer(
|
|
|
|
answer="an answer",
|
|
|
|
type="extractive",
|
|
|
|
score=0.1,
|
|
|
|
document_id="123",
|
|
|
|
offsets_in_document=[Span(start=1, end=3)],
|
|
|
|
),
|
|
|
|
document=Document(content="some text", content_type="text"),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
origin="user-feedback",
|
|
|
|
),
|
|
|
|
]
|
2021-10-13 14:23:23 +02:00
|
|
|
|
|
|
|
|
|
|
|
def test_no_answer_label():
|
|
|
|
labels = [
|
|
|
|
Label(
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer=""),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document=Document(content="some", id="777"),
|
|
|
|
origin="gold-label",
|
|
|
|
),
|
|
|
|
Label(
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer=""),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document=Document(content="some", id="777"),
|
|
|
|
no_answer=True,
|
|
|
|
origin="gold-label",
|
|
|
|
),
|
|
|
|
Label(
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="some"),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document=Document(content="some", id="777"),
|
|
|
|
origin="gold-label",
|
|
|
|
),
|
|
|
|
Label(
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="some"),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document=Document(content="some", id="777"),
|
|
|
|
no_answer=False,
|
|
|
|
origin="gold-label",
|
2022-02-03 13:43:18 +01:00
|
|
|
),
|
2021-10-13 14:23:23 +02:00
|
|
|
]
|
|
|
|
|
|
|
|
assert labels[0].no_answer == True
|
|
|
|
assert labels[1].no_answer == True
|
|
|
|
assert labels[2].no_answer == False
|
|
|
|
assert labels[3].no_answer == False
|
|
|
|
|
|
|
|
|
|
|
|
def test_equal_label():
|
|
|
|
assert LABELS[2] == LABELS[0]
|
|
|
|
assert LABELS[1] != LABELS[0]
|
|
|
|
|
|
|
|
|
|
|
|
def test_answer_to_json():
|
2022-02-03 13:43:18 +01:00
|
|
|
a = Answer(
|
|
|
|
answer="an answer",
|
|
|
|
type="extractive",
|
|
|
|
score=0.1,
|
|
|
|
context="abc",
|
|
|
|
offsets_in_document=[Span(start=1, end=10)],
|
|
|
|
offsets_in_context=[Span(start=3, end=5)],
|
|
|
|
document_id="123",
|
|
|
|
)
|
2021-10-13 14:23:23 +02:00
|
|
|
j = a.to_json()
|
2021-10-18 14:38:14 +02:00
|
|
|
assert type(j) == str
|
|
|
|
assert len(j) > 30
|
2021-10-13 14:23:23 +02:00
|
|
|
a_new = Answer.from_json(j)
|
2021-10-18 14:38:14 +02:00
|
|
|
assert type(a_new.offsets_in_document[0]) == Span
|
|
|
|
assert a_new == a
|
|
|
|
|
|
|
|
|
|
|
|
def test_answer_to_dict():
|
2022-02-03 13:43:18 +01:00
|
|
|
a = Answer(
|
|
|
|
answer="an answer",
|
|
|
|
type="extractive",
|
|
|
|
score=0.1,
|
|
|
|
context="abc",
|
|
|
|
offsets_in_document=[Span(start=1, end=10)],
|
|
|
|
offsets_in_context=[Span(start=3, end=5)],
|
|
|
|
document_id="123",
|
|
|
|
)
|
2021-10-18 14:38:14 +02:00
|
|
|
j = a.to_dict()
|
|
|
|
assert type(j) == dict
|
|
|
|
a_new = Answer.from_dict(j)
|
|
|
|
assert type(a_new.offsets_in_document[0]) == Span
|
2021-10-13 14:23:23 +02:00
|
|
|
assert a_new == a
|
|
|
|
|
|
|
|
|
|
|
|
def test_label_to_json():
|
|
|
|
j0 = LABELS[0].to_json()
|
|
|
|
l_new = Label.from_json(j0)
|
|
|
|
assert l_new == LABELS[0]
|
|
|
|
|
|
|
|
|
2021-10-18 14:38:14 +02:00
|
|
|
def test_label_to_json():
|
|
|
|
j0 = LABELS[0].to_json()
|
|
|
|
l_new = Label.from_json(j0)
|
|
|
|
assert l_new == LABELS[0]
|
|
|
|
assert l_new.answer.offsets_in_document[0].start == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_label_to_dict():
|
|
|
|
j0 = LABELS[0].to_dict()
|
|
|
|
l_new = Label.from_dict(j0)
|
|
|
|
assert l_new == LABELS[0]
|
|
|
|
assert l_new.answer.offsets_in_document[0].start == 1
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
def test_doc_to_json():
|
|
|
|
# With embedding
|
2022-02-03 13:43:18 +01:00
|
|
|
d = Document(
|
|
|
|
content="some text",
|
|
|
|
content_type="text",
|
|
|
|
score=0.99988,
|
|
|
|
meta={"name": "doc1"},
|
|
|
|
embedding=np.random.rand(768).astype(np.float32),
|
|
|
|
)
|
2021-10-13 14:23:23 +02:00
|
|
|
j0 = d.to_json()
|
|
|
|
d_new = Document.from_json(j0)
|
|
|
|
assert d == d_new
|
|
|
|
|
|
|
|
# No embedding
|
2022-02-03 13:43:18 +01:00
|
|
|
d = Document(content="some text", content_type="text", score=0.99988, meta={"name": "doc1"}, embedding=None)
|
2021-10-13 14:23:23 +02:00
|
|
|
j0 = d.to_json()
|
|
|
|
d_new = Document.from_json(j0)
|
|
|
|
assert d == d_new
|
|
|
|
|
|
|
|
|
|
|
|
def test_answer_postinit():
|
|
|
|
a = Answer(answer="test", offsets_in_document=[{"start": 10, "end": 20}])
|
|
|
|
assert a.meta == {}
|
|
|
|
assert isinstance(a.offsets_in_document[0], Span)
|
2021-05-17 21:21:52 +05:30
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2021-05-17 21:21:52 +05:30
|
|
|
def test_generate_doc_id_using_text():
|
|
|
|
text1 = "text1"
|
|
|
|
text2 = "text2"
|
2021-10-13 14:23:23 +02:00
|
|
|
doc1_text1 = Document(content=text1, meta={"name": "doc1"})
|
|
|
|
doc2_text1 = Document(content=text1, meta={"name": "doc2"})
|
|
|
|
doc3_text2 = Document(content=text2, meta={"name": "doc3"})
|
2021-05-17 21:21:52 +05:30
|
|
|
|
|
|
|
assert doc1_text1.id == doc2_text1.id
|
|
|
|
assert doc1_text1.id != doc3_text2.id
|
|
|
|
|
|
|
|
|
|
|
|
def test_generate_doc_id_using_custom_list():
|
|
|
|
text1 = "text1"
|
|
|
|
text2 = "text2"
|
|
|
|
|
2022-01-03 16:58:19 +01:00
|
|
|
doc1_meta1_id_by_content = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content"])
|
|
|
|
doc1_meta2_id_by_content = Document(content=text1, meta={"name": "doc2"}, id_hash_keys=["content"])
|
|
|
|
assert doc1_meta1_id_by_content.id == doc1_meta2_id_by_content.id
|
2021-05-17 21:21:52 +05:30
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
doc1_meta1_id_by_content_and_meta = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content", "meta"])
|
2022-01-03 16:58:19 +01:00
|
|
|
doc1_meta2_id_by_content_and_meta = Document(content=text1, meta={"name": "doc2"}, id_hash_keys=["content", "meta"])
|
|
|
|
assert doc1_meta1_id_by_content_and_meta.id != doc1_meta2_id_by_content_and_meta.id
|
|
|
|
|
|
|
|
doc1_text1 = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content"])
|
|
|
|
doc3_text2 = Document(content=text2, meta={"name": "doc3"}, id_hash_keys=["content"])
|
2021-05-17 21:21:52 +05:30
|
|
|
assert doc1_text1.id != doc3_text2.id
|
2022-01-03 16:58:19 +01:00
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
2022-02-03 13:43:18 +01:00
|
|
|
_ = Document(content=text1, meta={"name": "doc1"}, id_hash_keys=["content", "non_existing_field"])
|
2022-02-03 19:19:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
def test_aggregate_labels_with_labels():
|
|
|
|
label1_with_filter1 = Label(
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="1"),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document=Document(content="some", id="777"),
|
|
|
|
origin="gold-label",
|
|
|
|
filters={"name": ["filename1"]},
|
|
|
|
)
|
|
|
|
label2_with_filter1 = Label(
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="2"),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document=Document(content="some", id="777"),
|
|
|
|
origin="gold-label",
|
|
|
|
filters={"name": ["filename1"]},
|
|
|
|
)
|
|
|
|
label3_with_filter2 = Label(
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="2"),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document=Document(content="some", id="777"),
|
|
|
|
origin="gold-label",
|
|
|
|
filters={"name": ["filename2"]},
|
|
|
|
)
|
|
|
|
label = MultiLabel(labels=[label1_with_filter1, label2_with_filter1])
|
|
|
|
assert label.filters == {"name": ["filename1"]}
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
label = MultiLabel(labels=[label1_with_filter1, label3_with_filter2])
|