mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
fix: Fix _convert_streaming_chunks_to_chat_message (#9566)
* Fix conversion * Add reno * Add unit test
This commit is contained in:
parent
c54a68ab63
commit
fc64884819
@ -84,24 +84,20 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
|
||||
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
|
||||
for chunk in chunks:
|
||||
if chunk.tool_calls:
|
||||
# We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if
|
||||
# tool_call is present
|
||||
assert chunk.index is not None
|
||||
|
||||
for tool_call in chunk.tool_calls:
|
||||
# We use the index of the tool_call to track the tool call across chunks since the ID is not always
|
||||
# provided
|
||||
if tool_call.index not in tool_call_data:
|
||||
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""}
|
||||
tool_call_data[tool_call.index] = {"id": "", "name": "", "arguments": ""}
|
||||
|
||||
# Save the ID if present
|
||||
if tool_call.id is not None:
|
||||
tool_call_data[chunk.index]["id"] = tool_call.id
|
||||
tool_call_data[tool_call.index]["id"] = tool_call.id
|
||||
|
||||
if tool_call.tool_name is not None:
|
||||
tool_call_data[chunk.index]["name"] += tool_call.tool_name
|
||||
tool_call_data[tool_call.index]["name"] += tool_call.tool_name
|
||||
if tool_call.arguments is not None:
|
||||
tool_call_data[chunk.index]["arguments"] += tool_call.arguments
|
||||
tool_call_data[tool_call.index]["arguments"] += tool_call.arguments
|
||||
|
||||
# Convert accumulated tool call data into ToolCall objects
|
||||
sorted_keys = sorted(tool_call_data.keys())
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fix `_convert_streaming_chunks_to_chat_message` which is used to convert Haystack StreamingChunks into a Haystack ChatMessage. This fixes the scenario where one StreamingChunk contains two ToolCallDetlas in StreamingChunk.tool_calls. With this fix this correctly saves both ToolCallDeltas whereas before they were overwriting each other. This only occurs with some LLM providers like Mistral (and not OpenAI) due to how the provider returns tool calls.
|
||||
@ -325,3 +325,63 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
|
||||
},
|
||||
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
def test_convert_streaming_chunk_to_chat_message_two_tool_calls_in_same_chunk():
|
||||
chunks = [
|
||||
StreamingChunk(
|
||||
content="",
|
||||
meta={
|
||||
"model": "mistral-small-latest",
|
||||
"index": 0,
|
||||
"tool_calls": None,
|
||||
"finish_reason": None,
|
||||
"usage": None,
|
||||
},
|
||||
component_info=ComponentInfo(
|
||||
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
|
||||
name=None,
|
||||
),
|
||||
),
|
||||
StreamingChunk(
|
||||
content="",
|
||||
meta={
|
||||
"model": "mistral-small-latest",
|
||||
"index": 0,
|
||||
"finish_reason": "tool_calls",
|
||||
"usage": {
|
||||
"completion_tokens": 35,
|
||||
"prompt_tokens": 77,
|
||||
"total_tokens": 112,
|
||||
"completion_tokens_details": None,
|
||||
"prompt_tokens_details": None,
|
||||
},
|
||||
},
|
||||
component_info=ComponentInfo(
|
||||
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
|
||||
name=None,
|
||||
),
|
||||
index=0,
|
||||
tool_calls=[
|
||||
ToolCallDelta(index=0, tool_name="weather", arguments='{"city": "Paris"}', id="FL1FFlqUG"),
|
||||
ToolCallDelta(index=1, tool_name="weather", arguments='{"city": "Berlin"}', id="xSuhp66iB"),
|
||||
],
|
||||
start=True,
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
]
|
||||
|
||||
# Convert chunks to a chat message
|
||||
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
|
||||
|
||||
assert not result.texts
|
||||
assert not result.text
|
||||
|
||||
# Verify both tool calls were found and processed
|
||||
assert len(result.tool_calls) == 2
|
||||
assert result.tool_calls[0].id == "FL1FFlqUG"
|
||||
assert result.tool_calls[0].tool_name == "weather"
|
||||
assert result.tool_calls[0].arguments == {"city": "Paris"}
|
||||
assert result.tool_calls[1].id == "xSuhp66iB"
|
||||
assert result.tool_calls[1].tool_name == "weather"
|
||||
assert result.tool_calls[1].arguments == {"city": "Berlin"}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user