diff --git a/haystack/preview/components/builders/answer_builder.py b/haystack/preview/components/builders/answer_builder.py index 043b08b6c..af644d22a 100644 --- a/haystack/preview/components/builders/answer_builder.py +++ b/haystack/preview/components/builders/answer_builder.py @@ -37,23 +37,23 @@ class AnswerBuilder: self.pattern = pattern self.reference_pattern = reference_pattern - @component.output_types(answers=List[List[GeneratedAnswer]]) + @component.output_types(answers=List[GeneratedAnswer]) def run( self, - queries: List[str], - replies: List[List[str]], - metadata: List[List[Dict[str, Any]]], - documents: Optional[List[List[Document]]] = None, + query: str, + replies: List[str], + metadata: List[Dict[str, Any]], + documents: Optional[List[Document]] = None, pattern: Optional[str] = None, reference_pattern: Optional[str] = None, ): """ Parse the output of a Generator to `Answer` objects using regular expressions. - :param queries: The queries used in the prompts for the Generator. A list of strings. - :param replies: The output of the Generator. A list of lists of strings. - :param metadata: The metadata returned by the Generator. A list of lists of dictionaries. - :param documents: The documents used as input to the Generator. A list of lists of `Document` objects. If + :param query: The query used in the prompts for the Generator. A strings. + :param replies: The output of the Generator. A list of strings. + :param metadata: The metadata returned by the Generator. A list of dictionaries. + :param documents: The documents used as input to the Generator. A list of `Document` objects. If `documents` are specified, they are added to the `Answer` objects. If both `documents` and `reference_pattern` are specified, the documents referenced in the Generator output are extracted from the input documents and added to the `Answer` objects. @@ -73,43 +73,33 @@ class AnswerBuilder: If not specified, no parsing is done, and all documents are referenced. Default: `None`. """ - if len(queries) != len(replies) != len(metadata): - raise ValueError( - f"Number of queries ({len(queries)}), replies ({len(replies)}), and metadata " - f"({len(metadata)}) must match." - ) - + if len(replies) != len(metadata): + raise ValueError(f"Number of replies ({len(replies)}), and metadata ({len(metadata)}) must match.") if pattern: AnswerBuilder._check_num_groups_in_regex(pattern) - documents = documents or [] pattern = pattern or self.pattern reference_pattern = reference_pattern or self.reference_pattern all_answers = [] - for i, (query, reply_list, meta_list) in enumerate(zip(queries, replies, metadata)): - doc_list = documents[i] if i < len(documents) else [] + for reply, meta in zip(replies, metadata): + referenced_docs = [] + if documents: + reference_idxs = [] + if reference_pattern: + reference_idxs = AnswerBuilder._extract_reference_idxs(reply, reference_pattern) + else: + reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)] - extracted_answer_strings = AnswerBuilder._extract_answer_strings(reply_list, pattern) - - if doc_list and reference_pattern: - reference_idxs = AnswerBuilder._extract_reference_idxs(reply_list, reference_pattern) - else: - reference_idxs = [[doc_idx for doc_idx, _ in enumerate(doc_list)] for _ in reply_list] - - answers_for_cur_query = [] - for answer_string, doc_idxs, meta in zip(extracted_answer_strings, reference_idxs, meta_list): - referenced_docs = [] - for idx in doc_idxs: - if idx < len(doc_list): - referenced_docs.append(doc_list[idx]) - else: + for idx in reference_idxs: + try: + referenced_docs.append(documents[idx]) + except IndexError: logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1) - answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta) - answers_for_cur_query.append(answer) - - all_answers.append(answers_for_cur_query) + answer_string = AnswerBuilder._extract_answer_string(reply, pattern) + answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta) + all_answers.append(answer) return {"answers": all_answers} @@ -127,39 +117,29 @@ class AnswerBuilder: return default_from_dict(cls, data) @staticmethod - def _extract_answer_strings(replies: List[str], pattern: Optional[str] = None) -> List[str]: + def _extract_answer_string(reply: str, pattern: Optional[str] = None) -> str: """ - Extract the answer strings from the generator output using the specified pattern. + Extract the answer string from the generator output using the specified pattern. If no pattern is specified, the whole string is used as the answer. - :param replies: The output of the Generator. A list of strings. + :param replies: The output of the Generator. A string. :param pattern: The regular expression pattern to use to extract the answer text from the generator output. """ if pattern is None: - return replies + return reply - extracted_answers = [] - for reply in replies: - if match := re.search(pattern, reply): - # No capture group in pattern -> use the whole match as answer - if not match.lastindex: - extracted_answers.append(match.group(0)) - # One capture group in pattern -> use the capture group as answer - else: - extracted_answers.append(match.group(1)) - else: - extracted_answers.append("") - - return extracted_answers + if match := re.search(pattern, reply): + # No capture group in pattern -> use the whole match as answer + if not match.lastindex: + return match.group(0) + # One capture group in pattern -> use the capture group as answer + return match.group(1) + return "" @staticmethod - def _extract_reference_idxs(replies: List[str], reference_pattern: str) -> List[List[int]]: - reference_idxs = [] - for reply in replies: - document_idxs = re.findall(reference_pattern, reply) - reference_idxs.append([int(idx) - 1 for idx in document_idxs]) - - return reference_idxs + def _extract_reference_idxs(reply: str, reference_pattern: str) -> List[int]: + document_idxs = re.findall(reference_pattern, reply) + return [int(idx) - 1 for idx in document_idxs] @staticmethod def _check_num_groups_in_regex(pattern: str): diff --git a/test/preview/components/builders/test_answer_builder.py b/test/preview/components/builders/test_answer_builder.py index eccec5b2e..03b4a42f9 100644 --- a/test/preview/components/builders/test_answer_builder.py +++ b/test/preview/components/builders/test_answer_builder.py @@ -36,130 +36,122 @@ class TestAnswerBuilder: def test_run_unmatching_input_len(self): component = AnswerBuilder() with pytest.raises(ValueError): - component.run(queries=["query"], replies=[["reply1"], ["reply2"]], metadata=[[]]) + component.run(query="query", replies=["reply1", "reply2"], metadata=[]) def test_run_without_pattern(self): component = AnswerBuilder() - output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) + output = component.run(query="test query", replies=["Answer: AnswerString"], metadata=[{}]) answers = output["answers"] assert len(answers) == 1 - assert len(answers[0]) == 1 - assert answers[0][0].data == "Answer: AnswerString" - assert answers[0][0].metadata == {} - assert answers[0][0].query == "test query" - assert answers[0][0].documents == [] - assert isinstance(answers[0][0], GeneratedAnswer) + assert answers[0].data == "Answer: AnswerString" + assert answers[0].metadata == {} + assert answers[0].query == "test query" + assert answers[0].documents == [] + assert isinstance(answers[0], GeneratedAnswer) def test_run_with_pattern_with_capturing_group(self): component = AnswerBuilder(pattern=r"Answer: (.*)") - output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) + output = component.run(query="test query", replies=["Answer: AnswerString"], metadata=[{}]) answers = output["answers"] assert len(answers) == 1 - assert len(answers[0]) == 1 - assert answers[0][0].data == "AnswerString" - assert answers[0][0].metadata == {} - assert answers[0][0].query == "test query" - assert answers[0][0].documents == [] - assert isinstance(answers[0][0], GeneratedAnswer) + assert answers[0].data == "AnswerString" + assert answers[0].metadata == {} + assert answers[0].query == "test query" + assert answers[0].documents == [] + assert isinstance(answers[0], GeneratedAnswer) def test_run_with_pattern_without_capturing_group(self): component = AnswerBuilder(pattern=r"'.*'") - output = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]]) + output = component.run(query="test query", replies=["Answer: 'AnswerString'"], metadata=[{}]) answers = output["answers"] assert len(answers) == 1 - assert len(answers[0]) == 1 - assert answers[0][0].data == "'AnswerString'" - assert answers[0][0].metadata == {} - assert answers[0][0].query == "test query" - assert answers[0][0].documents == [] - assert isinstance(answers[0][0], GeneratedAnswer) + assert answers[0].data == "'AnswerString'" + assert answers[0].metadata == {} + assert answers[0].query == "test query" + assert answers[0].documents == [] + assert isinstance(answers[0], GeneratedAnswer) def test_run_with_pattern_with_more_than_one_capturing_group(self): with pytest.raises(ValueError, match="contains multiple capture groups"): - component = AnswerBuilder(pattern=r"Answer: (.*), (.*)") + AnswerBuilder(pattern=r"Answer: (.*), (.*)") def test_run_with_pattern_set_at_runtime(self): component = AnswerBuilder(pattern="unused pattern") output = component.run( - queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)" + query="test query", replies=["Answer: AnswerString"], metadata=[{}], pattern=r"Answer: (.*)" ) answers = output["answers"] assert len(answers) == 1 - assert len(answers[0]) == 1 - assert answers[0][0].data == "AnswerString" - assert answers[0][0].metadata == {} - assert answers[0][0].query == "test query" - assert answers[0][0].documents == [] - assert isinstance(answers[0][0], GeneratedAnswer) + assert answers[0].data == "AnswerString" + assert answers[0].metadata == {} + assert answers[0].query == "test query" + assert answers[0].documents == [] + assert isinstance(answers[0], GeneratedAnswer) def test_run_with_documents_without_reference_pattern(self): component = AnswerBuilder() output = component.run( - queries=["test query"], - replies=[["Answer: AnswerString"]], - metadata=[[{}]], - documents=[[Document(text="test doc 1"), Document(text="test doc 2")]], + query="test query", + replies=["Answer: AnswerString"], + metadata=[{}], + documents=[Document(text="test doc 1"), Document(text="test doc 2")], ) answers = output["answers"] assert len(answers) == 1 - assert len(answers[0]) == 1 - assert answers[0][0].data == "Answer: AnswerString" - assert answers[0][0].metadata == {} - assert answers[0][0].query == "test query" - assert len(answers[0][0].documents) == 2 - assert answers[0][0].documents[0].text == "test doc 1" - assert answers[0][0].documents[1].text == "test doc 2" + assert answers[0].data == "Answer: AnswerString" + assert answers[0].metadata == {} + assert answers[0].query == "test query" + assert len(answers[0].documents) == 2 + assert answers[0].documents[0].text == "test doc 1" + assert answers[0].documents[1].text == "test doc 2" def test_run_with_documents_with_reference_pattern(self): component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]") output = component.run( - queries=["test query"], - replies=[["Answer: AnswerString[2]"]], - metadata=[[{}]], - documents=[[Document(text="test doc 1"), Document(text="test doc 2")]], + query="test query", + replies=["Answer: AnswerString[2]"], + metadata=[{}], + documents=[Document(text="test doc 1"), Document(text="test doc 2")], ) answers = output["answers"] assert len(answers) == 1 - assert len(answers[0]) == 1 - assert answers[0][0].data == "Answer: AnswerString[2]" - assert answers[0][0].metadata == {} - assert answers[0][0].query == "test query" - assert len(answers[0][0].documents) == 1 - assert answers[0][0].documents[0].text == "test doc 2" + assert answers[0].data == "Answer: AnswerString[2]" + assert answers[0].metadata == {} + assert answers[0].query == "test query" + assert len(answers[0].documents) == 1 + assert answers[0].documents[0].text == "test doc 2" def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog): component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]") with caplog.at_level(logging.WARNING): output = component.run( - queries=["test query"], - replies=[["Answer: AnswerString[3]"]], - metadata=[[{}]], - documents=[[Document(text="test doc 1"), Document(text="test doc 2")]], + query="test query", + replies=["Answer: AnswerString[3]"], + metadata=[{}], + documents=[Document(text="test doc 1"), Document(text="test doc 2")], ) answers = output["answers"] assert len(answers) == 1 - assert len(answers[0]) == 1 - assert answers[0][0].data == "Answer: AnswerString[3]" - assert answers[0][0].metadata == {} - assert answers[0][0].query == "test query" - assert len(answers[0][0].documents) == 0 + assert answers[0].data == "Answer: AnswerString[3]" + assert answers[0].metadata == {} + assert answers[0].query == "test query" + assert len(answers[0].documents) == 0 assert "Document index '3' referenced in Generator output is out of range." in caplog.text def test_run_with_reference_pattern_set_at_runtime(self): component = AnswerBuilder(reference_pattern="unused pattern") output = component.run( - queries=["test query"], - replies=[["Answer: AnswerString[2][3]"]], - metadata=[[{}]], - documents=[[Document(text="test doc 1"), Document(text="test doc 2"), Document(text="test doc 3")]], + query="test query", + replies=["Answer: AnswerString[2][3]"], + metadata=[{}], + documents=[Document(text="test doc 1"), Document(text="test doc 2"), Document(text="test doc 3")], reference_pattern="\\[(\\d+)\\]", ) answers = output["answers"] assert len(answers) == 1 - assert len(answers[0]) == 1 - assert answers[0][0].data == "Answer: AnswerString[2][3]" - assert answers[0][0].metadata == {} - assert answers[0][0].query == "test query" - assert len(answers[0][0].documents) == 2 - assert answers[0][0].documents[0].text == "test doc 2" - assert answers[0][0].documents[1].text == "test doc 3" + assert answers[0].data == "Answer: AnswerString[2][3]" + assert answers[0].metadata == {} + assert answers[0].query == "test query" + assert len(answers[0].documents) == 2 + assert answers[0].documents[0].text == "test doc 2" + assert answers[0].documents[1].text == "test doc 3"