feat: Allow Connection of ChatGenerator to AnswerBuilder (#7897)

* initial implementation

* add support for meta and add ChatMessage tests

* explictly cast types for mypy and update reno

* leave inputs unchanged avoiding side effects

---------

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Ulises M 2024-07-04 23:21:53 -07:00 committed by GitHub
parent 61de1dcc61
commit e92a0e4beb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 139 additions and 9 deletions

View File

@ -3,9 +3,10 @@
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from haystack import Document, GeneratedAnswer, component, logging
from haystack.dataclasses.chat_message import ChatMessage
logger = logging.getLogger(__name__)
@ -56,7 +57,7 @@ class AnswerBuilder:
def run(
self,
query: str,
replies: List[str],
replies: Union[List[str], List[ChatMessage]],
meta: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[Document]] = None,
pattern: Optional[str] = None,
@ -68,7 +69,7 @@ class AnswerBuilder:
:param query:
The query used in the prompts for the Generator.
:param replies:
The output of the Generator.
The output of the Generator. Can be a list of strings or a list of ChatMessage objects.
:param meta:
The metadata returned by the Generator. If not specified, the generated answer will contain no metadata.
:param documents:
@ -103,14 +104,15 @@ class AnswerBuilder:
pattern = pattern or self.pattern
reference_pattern = reference_pattern or self.reference_pattern
all_answers = []
for reply, metadata in zip(replies, meta):
# Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is
extracted_reply: str = reply.content if isinstance(reply, ChatMessage) else reply # type: ignore
extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else metadata
referenced_docs = []
if documents:
reference_idxs = []
if reference_pattern:
reference_idxs = AnswerBuilder._extract_reference_idxs(reply, reference_pattern)
reference_idxs = AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern)
else:
reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)]
@ -122,8 +124,10 @@ class AnswerBuilder:
"Document index '{index}' referenced in Generator output is out of range. ", index=idx + 1
)
answer_string = AnswerBuilder._extract_answer_string(reply, pattern)
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, meta=metadata)
answer_string = AnswerBuilder._extract_answer_string(extracted_reply, pattern)
answer = GeneratedAnswer(
data=answer_string, query=query, documents=referenced_docs, meta=extracted_metadata
)
all_answers.append(answer)
return {"answers": all_answers}

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
AnswerBuilder can now accept ChatMessages as input in addition to strings. When using ChatMessages, metadata will be automatically added to the answer.

View File

@ -5,8 +5,9 @@ import logging
import pytest
from haystack import GeneratedAnswer, Document
from haystack import Document, GeneratedAnswer
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
class TestAnswerBuilder:
@ -150,3 +151,124 @@ class TestAnswerBuilder:
assert len(answers[0].documents) == 2
assert answers[0].documents[0].content == "test doc 2"
assert answers[0].documents[1].content == "test doc 3"
def test_run_with_chat_message_replies_without_pattern(self):
component = AnswerBuilder()
replies = [
ChatMessage(
content="Answer: AnswerString",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "gpt-3.5-turbo-0613",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
},
)
]
output = component.run(query="test query", replies=replies, meta=[{}])
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString"
assert answers[0].meta == {
"model": "gpt-3.5-turbo-0613",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
def test_run_with_chat_message_replies_with_pattern(self):
component = AnswerBuilder(pattern=r"Answer: (.*)")
replies = [
ChatMessage(
content="Answer: AnswerString",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "gpt-3.5-turbo-0613",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
},
)
]
output = component.run(query="test query", replies=replies, meta=[{}])
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "AnswerString"
assert answers[0].meta == {
"model": "gpt-3.5-turbo-0613",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
def test_run_with_chat_message_replies_with_documents(self):
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
replies = [
ChatMessage(
content="Answer: AnswerString[2]",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "gpt-3.5-turbo-0613",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
},
)
]
output = component.run(
query="test query",
replies=replies,
meta=[{}],
documents=[Document(content="test doc 1"), Document(content="test doc 2")],
)
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString[2]"
assert answers[0].meta == {
"model": "gpt-3.5-turbo-0613",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
assert answers[0].query == "test query"
assert len(answers[0].documents) == 1
assert answers[0].documents[0].content == "test doc 2"
def test_run_with_chat_message_replies_with_pattern_set_at_runtime(self):
component = AnswerBuilder(pattern="unused pattern")
replies = [
ChatMessage(
content="Answer: AnswerString",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "gpt-3.5-turbo-0613",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
},
)
]
output = component.run(query="test query", replies=replies, meta=[{}], pattern=r"Answer: (.*)")
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "AnswerString"
assert answers[0].meta == {
"model": "gpt-3.5-turbo-0613",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)