2023-09-04 21:16:20 +02:00
|
|
|
import logging
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from haystack.preview import GeneratedAnswer, Document
|
2023-09-05 14:34:22 +02:00
|
|
|
from haystack.preview.components.builders.answer_builder import AnswerBuilder
|
2023-09-04 21:16:20 +02:00
|
|
|
|
|
|
|
|
2023-09-05 14:34:22 +02:00
|
|
|
class TestAnswerBuilder:
|
2023-09-04 21:16:20 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_to_dict(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder()
|
2023-09-04 21:16:20 +02:00
|
|
|
data = component.to_dict()
|
2023-09-05 14:34:22 +02:00
|
|
|
assert data == {"type": "AnswerBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}}
|
2023-09-04 21:16:20 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_to_dict_with_custom_init_parameters(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder(pattern="pattern", reference_pattern="reference_pattern")
|
2023-09-04 21:16:20 +02:00
|
|
|
data = component.to_dict()
|
|
|
|
assert data == {
|
2023-09-05 14:34:22 +02:00
|
|
|
"type": "AnswerBuilder",
|
2023-09-04 21:16:20 +02:00
|
|
|
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
|
|
|
|
}
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_from_dict(self):
|
|
|
|
data = {
|
2023-09-05 14:34:22 +02:00
|
|
|
"type": "AnswerBuilder",
|
2023-09-04 21:16:20 +02:00
|
|
|
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
|
|
|
|
}
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder.from_dict(data)
|
2023-09-04 21:16:20 +02:00
|
|
|
assert component.pattern == "pattern"
|
|
|
|
assert component.reference_pattern == "reference_pattern"
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_run_unmatching_input_len(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder()
|
2023-09-04 21:16:20 +02:00
|
|
|
with pytest.raises(ValueError):
|
|
|
|
component.run(queries=["query"], replies=[["reply1"], ["reply2"]], metadata=[[]])
|
|
|
|
|
|
|
|
def test_run_without_pattern(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder()
|
2023-09-07 12:54:24 +02:00
|
|
|
output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
|
|
|
|
answers = output["answers"]
|
2023-09-04 21:16:20 +02:00
|
|
|
assert len(answers) == 1
|
|
|
|
assert len(answers[0]) == 1
|
|
|
|
assert answers[0][0].data == "Answer: AnswerString"
|
|
|
|
assert answers[0][0].metadata == {}
|
|
|
|
assert answers[0][0].query == "test query"
|
|
|
|
assert answers[0][0].documents == []
|
|
|
|
assert isinstance(answers[0][0], GeneratedAnswer)
|
|
|
|
|
|
|
|
def test_run_with_pattern_with_capturing_group(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder(pattern=r"Answer: (.*)")
|
2023-09-07 12:54:24 +02:00
|
|
|
output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
|
|
|
|
answers = output["answers"]
|
2023-09-04 21:16:20 +02:00
|
|
|
assert len(answers) == 1
|
|
|
|
assert len(answers[0]) == 1
|
|
|
|
assert answers[0][0].data == "AnswerString"
|
|
|
|
assert answers[0][0].metadata == {}
|
|
|
|
assert answers[0][0].query == "test query"
|
|
|
|
assert answers[0][0].documents == []
|
|
|
|
assert isinstance(answers[0][0], GeneratedAnswer)
|
|
|
|
|
|
|
|
def test_run_with_pattern_without_capturing_group(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder(pattern=r"'.*'")
|
2023-09-07 12:54:24 +02:00
|
|
|
output = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]])
|
|
|
|
answers = output["answers"]
|
2023-09-04 21:16:20 +02:00
|
|
|
assert len(answers) == 1
|
|
|
|
assert len(answers[0]) == 1
|
|
|
|
assert answers[0][0].data == "'AnswerString'"
|
|
|
|
assert answers[0][0].metadata == {}
|
|
|
|
assert answers[0][0].query == "test query"
|
|
|
|
assert answers[0][0].documents == []
|
|
|
|
assert isinstance(answers[0][0], GeneratedAnswer)
|
|
|
|
|
|
|
|
def test_run_with_pattern_with_more_than_one_capturing_group(self):
|
|
|
|
with pytest.raises(ValueError, match="contains multiple capture groups"):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder(pattern=r"Answer: (.*), (.*)")
|
2023-09-04 21:16:20 +02:00
|
|
|
|
|
|
|
def test_run_with_pattern_set_at_runtime(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder(pattern="unused pattern")
|
2023-09-07 12:54:24 +02:00
|
|
|
output = component.run(
|
2023-09-04 21:16:20 +02:00
|
|
|
queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)"
|
|
|
|
)
|
2023-09-07 12:54:24 +02:00
|
|
|
answers = output["answers"]
|
2023-09-04 21:16:20 +02:00
|
|
|
assert len(answers) == 1
|
|
|
|
assert len(answers[0]) == 1
|
|
|
|
assert answers[0][0].data == "AnswerString"
|
|
|
|
assert answers[0][0].metadata == {}
|
|
|
|
assert answers[0][0].query == "test query"
|
|
|
|
assert answers[0][0].documents == []
|
|
|
|
assert isinstance(answers[0][0], GeneratedAnswer)
|
|
|
|
|
|
|
|
def test_run_with_documents_without_reference_pattern(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder()
|
2023-09-07 12:54:24 +02:00
|
|
|
output = component.run(
|
2023-09-04 21:16:20 +02:00
|
|
|
queries=["test query"],
|
|
|
|
replies=[["Answer: AnswerString"]],
|
|
|
|
metadata=[[{}]],
|
2023-09-11 16:40:00 +01:00
|
|
|
documents=[[Document(text="test doc 1"), Document(text="test doc 2")]],
|
2023-09-04 21:16:20 +02:00
|
|
|
)
|
2023-09-07 12:54:24 +02:00
|
|
|
answers = output["answers"]
|
2023-09-04 21:16:20 +02:00
|
|
|
assert len(answers) == 1
|
|
|
|
assert len(answers[0]) == 1
|
|
|
|
assert answers[0][0].data == "Answer: AnswerString"
|
|
|
|
assert answers[0][0].metadata == {}
|
|
|
|
assert answers[0][0].query == "test query"
|
|
|
|
assert len(answers[0][0].documents) == 2
|
2023-09-11 16:40:00 +01:00
|
|
|
assert answers[0][0].documents[0].text == "test doc 1"
|
|
|
|
assert answers[0][0].documents[1].text == "test doc 2"
|
2023-09-04 21:16:20 +02:00
|
|
|
|
|
|
|
def test_run_with_documents_with_reference_pattern(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
|
2023-09-07 12:54:24 +02:00
|
|
|
output = component.run(
|
2023-09-04 21:16:20 +02:00
|
|
|
queries=["test query"],
|
|
|
|
replies=[["Answer: AnswerString[2]"]],
|
|
|
|
metadata=[[{}]],
|
2023-09-11 16:40:00 +01:00
|
|
|
documents=[[Document(text="test doc 1"), Document(text="test doc 2")]],
|
2023-09-04 21:16:20 +02:00
|
|
|
)
|
2023-09-07 12:54:24 +02:00
|
|
|
answers = output["answers"]
|
2023-09-04 21:16:20 +02:00
|
|
|
assert len(answers) == 1
|
|
|
|
assert len(answers[0]) == 1
|
|
|
|
assert answers[0][0].data == "Answer: AnswerString[2]"
|
|
|
|
assert answers[0][0].metadata == {}
|
|
|
|
assert answers[0][0].query == "test query"
|
|
|
|
assert len(answers[0][0].documents) == 1
|
2023-09-11 16:40:00 +01:00
|
|
|
assert answers[0][0].documents[0].text == "test doc 2"
|
2023-09-04 21:16:20 +02:00
|
|
|
|
|
|
|
def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
|
2023-09-04 21:16:20 +02:00
|
|
|
with caplog.at_level(logging.WARNING):
|
2023-09-07 12:54:24 +02:00
|
|
|
output = component.run(
|
2023-09-04 21:16:20 +02:00
|
|
|
queries=["test query"],
|
|
|
|
replies=[["Answer: AnswerString[3]"]],
|
|
|
|
metadata=[[{}]],
|
2023-09-11 16:40:00 +01:00
|
|
|
documents=[[Document(text="test doc 1"), Document(text="test doc 2")]],
|
2023-09-04 21:16:20 +02:00
|
|
|
)
|
2023-09-07 12:54:24 +02:00
|
|
|
answers = output["answers"]
|
2023-09-04 21:16:20 +02:00
|
|
|
assert len(answers) == 1
|
|
|
|
assert len(answers[0]) == 1
|
|
|
|
assert answers[0][0].data == "Answer: AnswerString[3]"
|
|
|
|
assert answers[0][0].metadata == {}
|
|
|
|
assert answers[0][0].query == "test query"
|
|
|
|
assert len(answers[0][0].documents) == 0
|
|
|
|
assert "Document index '3' referenced in Generator output is out of range." in caplog.text
|
|
|
|
|
|
|
|
def test_run_with_reference_pattern_set_at_runtime(self):
|
2023-09-05 14:34:22 +02:00
|
|
|
component = AnswerBuilder(reference_pattern="unused pattern")
|
2023-09-07 12:54:24 +02:00
|
|
|
output = component.run(
|
2023-09-04 21:16:20 +02:00
|
|
|
queries=["test query"],
|
|
|
|
replies=[["Answer: AnswerString[2][3]"]],
|
|
|
|
metadata=[[{}]],
|
2023-09-11 16:40:00 +01:00
|
|
|
documents=[[Document(text="test doc 1"), Document(text="test doc 2"), Document(text="test doc 3")]],
|
2023-09-04 21:16:20 +02:00
|
|
|
reference_pattern="\\[(\\d+)\\]",
|
|
|
|
)
|
2023-09-07 12:54:24 +02:00
|
|
|
answers = output["answers"]
|
2023-09-04 21:16:20 +02:00
|
|
|
assert len(answers) == 1
|
|
|
|
assert len(answers[0]) == 1
|
|
|
|
assert answers[0][0].data == "Answer: AnswerString[2][3]"
|
|
|
|
assert answers[0][0].metadata == {}
|
|
|
|
assert answers[0][0].query == "test query"
|
|
|
|
assert len(answers[0][0].documents) == 2
|
2023-09-11 16:40:00 +01:00
|
|
|
assert answers[0][0].documents[0].text == "test doc 2"
|
|
|
|
assert answers[0][0].documents[1].text == "test doc 3"
|