diff --git a/haystack/preview/components/builders/answer_builder.py b/haystack/preview/components/builders/answer_builder.py index 221c46a5a..043b08b6c 100644 --- a/haystack/preview/components/builders/answer_builder.py +++ b/haystack/preview/components/builders/answer_builder.py @@ -111,7 +111,7 @@ class AnswerBuilder: all_answers.append(answers_for_cur_query) - return all_answers + return {"answers": all_answers} def to_dict(self) -> Dict[str, Any]: """ diff --git a/test/preview/components/builders/test_answer_builder.py b/test/preview/components/builders/test_answer_builder.py index 78542255f..65fb26d8b 100644 --- a/test/preview/components/builders/test_answer_builder.py +++ b/test/preview/components/builders/test_answer_builder.py @@ -40,7 +40,8 @@ class TestAnswerBuilder: def test_run_without_pattern(self): component = AnswerBuilder() - answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) + output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) + answers = output["answers"] assert len(answers) == 1 assert len(answers[0]) == 1 assert answers[0][0].data == "Answer: AnswerString" @@ -51,7 +52,8 @@ class TestAnswerBuilder: def test_run_with_pattern_with_capturing_group(self): component = AnswerBuilder(pattern=r"Answer: (.*)") - answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) + output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) + answers = output["answers"] assert len(answers) == 1 assert len(answers[0]) == 1 assert answers[0][0].data == "AnswerString" @@ -62,7 +64,8 @@ class TestAnswerBuilder: def test_run_with_pattern_without_capturing_group(self): component = AnswerBuilder(pattern=r"'.*'") - answers = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]]) + output = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]]) + answers = output["answers"] assert len(answers) == 1 assert len(answers[0]) == 1 assert answers[0][0].data == "'AnswerString'" @@ -77,9 +80,10 @@ class TestAnswerBuilder: def test_run_with_pattern_set_at_runtime(self): component = AnswerBuilder(pattern="unused pattern") - answers = component.run( + output = component.run( queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)" ) + answers = output["answers"] assert len(answers) == 1 assert len(answers[0]) == 1 assert answers[0][0].data == "AnswerString" @@ -90,12 +94,13 @@ class TestAnswerBuilder: def test_run_with_documents_without_reference_pattern(self): component = AnswerBuilder() - answers = component.run( + output = component.run( queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], documents=[[Document(content="test doc 1"), Document(content="test doc 2")]], ) + answers = output["answers"] assert len(answers) == 1 assert len(answers[0]) == 1 assert answers[0][0].data == "Answer: AnswerString" @@ -107,12 +112,13 @@ class TestAnswerBuilder: def test_run_with_documents_with_reference_pattern(self): component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]") - answers = component.run( + output = component.run( queries=["test query"], replies=[["Answer: AnswerString[2]"]], metadata=[[{}]], documents=[[Document(content="test doc 1"), Document(content="test doc 2")]], ) + answers = output["answers"] assert len(answers) == 1 assert len(answers[0]) == 1 assert answers[0][0].data == "Answer: AnswerString[2]" @@ -124,12 +130,13 @@ class TestAnswerBuilder: def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog): component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]") with caplog.at_level(logging.WARNING): - answers = component.run( + output = component.run( queries=["test query"], replies=[["Answer: AnswerString[3]"]], metadata=[[{}]], documents=[[Document(content="test doc 1"), Document(content="test doc 2")]], ) + answers = output["answers"] assert len(answers) == 1 assert len(answers[0]) == 1 assert answers[0][0].data == "Answer: AnswerString[3]" @@ -140,7 +147,7 @@ class TestAnswerBuilder: def test_run_with_reference_pattern_set_at_runtime(self): component = AnswerBuilder(reference_pattern="unused pattern") - answers = component.run( + output = component.run( queries=["test query"], replies=[["Answer: AnswerString[2][3]"]], metadata=[[{}]], @@ -149,6 +156,7 @@ class TestAnswerBuilder: ], reference_pattern="\\[(\\d+)\\]", ) + answers = output["answers"] assert len(answers) == 1 assert len(answers[0]) == 1 assert answers[0][0].data == "Answer: AnswerString[2][3]"