mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 12:07:04 +00:00
feat: Allow Connection of ChatGenerator to AnswerBuilder (#7897)
* initial implementation * add support for meta and add ChatMessage tests * explictly cast types for mypy and update reno * leave inputs unchanged avoiding side effects --------- Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
61de1dcc61
commit
e92a0e4beb
@ -3,9 +3,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack import Document, GeneratedAnswer, component, logging
|
||||
from haystack.dataclasses.chat_message import ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -56,7 +57,7 @@ class AnswerBuilder:
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
replies: List[str],
|
||||
replies: Union[List[str], List[ChatMessage]],
|
||||
meta: Optional[List[Dict[str, Any]]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
pattern: Optional[str] = None,
|
||||
@ -68,7 +69,7 @@ class AnswerBuilder:
|
||||
:param query:
|
||||
The query used in the prompts for the Generator.
|
||||
:param replies:
|
||||
The output of the Generator.
|
||||
The output of the Generator. Can be a list of strings or a list of ChatMessage objects.
|
||||
:param meta:
|
||||
The metadata returned by the Generator. If not specified, the generated answer will contain no metadata.
|
||||
:param documents:
|
||||
@ -103,14 +104,15 @@ class AnswerBuilder:
|
||||
|
||||
pattern = pattern or self.pattern
|
||||
reference_pattern = reference_pattern or self.reference_pattern
|
||||
|
||||
all_answers = []
|
||||
for reply, metadata in zip(replies, meta):
|
||||
# Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is
|
||||
extracted_reply: str = reply.content if isinstance(reply, ChatMessage) else reply # type: ignore
|
||||
extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else metadata
|
||||
referenced_docs = []
|
||||
if documents:
|
||||
reference_idxs = []
|
||||
if reference_pattern:
|
||||
reference_idxs = AnswerBuilder._extract_reference_idxs(reply, reference_pattern)
|
||||
reference_idxs = AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern)
|
||||
else:
|
||||
reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)]
|
||||
|
||||
@ -122,8 +124,10 @@ class AnswerBuilder:
|
||||
"Document index '{index}' referenced in Generator output is out of range. ", index=idx + 1
|
||||
)
|
||||
|
||||
answer_string = AnswerBuilder._extract_answer_string(reply, pattern)
|
||||
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, meta=metadata)
|
||||
answer_string = AnswerBuilder._extract_answer_string(extracted_reply, pattern)
|
||||
answer = GeneratedAnswer(
|
||||
data=answer_string, query=query, documents=referenced_docs, meta=extracted_metadata
|
||||
)
|
||||
all_answers.append(answer)
|
||||
|
||||
return {"answers": all_answers}
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
AnswerBuilder can now accept ChatMessages as input in addition to strings. When using ChatMessages, metadata will be automatically added to the answer.
|
||||
@ -5,8 +5,9 @@ import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import GeneratedAnswer, Document
|
||||
from haystack import Document, GeneratedAnswer
|
||||
from haystack.components.builders.answer_builder import AnswerBuilder
|
||||
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
|
||||
|
||||
|
||||
class TestAnswerBuilder:
|
||||
@ -150,3 +151,124 @@ class TestAnswerBuilder:
|
||||
assert len(answers[0].documents) == 2
|
||||
assert answers[0].documents[0].content == "test doc 2"
|
||||
assert answers[0].documents[1].content == "test doc 3"
|
||||
|
||||
def test_run_with_chat_message_replies_without_pattern(self):
|
||||
component = AnswerBuilder()
|
||||
replies = [
|
||||
ChatMessage(
|
||||
content="Answer: AnswerString",
|
||||
role=ChatRole.ASSISTANT,
|
||||
name=None,
|
||||
meta={
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
|
||||
},
|
||||
)
|
||||
]
|
||||
output = component.run(query="test query", replies=replies, meta=[{}])
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "Answer: AnswerString"
|
||||
assert answers[0].meta == {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
|
||||
}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
|
||||
def test_run_with_chat_message_replies_with_pattern(self):
|
||||
component = AnswerBuilder(pattern=r"Answer: (.*)")
|
||||
replies = [
|
||||
ChatMessage(
|
||||
content="Answer: AnswerString",
|
||||
role=ChatRole.ASSISTANT,
|
||||
name=None,
|
||||
meta={
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
|
||||
},
|
||||
)
|
||||
]
|
||||
output = component.run(query="test query", replies=replies, meta=[{}])
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "AnswerString"
|
||||
assert answers[0].meta == {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
|
||||
}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
|
||||
def test_run_with_chat_message_replies_with_documents(self):
|
||||
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
|
||||
replies = [
|
||||
ChatMessage(
|
||||
content="Answer: AnswerString[2]",
|
||||
role=ChatRole.ASSISTANT,
|
||||
name=None,
|
||||
meta={
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
|
||||
},
|
||||
)
|
||||
]
|
||||
output = component.run(
|
||||
query="test query",
|
||||
replies=replies,
|
||||
meta=[{}],
|
||||
documents=[Document(content="test doc 1"), Document(content="test doc 2")],
|
||||
)
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "Answer: AnswerString[2]"
|
||||
assert answers[0].meta == {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
|
||||
}
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 1
|
||||
assert answers[0].documents[0].content == "test doc 2"
|
||||
|
||||
def test_run_with_chat_message_replies_with_pattern_set_at_runtime(self):
|
||||
component = AnswerBuilder(pattern="unused pattern")
|
||||
replies = [
|
||||
ChatMessage(
|
||||
content="Answer: AnswerString",
|
||||
role=ChatRole.ASSISTANT,
|
||||
name=None,
|
||||
meta={
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
|
||||
},
|
||||
)
|
||||
]
|
||||
output = component.run(query="test query", replies=replies, meta=[{}], pattern=r"Answer: (.*)")
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "AnswerString"
|
||||
assert answers[0].meta == {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
|
||||
}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user