mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-05 19:36:55 +00:00
feat: improve AnswerBuilder to support showing RAG references (#9933)
* draft * improve * refactor * improvs + usage ex * relnote * pipeline test fix
This commit is contained in:
parent
ce260b14c6
commit
35fb6c6f01
@ -3,6 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from dataclasses import replace
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from haystack import Document, GeneratedAnswer, component, logging
|
||||
@ -29,10 +30,50 @@ class AnswerBuilder:
|
||||
builder = AnswerBuilder(pattern="Answer: (.*)")
|
||||
builder.run(query="What's the answer?", replies=["This is an argument. Answer: This is the answer."])
|
||||
```
|
||||
|
||||
### Usage example with documents and reference pattern
|
||||
|
||||
```python
|
||||
from haystack import Document
|
||||
from haystack.components.builders import AnswerBuilder
|
||||
|
||||
replies = ["The capital of France is Paris [2]."]
|
||||
|
||||
docs = [
|
||||
Document(content="Berlin is the capital of Germany."),
|
||||
Document(content="Paris is the capital of France."),
|
||||
Document(content="Rome is the capital of Italy."),
|
||||
]
|
||||
|
||||
builder = AnswerBuilder(reference_pattern="\\[(\\d+)\\]", return_only_referenced_documents=False)
|
||||
result = builder.run(query="What is the capital of France?", replies=replies, documents=docs)["answers"][0]
|
||||
|
||||
print(f"Answer: {result.data}")
|
||||
print("References:")
|
||||
for doc in result.documents:
|
||||
if doc.meta["referenced"]:
|
||||
print(f"[{doc.meta['source_index']}] {doc.content}")
|
||||
print("Other sources:")
|
||||
for doc in result.documents:
|
||||
if not doc.meta["referenced"]:
|
||||
print(f"[{doc.meta['source_index']}] {doc.content}")
|
||||
|
||||
# Answer: The capital of France is Paris
|
||||
# References:
|
||||
# [2] Paris is the capital of France.
|
||||
# Other sources:
|
||||
# [1] Berlin is the capital of Germany.
|
||||
# [3] Rome is the capital of Italy.
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None, last_message_only: bool = False
|
||||
self,
|
||||
pattern: Optional[str] = None,
|
||||
reference_pattern: Optional[str] = None,
|
||||
last_message_only: bool = False,
|
||||
*,
|
||||
return_only_referenced_documents: bool = True,
|
||||
):
|
||||
"""
|
||||
Creates an instance of the AnswerBuilder component.
|
||||
@ -49,13 +90,20 @@ class AnswerBuilder:
|
||||
|
||||
:param reference_pattern:
|
||||
The regular expression pattern used for parsing the document references.
|
||||
If not specified, no parsing is done, and all documents are referenced.
|
||||
If not specified, no parsing is done, and all documents are returned.
|
||||
References need to be specified as indices of the input documents and start at [1].
|
||||
Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]".
|
||||
If this parameter is provided, documents metadata will contain a "referenced" key with a boolean value.
|
||||
|
||||
:param last_message_only:
|
||||
If False (default value), all messages are used as the answer.
|
||||
If True, only the last message is used as the answer.
|
||||
|
||||
:param return_only_referenced_documents:
|
||||
To be used in conjunction with `reference_pattern`.
|
||||
If True (default value), only the documents that were actually referenced in `replies` are returned.
|
||||
If False, all documents are returned.
|
||||
If `reference_pattern` is not provided, this parameter has no effect, and all documents are returned.
|
||||
"""
|
||||
if pattern:
|
||||
AnswerBuilder._check_num_groups_in_regex(pattern)
|
||||
@ -63,6 +111,7 @@ class AnswerBuilder:
|
||||
self.pattern = pattern
|
||||
self.reference_pattern = reference_pattern
|
||||
self.last_message_only = last_message_only
|
||||
self.return_only_referenced_documents = return_only_referenced_documents
|
||||
|
||||
@component.output_types(answers=list[GeneratedAnswer])
|
||||
def run( # pylint: disable=too-many-positional-arguments
|
||||
@ -85,9 +134,12 @@ class AnswerBuilder:
|
||||
The metadata returned by the Generator. If not specified, the generated answer will contain no metadata.
|
||||
:param documents:
|
||||
The documents used as the Generator inputs. If specified, they are added to
|
||||
the`GeneratedAnswer` objects.
|
||||
If both `documents` and `reference_pattern` are specified, the documents referenced in the
|
||||
Generator output are extracted from the input documents and added to the `GeneratedAnswer` objects.
|
||||
the `GeneratedAnswer` objects.
|
||||
Each Document.meta includes a "source_index" key, representing its 1-based position in the input list.
|
||||
When `reference_pattern` is provided:
|
||||
- "referenced" key is added to the Document.meta, indicating if the document was referenced in the output.
|
||||
- `return_only_referenced_documents` init parameter controls if all or only referenced documents are
|
||||
returned.
|
||||
:param pattern:
|
||||
The regular expression pattern to extract the answer text from the Generator.
|
||||
If not specified, the entire response is used as the answer.
|
||||
@ -100,7 +152,7 @@ class AnswerBuilder:
|
||||
"this is an argument. Answer: this is an answer".
|
||||
:param reference_pattern:
|
||||
The regular expression pattern used for parsing the document references.
|
||||
If not specified, no parsing is done, and all documents are referenced.
|
||||
If not specified, no parsing is done, and all documents are returned.
|
||||
References need to be specified as indices of the input documents and start at [1].
|
||||
Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]".
|
||||
|
||||
@ -117,21 +169,14 @@ class AnswerBuilder:
|
||||
|
||||
pattern = pattern or self.pattern
|
||||
reference_pattern = reference_pattern or self.reference_pattern
|
||||
|
||||
replies_to_iterate = replies[-1:] if self.last_message_only and replies else replies
|
||||
meta_to_iterate = meta[-1:] if self.last_message_only and meta else meta
|
||||
|
||||
all_answers = []
|
||||
|
||||
replies_to_iterate = replies
|
||||
meta_to_iterate = meta
|
||||
|
||||
if self.last_message_only and replies:
|
||||
replies_to_iterate = replies[-1:]
|
||||
meta_to_iterate = meta[-1:]
|
||||
|
||||
for reply, given_metadata in zip(replies_to_iterate, meta_to_iterate):
|
||||
# Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is
|
||||
if isinstance(reply, ChatMessage):
|
||||
extracted_reply = reply.text or ""
|
||||
else:
|
||||
extracted_reply = str(reply)
|
||||
extracted_reply = reply.text or "" if isinstance(reply, ChatMessage) else str(reply)
|
||||
extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else {}
|
||||
|
||||
extracted_metadata = {**extracted_metadata, **given_metadata}
|
||||
@ -139,18 +184,31 @@ class AnswerBuilder:
|
||||
|
||||
referenced_docs = []
|
||||
if documents:
|
||||
if reference_pattern:
|
||||
reference_idxs = AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern)
|
||||
else:
|
||||
reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)]
|
||||
referenced_idxs = (
|
||||
AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern)
|
||||
if reference_pattern
|
||||
else set()
|
||||
)
|
||||
doc_idxs = (
|
||||
referenced_idxs
|
||||
if reference_pattern and self.return_only_referenced_documents
|
||||
else set(range(len(documents)))
|
||||
)
|
||||
|
||||
for idx in reference_idxs:
|
||||
for idx in doc_idxs:
|
||||
try:
|
||||
referenced_docs.append(documents[idx])
|
||||
doc = documents[idx]
|
||||
except IndexError:
|
||||
logger.warning(
|
||||
"Document index '{index}' referenced in Generator output is out of range. ", index=idx + 1
|
||||
)
|
||||
continue
|
||||
|
||||
doc_meta: dict[str, Any] = doc.meta or {}
|
||||
doc_meta["source_index"] = idx + 1
|
||||
if reference_pattern:
|
||||
doc_meta["referenced"] = idx in referenced_idxs
|
||||
referenced_docs.append(replace(doc, meta=doc_meta))
|
||||
|
||||
answer_string = AnswerBuilder._extract_answer_string(extracted_reply, pattern)
|
||||
answer = GeneratedAnswer(
|
||||
@ -184,9 +242,9 @@ class AnswerBuilder:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_reference_idxs(reply: str, reference_pattern: str) -> list[int]:
|
||||
def _extract_reference_idxs(reply: str, reference_pattern: str) -> set[int]:
|
||||
document_idxs = re.findall(reference_pattern, reply)
|
||||
return [int(idx) - 1 for idx in document_idxs]
|
||||
return {int(idx) - 1 for idx in document_idxs}
|
||||
|
||||
@staticmethod
|
||||
def _check_num_groups_in_regex(pattern: str):
|
||||
|
||||
@ -0,0 +1,10 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
The `AnswerBuilder` component now exposes a new parameter `return_only_referenced_documents` (default: True) that
|
||||
controls if only documents referenced in the `replies` are returned.
|
||||
Returned documents include two new fields in the `meta` dictionary:
|
||||
- `source_index`: the 1-based index of the document in the input list
|
||||
- `referenced`: a boolean value indicating if the document was referenced in the `replies` (only present
|
||||
if the `reference_pattern` parameter is provided).
|
||||
These additions make it easier to display references and other sources within a RAG pipeline.
|
||||
@ -143,6 +143,30 @@ class TestAnswerBuilder:
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 1
|
||||
assert answers[0].documents[0].content == "test doc 2"
|
||||
assert answers[0].documents[0].meta["referenced"] is True
|
||||
assert answers[0].documents[0].meta["source_index"] == 2
|
||||
|
||||
def test_run_with_documents_with_reference_pattern_return_all_documents(self):
|
||||
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]", return_only_referenced_documents=False)
|
||||
output = component.run(
|
||||
query="test query",
|
||||
replies=["Answer: AnswerString[2]"],
|
||||
meta=[{}],
|
||||
documents=[Document(content="test doc 1"), Document(content="test doc 2")],
|
||||
)
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "Answer: AnswerString[2]"
|
||||
_check_metadata_excluding_all_messages(answers[0].meta, {})
|
||||
assert "all_messages" in answers[0].meta
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 2
|
||||
assert answers[0].documents[0].content == "test doc 1"
|
||||
assert answers[0].documents[0].meta["referenced"] is False
|
||||
assert answers[0].documents[0].meta["source_index"] == 1
|
||||
assert answers[0].documents[1].content == "test doc 2"
|
||||
assert answers[0].documents[1].meta["referenced"] is True
|
||||
assert answers[0].documents[1].meta["source_index"] == 2
|
||||
|
||||
def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog):
|
||||
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
|
||||
|
||||
@ -951,11 +951,13 @@ def pipeline_that_has_a_component_with_only_default_inputs(pipeline_class):
|
||||
id="413dccdf51a54cca75b7ed2eddac04e6e58560bd2f0caf4106a3efc023fe3651",
|
||||
content="Paris is the capital of France",
|
||||
score=1.600237583702734,
|
||||
meta={"source_index": 1},
|
||||
),
|
||||
Document(
|
||||
id="a4a874fc2ef75015da7924d709fbdd2430e46a8e94add6e0f26cd32c1c03435d",
|
||||
content="Rome is the capital of Italy",
|
||||
score=1.2536639934227616,
|
||||
meta={"source_index": 2},
|
||||
),
|
||||
],
|
||||
meta={"all_messages": ["Paris"]},
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user