mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 14:08:27 +00:00
fix: Look through all streaming chunks for tools calls (#8829)
* Look through all streaming chunks for tools calls * Add reno note * mypy fixes * Improve robustness * Don't concatenate, use the last value * typing * Update releasenotes/notes/improve-tool-call-chunk-search-986474e814af17a7.yaml Co-authored-by: David S. Batista <dsbatista@gmail.com> * Small refactoring * isort --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
parent
b5d2854b93
commit
a7c1661f13
@ -337,46 +337,46 @@ class OpenAIChatGenerator:
|
||||
finish_reason=meta["finish_reason"],
|
||||
)
|
||||
|
||||
def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
|
||||
def _convert_streaming_chunks_to_chat_message(
|
||||
self, chunk: ChatCompletionChunk, chunks: List[StreamingChunk]
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
Connects the streaming chunks into a single ChatMessage.
|
||||
|
||||
:param chunk: The last chunk returned by the OpenAI API.
|
||||
:param chunks: The list of all `StreamingChunk` objects.
|
||||
"""
|
||||
|
||||
text = "".join([chunk.content for chunk in chunks])
|
||||
tool_calls = []
|
||||
|
||||
# if it's a tool call , we need to build the payload dict from all the chunks
|
||||
if bool(chunks[0].meta.get("tool_calls")):
|
||||
tools_len = len(chunks[0].meta.get("tool_calls", []))
|
||||
# Process tool calls if present in any chunk
|
||||
tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by ID
|
||||
for chunk_payload in chunks:
|
||||
tool_calls_meta = chunk_payload.meta.get("tool_calls")
|
||||
if tool_calls_meta:
|
||||
for delta in tool_calls_meta:
|
||||
if not delta.id in tool_call_data:
|
||||
tool_call_data[delta.id] = {"id": delta.id, "name": "", "arguments": ""}
|
||||
|
||||
payloads = [{"arguments": "", "name": ""} for _ in range(tools_len)]
|
||||
for chunk_payload in chunks:
|
||||
deltas = chunk_payload.meta.get("tool_calls") or []
|
||||
|
||||
# deltas is a list of ChoiceDeltaToolCall or ChoiceDeltaFunctionCall
|
||||
for i, delta in enumerate(deltas):
|
||||
payloads[i]["id"] = delta.id or payloads[i].get("id", "")
|
||||
if delta.function:
|
||||
payloads[i]["name"] += delta.function.name or ""
|
||||
payloads[i]["arguments"] += delta.function.arguments or ""
|
||||
if delta.function.name:
|
||||
tool_call_data[delta.id]["name"] = delta.function.name
|
||||
if delta.function.arguments:
|
||||
tool_call_data[delta.id]["arguments"] = delta.function.arguments
|
||||
|
||||
for payload in payloads:
|
||||
arguments_str = payload["arguments"]
|
||||
try:
|
||||
arguments = json.loads(arguments_str)
|
||||
tool_calls.append(ToolCall(id=payload["id"], tool_name=payload["name"], arguments=arguments))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
|
||||
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
|
||||
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
|
||||
_id=payload["id"],
|
||||
_name=payload["name"],
|
||||
_arguments=arguments_str,
|
||||
)
|
||||
# Convert accumulated tool call data into ToolCall objects
|
||||
for call_data in tool_call_data.values():
|
||||
try:
|
||||
arguments = json.loads(call_data["arguments"])
|
||||
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Skipping malformed tool call due to invalid JSON. Set `tools_strict=True` for valid JSON. "
|
||||
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
|
||||
_id=call_data["id"],
|
||||
_name=call_data["name"],
|
||||
_arguments=call_data["arguments"],
|
||||
)
|
||||
|
||||
meta = {
|
||||
"model": chunk.model,
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Improved `OpenAIChatGenerator` streaming response tool call processing: The logic now scans all chunks to correctly identify the first chunk with tool calls, ensuring accurate payload construction and preventing errors when tool call data isn’t confined to the initial chunk.
|
||||
@ -570,3 +570,94 @@ class TestOpenAIChatGenerator:
|
||||
assert tool_call.tool_name == "weather"
|
||||
assert tool_call.arguments == {"city": "Paris"}
|
||||
assert message.meta["finish_reason"] == "tool_calls"
|
||||
|
||||
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
|
||||
"""Test that tool calls can be found in any chunk of the streaming response."""
|
||||
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
||||
|
||||
# Create a list of chunks where tool calls appear in different positions
|
||||
chunks = [
|
||||
# First chunk has no tool calls
|
||||
StreamingChunk("Hello! Let me help you with that. "),
|
||||
# Second chunk has the first tool call
|
||||
StreamingChunk("I'll check the weather. "),
|
||||
# Third chunk has no tool calls
|
||||
StreamingChunk("Now, let me check another city. "),
|
||||
# Fourth chunk has another tool call
|
||||
StreamingChunk(""),
|
||||
]
|
||||
|
||||
# Add received_at to first chunk
|
||||
chunks[0].meta["received_at"] = "2024-02-07T14:21:47.446186Z"
|
||||
|
||||
# Add tool calls meta to second chunk
|
||||
chunks[1].meta["tool_calls"] = [
|
||||
chat_completion_chunk.ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
|
||||
name="get_weather", arguments='{"city": "London"}'
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Add tool calls meta to fourth chunk
|
||||
chunks[3].meta["tool_calls"] = [
|
||||
chat_completion_chunk.ChoiceDeltaToolCall(
|
||||
index=0, # Same index as first tool call since it's the same function
|
||||
id="call_1", # Same ID as first tool call since it's the same function
|
||||
type="function",
|
||||
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
|
||||
name="get_weather", arguments='{"city": "Paris"}'
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Add required meta information to the last chunk
|
||||
chunks[-1].meta.update({"model": "gpt-4", "index": 0, "finish_reason": "tool_calls"})
|
||||
|
||||
# Create the final ChatCompletionChunk that would be passed as the first parameter
|
||||
final_chunk = ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
model="gpt-4",
|
||||
object="chat.completion.chunk",
|
||||
created=1234567890,
|
||||
choices=[
|
||||
chat_completion_chunk.Choice(
|
||||
index=0,
|
||||
finish_reason="tool_calls",
|
||||
delta=chat_completion_chunk.ChoiceDelta(
|
||||
tool_calls=[
|
||||
chat_completion_chunk.ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
|
||||
name="get_weather", arguments='{"city": "Paris"}'
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Convert chunks to a chat message
|
||||
result = component._convert_streaming_chunks_to_chat_message(final_chunk, chunks)
|
||||
|
||||
# Verify the content is concatenated correctly
|
||||
expected_text = "Hello! Let me help you with that. I'll check the weather. Now, let me check another city. "
|
||||
assert result.text == expected_text
|
||||
|
||||
# Verify both tool calls were found and processed
|
||||
assert len(result.tool_calls) == 1 # Now we expect only one tool call since they have the same ID
|
||||
assert result.tool_calls[0].id == "call_1"
|
||||
assert result.tool_calls[0].tool_name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"city": "Paris"} # The last value overwrites the previous one
|
||||
|
||||
# Verify meta information
|
||||
assert result.meta["model"] == "gpt-4"
|
||||
assert result.meta["finish_reason"] == "tool_calls"
|
||||
assert result.meta["index"] == 0
|
||||
assert result.meta["completion_start_time"] == "2024-02-07T14:21:47.446186Z"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user