fix: Look through all streaming chunks for tools calls (#8829)

* Look through all streaming chunks for tools calls

* Add reno note

* mypy fixes

* Improve robustness

* Don't concatenate, use the last value

* typing

* Update releasenotes/notes/improve-tool-call-chunk-search-986474e814af17a7.yaml

Co-authored-by: David S. Batista <dsbatista@gmail.com>

* Small refactoring

* isort

---------

Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2025-02-11 13:25:39 +01:00 committed by GitHub
parent b5d2854b93
commit a7c1661f13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 123 additions and 28 deletions

View File

@ -337,46 +337,46 @@ class OpenAIChatGenerator:
finish_reason=meta["finish_reason"],
)
def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
def _convert_streaming_chunks_to_chat_message(
self, chunk: ChatCompletionChunk, chunks: List[StreamingChunk]
) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
:param chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all `StreamingChunk` objects.
"""
text = "".join([chunk.content for chunk in chunks])
tool_calls = []
# if it's a tool call , we need to build the payload dict from all the chunks
if bool(chunks[0].meta.get("tool_calls")):
tools_len = len(chunks[0].meta.get("tool_calls", []))
# Process tool calls if present in any chunk
tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by ID
for chunk_payload in chunks:
tool_calls_meta = chunk_payload.meta.get("tool_calls")
if tool_calls_meta:
for delta in tool_calls_meta:
if not delta.id in tool_call_data:
tool_call_data[delta.id] = {"id": delta.id, "name": "", "arguments": ""}
payloads = [{"arguments": "", "name": ""} for _ in range(tools_len)]
for chunk_payload in chunks:
deltas = chunk_payload.meta.get("tool_calls") or []
# deltas is a list of ChoiceDeltaToolCall or ChoiceDeltaFunctionCall
for i, delta in enumerate(deltas):
payloads[i]["id"] = delta.id or payloads[i].get("id", "")
if delta.function:
payloads[i]["name"] += delta.function.name or ""
payloads[i]["arguments"] += delta.function.arguments or ""
if delta.function.name:
tool_call_data[delta.id]["name"] = delta.function.name
if delta.function.arguments:
tool_call_data[delta.id]["arguments"] = delta.function.arguments
for payload in payloads:
arguments_str = payload["arguments"]
try:
arguments = json.loads(arguments_str)
tool_calls.append(ToolCall(id=payload["id"], tool_name=payload["name"], arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=payload["id"],
_name=payload["name"],
_arguments=arguments_str,
)
# Convert accumulated tool call data into ToolCall objects
for call_data in tool_call_data.values():
try:
arguments = json.loads(call_data["arguments"])
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"Skipping malformed tool call due to invalid JSON. Set `tools_strict=True` for valid JSON. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=call_data["id"],
_name=call_data["name"],
_arguments=call_data["arguments"],
)
meta = {
"model": chunk.model,

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Improved `OpenAIChatGenerator` streaming response tool call processing: The logic now scans all chunks to correctly identify the first chunk with tool calls, ensuring accurate payload construction and preventing errors when tool call data isnt confined to the initial chunk.

View File

@ -570,3 +570,94 @@ class TestOpenAIChatGenerator:
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"""Test that tool calls can be found in any chunk of the streaming response."""
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
# Create a list of chunks where tool calls appear in different positions
chunks = [
# First chunk has no tool calls
StreamingChunk("Hello! Let me help you with that. "),
# Second chunk has the first tool call
StreamingChunk("I'll check the weather. "),
# Third chunk has no tool calls
StreamingChunk("Now, let me check another city. "),
# Fourth chunk has another tool call
StreamingChunk(""),
]
# Add received_at to first chunk
chunks[0].meta["received_at"] = "2024-02-07T14:21:47.446186Z"
# Add tool calls meta to second chunk
chunks[1].meta["tool_calls"] = [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0,
id="call_1",
type="function",
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
name="get_weather", arguments='{"city": "London"}'
),
)
]
# Add tool calls meta to fourth chunk
chunks[3].meta["tool_calls"] = [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0, # Same index as first tool call since it's the same function
id="call_1", # Same ID as first tool call since it's the same function
type="function",
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
name="get_weather", arguments='{"city": "Paris"}'
),
)
]
# Add required meta information to the last chunk
chunks[-1].meta.update({"model": "gpt-4", "index": 0, "finish_reason": "tool_calls"})
# Create the final ChatCompletionChunk that would be passed as the first parameter
final_chunk = ChatCompletionChunk(
id="chatcmpl-123",
model="gpt-4",
object="chat.completion.chunk",
created=1234567890,
choices=[
chat_completion_chunk.Choice(
index=0,
finish_reason="tool_calls",
delta=chat_completion_chunk.ChoiceDelta(
tool_calls=[
chat_completion_chunk.ChoiceDeltaToolCall(
index=0,
id="call_1",
type="function",
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
name="get_weather", arguments='{"city": "Paris"}'
),
)
]
),
)
],
)
# Convert chunks to a chat message
result = component._convert_streaming_chunks_to_chat_message(final_chunk, chunks)
# Verify the content is concatenated correctly
expected_text = "Hello! Let me help you with that. I'll check the weather. Now, let me check another city. "
assert result.text == expected_text
# Verify both tool calls were found and processed
assert len(result.tool_calls) == 1 # Now we expect only one tool call since they have the same ID
assert result.tool_calls[0].id == "call_1"
assert result.tool_calls[0].tool_name == "get_weather"
assert result.tool_calls[0].arguments == {"city": "Paris"} # The last value overwrites the previous one
# Verify meta information
assert result.meta["model"] == "gpt-4"
assert result.meta["finish_reason"] == "tool_calls"
assert result.meta["index"] == 0
assert result.meta["completion_start_time"] == "2024-02-07T14:21:47.446186Z"