diff --git a/haystack/components/builders/answer_builder.py b/haystack/components/builders/answer_builder.py index ede35dd19..81d1dfe5f 100644 --- a/haystack/components/builders/answer_builder.py +++ b/haystack/components/builders/answer_builder.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import re +from dataclasses import replace from typing import Any, Optional, Union from haystack import Document, GeneratedAnswer, component, logging @@ -29,10 +30,50 @@ class AnswerBuilder: builder = AnswerBuilder(pattern="Answer: (.*)") builder.run(query="What's the answer?", replies=["This is an argument. Answer: This is the answer."]) ``` + + ### Usage example with documents and reference pattern + + ```python + from haystack import Document + from haystack.components.builders import AnswerBuilder + + replies = ["The capital of France is Paris [2]."] + + docs = [ + Document(content="Berlin is the capital of Germany."), + Document(content="Paris is the capital of France."), + Document(content="Rome is the capital of Italy."), + ] + + builder = AnswerBuilder(reference_pattern="\\[(\\d+)\\]", return_only_referenced_documents=False) + result = builder.run(query="What is the capital of France?", replies=replies, documents=docs)["answers"][0] + + print(f"Answer: {result.data}") + print("References:") + for doc in result.documents: + if doc.meta["referenced"]: + print(f"[{doc.meta['source_index']}] {doc.content}") + print("Other sources:") + for doc in result.documents: + if not doc.meta["referenced"]: + print(f"[{doc.meta['source_index']}] {doc.content}") + + # Answer: The capital of France is Paris + # References: + # [2] Paris is the capital of France. + # Other sources: + # [1] Berlin is the capital of Germany. + # [3] Rome is the capital of Italy. + ``` """ def __init__( - self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None, last_message_only: bool = False + self, + pattern: Optional[str] = None, + reference_pattern: Optional[str] = None, + last_message_only: bool = False, + *, + return_only_referenced_documents: bool = True, ): """ Creates an instance of the AnswerBuilder component. @@ -49,13 +90,20 @@ class AnswerBuilder: :param reference_pattern: The regular expression pattern used for parsing the document references. - If not specified, no parsing is done, and all documents are referenced. + If not specified, no parsing is done, and all documents are returned. References need to be specified as indices of the input documents and start at [1]. Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". + If this parameter is provided, documents metadata will contain a "referenced" key with a boolean value. :param last_message_only: If False (default value), all messages are used as the answer. If True, only the last message is used as the answer. + + :param return_only_referenced_documents: + To be used in conjunction with `reference_pattern`. + If True (default value), only the documents that were actually referenced in `replies` are returned. + If False, all documents are returned. + If `reference_pattern` is not provided, this parameter has no effect, and all documents are returned. """ if pattern: AnswerBuilder._check_num_groups_in_regex(pattern) @@ -63,6 +111,7 @@ class AnswerBuilder: self.pattern = pattern self.reference_pattern = reference_pattern self.last_message_only = last_message_only + self.return_only_referenced_documents = return_only_referenced_documents @component.output_types(answers=list[GeneratedAnswer]) def run( # pylint: disable=too-many-positional-arguments @@ -85,9 +134,12 @@ class AnswerBuilder: The metadata returned by the Generator. If not specified, the generated answer will contain no metadata. :param documents: The documents used as the Generator inputs. If specified, they are added to - the`GeneratedAnswer` objects. - If both `documents` and `reference_pattern` are specified, the documents referenced in the - Generator output are extracted from the input documents and added to the `GeneratedAnswer` objects. + the `GeneratedAnswer` objects. + Each Document.meta includes a "source_index" key, representing its 1-based position in the input list. + When `reference_pattern` is provided: + - "referenced" key is added to the Document.meta, indicating if the document was referenced in the output. + - `return_only_referenced_documents` init parameter controls if all or only referenced documents are + returned. :param pattern: The regular expression pattern to extract the answer text from the Generator. If not specified, the entire response is used as the answer. @@ -100,7 +152,7 @@ class AnswerBuilder: "this is an argument. Answer: this is an answer". :param reference_pattern: The regular expression pattern used for parsing the document references. - If not specified, no parsing is done, and all documents are referenced. + If not specified, no parsing is done, and all documents are returned. References need to be specified as indices of the input documents and start at [1]. Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". @@ -117,21 +169,14 @@ class AnswerBuilder: pattern = pattern or self.pattern reference_pattern = reference_pattern or self.reference_pattern + + replies_to_iterate = replies[-1:] if self.last_message_only and replies else replies + meta_to_iterate = meta[-1:] if self.last_message_only and meta else meta + all_answers = [] - - replies_to_iterate = replies - meta_to_iterate = meta - - if self.last_message_only and replies: - replies_to_iterate = replies[-1:] - meta_to_iterate = meta[-1:] - for reply, given_metadata in zip(replies_to_iterate, meta_to_iterate): # Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is - if isinstance(reply, ChatMessage): - extracted_reply = reply.text or "" - else: - extracted_reply = str(reply) + extracted_reply = reply.text or "" if isinstance(reply, ChatMessage) else str(reply) extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else {} extracted_metadata = {**extracted_metadata, **given_metadata} @@ -139,18 +184,31 @@ class AnswerBuilder: referenced_docs = [] if documents: - if reference_pattern: - reference_idxs = AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern) - else: - reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)] + referenced_idxs = ( + AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern) + if reference_pattern + else set() + ) + doc_idxs = ( + referenced_idxs + if reference_pattern and self.return_only_referenced_documents + else set(range(len(documents))) + ) - for idx in reference_idxs: + for idx in doc_idxs: try: - referenced_docs.append(documents[idx]) + doc = documents[idx] except IndexError: logger.warning( "Document index '{index}' referenced in Generator output is out of range. ", index=idx + 1 ) + continue + + doc_meta: dict[str, Any] = doc.meta or {} + doc_meta["source_index"] = idx + 1 + if reference_pattern: + doc_meta["referenced"] = idx in referenced_idxs + referenced_docs.append(replace(doc, meta=doc_meta)) answer_string = AnswerBuilder._extract_answer_string(extracted_reply, pattern) answer = GeneratedAnswer( @@ -184,9 +242,9 @@ class AnswerBuilder: return "" @staticmethod - def _extract_reference_idxs(reply: str, reference_pattern: str) -> list[int]: + def _extract_reference_idxs(reply: str, reference_pattern: str) -> set[int]: document_idxs = re.findall(reference_pattern, reply) - return [int(idx) - 1 for idx in document_idxs] + return {int(idx) - 1 for idx in document_idxs} @staticmethod def _check_num_groups_in_regex(pattern: str): diff --git a/releasenotes/notes/answer-builder-refactoring-dc02d2285aebea32.yaml b/releasenotes/notes/answer-builder-refactoring-dc02d2285aebea32.yaml new file mode 100644 index 000000000..50f94fe51 --- /dev/null +++ b/releasenotes/notes/answer-builder-refactoring-dc02d2285aebea32.yaml @@ -0,0 +1,10 @@ +--- +features: + - | + The `AnswerBuilder` component now exposes a new parameter `return_only_referenced_documents` (default: True) that + controls if only documents referenced in the `replies` are returned. + Returned documents include two new fields in the `meta` dictionary: + - `source_index`: the 1-based index of the document in the input list + - `referenced`: a boolean value indicating if the document was referenced in the `replies` (only present + if the `reference_pattern` parameter is provided). + These additions make it easier to display references and other sources within a RAG pipeline. diff --git a/test/components/builders/test_answer_builder.py b/test/components/builders/test_answer_builder.py index e85e4d71c..d42e17730 100644 --- a/test/components/builders/test_answer_builder.py +++ b/test/components/builders/test_answer_builder.py @@ -143,6 +143,30 @@ class TestAnswerBuilder: assert answers[0].query == "test query" assert len(answers[0].documents) == 1 assert answers[0].documents[0].content == "test doc 2" + assert answers[0].documents[0].meta["referenced"] is True + assert answers[0].documents[0].meta["source_index"] == 2 + + def test_run_with_documents_with_reference_pattern_return_all_documents(self): + component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]", return_only_referenced_documents=False) + output = component.run( + query="test query", + replies=["Answer: AnswerString[2]"], + meta=[{}], + documents=[Document(content="test doc 1"), Document(content="test doc 2")], + ) + answers = output["answers"] + assert len(answers) == 1 + assert answers[0].data == "Answer: AnswerString[2]" + _check_metadata_excluding_all_messages(answers[0].meta, {}) + assert "all_messages" in answers[0].meta + assert answers[0].query == "test query" + assert len(answers[0].documents) == 2 + assert answers[0].documents[0].content == "test doc 1" + assert answers[0].documents[0].meta["referenced"] is False + assert answers[0].documents[0].meta["source_index"] == 1 + assert answers[0].documents[1].content == "test doc 2" + assert answers[0].documents[1].meta["referenced"] is True + assert answers[0].documents[1].meta["source_index"] == 2 def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog): component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]") diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index 38c19e197..b25c03e0b 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -951,11 +951,13 @@ def pipeline_that_has_a_component_with_only_default_inputs(pipeline_class): id="413dccdf51a54cca75b7ed2eddac04e6e58560bd2f0caf4106a3efc023fe3651", content="Paris is the capital of France", score=1.600237583702734, + meta={"source_index": 1}, ), Document( id="a4a874fc2ef75015da7924d709fbdd2430e46a8e94add6e0f26cd32c1c03435d", content="Rome is the capital of Italy", score=1.2536639934227616, + meta={"source_index": 2}, ), ], meta={"all_messages": ["Paris"]},