refactor: Update to StreamingChunk, better index setting and change tool_call to tool_calls (#9525)

* Fixes to setting StreamingChunk.index properly and refactoring tests for conversion

* Make _convert_chat_completion_chunk_to_streaming_chunk a member of OpenAIChatGenerator so we can overwrite it in integrations that inherit from it

* Fixes

* Modify streaming chunk to accept a list of tool call deltas.

* Fix tests

* Fix mypy and update original reno

* Undo change

* Update conversion to return a single streaming chunk

* update to print streaming chunk

* Fix types

* PR comments
This commit is contained in:
Sebastian Husch Lee 2025-06-23 10:14:25 +02:00 committed by GitHub
parent f911459647
commit ec371387f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 619 additions and 392 deletions

View File

@ -427,12 +427,11 @@ class OpenAIChatGenerator:
chunks: List[StreamingChunk] = []
for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk(
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(
chunk=chunk, previous_chunks=chunks, component_info=component_info
)
for chunk_delta in chunk_deltas:
chunks.append(chunk_delta)
callback(chunk_delta)
chunks.append(chunk_delta)
callback(chunk_delta)
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
async def _handle_async_stream_response(
@ -442,12 +441,11 @@ class OpenAIChatGenerator:
chunks: List[StreamingChunk] = []
async for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk(
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(
chunk=chunk, previous_chunks=chunks, component_info=component_info
)
for chunk_delta in chunk_deltas:
chunks.append(chunk_delta)
await callback(chunk_delta)
chunks.append(chunk_delta)
await callback(chunk_delta)
return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
@ -509,7 +507,7 @@ def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice:
def _convert_chat_completion_chunk_to_streaming_chunk(
chunk: ChatCompletionChunk, previous_chunks: List[StreamingChunk], component_info: Optional[ComponentInfo] = None
) -> List[StreamingChunk]:
) -> StreamingChunk:
"""
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.
@ -521,61 +519,68 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
:returns:
A list of StreamingChunk objects representing the content of the chunk from the OpenAI API.
"""
# Choices is empty on the very first chunk which provides role information (e.g. "assistant").
# It is also empty if include_usage is set to True where the usage information is returned.
# On very first chunk so len(previous_chunks) == 0, the Choices field only provides role info (e.g. "assistant")
# Choices is empty if include_usage is set to True where the usage information is returned.
if len(chunk.choices) == 0:
return [
StreamingChunk(
content="",
component_info=component_info,
# Index is None since it's only set to an int when a content block is present
index=None,
meta={
"model": chunk.model,
"received_at": datetime.now().isoformat(),
"usage": _serialize_usage(chunk.usage),
},
)
]
return StreamingChunk(
content="",
component_info=component_info,
# Index is None since it's only set to an int when a content block is present
index=None,
meta={
"model": chunk.model,
"received_at": datetime.now().isoformat(),
"usage": _serialize_usage(chunk.usage),
},
)
choice: ChunkChoice = chunk.choices[0]
content = choice.delta.content or ""
# create a list of ToolCallDelta objects from the tool calls
if choice.delta.tool_calls:
chunk_messages = []
tool_calls_deltas = []
for tool_call in choice.delta.tool_calls:
function = tool_call.function
chunk_message = StreamingChunk(
content=content,
# We adopt the tool_call.index as the index of the chunk
component_info=component_info,
index=tool_call.index,
tool_call=ToolCallDelta(
tool_calls_deltas.append(
ToolCallDelta(
index=tool_call.index,
id=tool_call.id,
tool_name=function.name if function else None,
arguments=function.arguments if function and function.arguments else None,
),
start=function.name is not None if function else False,
meta={
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
"usage": _serialize_usage(chunk.usage),
},
)
)
chunk_messages.append(chunk_message)
return chunk_messages
chunk_message = StreamingChunk(
content=choice.delta.content or "",
component_info=component_info,
# We adopt the first tool_calls_deltas.index as the overall index of the chunk.
index=tool_calls_deltas[0].index,
tool_calls=tool_calls_deltas,
start=tool_calls_deltas[0].tool_name is not None,
meta={
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
"usage": _serialize_usage(chunk.usage),
},
)
return chunk_message
chunk_message = StreamingChunk(
content=content,
component_info=component_info,
# On very first chunk the choice field only provides role info (e.g. "assistant") so we set index to None
# We set all chunks missing the content field to index of None. E.g. can happen if chunk only contains finish
# reason.
if choice.delta.content is None or choice.delta.role is not None:
resolved_index = None
else:
# We set the index to be 0 since if text content is being streamed then no tool calls are being streamed
# NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like
# Anthropic Claude
index=0,
resolved_index = 0
chunk_message = StreamingChunk(
content=choice.delta.content or "",
component_info=component_info,
index=resolved_index,
# The first chunk is always a start message chunk that only contains role information, so if we reach here
# and previous_chunks is length 1 then this is the start of text content.
start=len(previous_chunks) == 1,
@ -588,7 +593,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
"usage": _serialize_usage(chunk.usage),
},
)
return [chunk_message]
return chunk_message
def _serialize_usage(usage):

View File

@ -249,7 +249,7 @@ class OpenAIGenerator:
chunk=chunk, # type: ignore
previous_chunks=chunks,
component_info=component_info,
)[0]
)
chunks.append(chunk_delta)
streaming_callback(chunk_delta)

View File

@ -31,17 +31,24 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
print("\n\n", flush=True, end="")
## Tool Call streaming
if chunk.tool_call:
# If chunk.start is True indicates beginning of a tool call
# Also presence of chunk.tool_call.name indicates the start of a tool call too
if chunk.start:
print("[TOOL CALL]\n", flush=True, end="")
print(f"Tool: {chunk.tool_call.tool_name} ", flush=True, end="")
print("\nArguments: ", flush=True, end="")
if chunk.tool_calls:
# Typically, if there are multiple tool calls in the chunk this means that the tool calls are fully formed and
# not just a delta.
for tool_call in chunk.tool_calls:
# If chunk.start is True indicates beginning of a tool call
# Also presence of tool_call.tool_name indicates the start of a tool call too
if chunk.start:
# If there is more than one tool call in the chunk, we print two new lines to separate them
# We know there is more than one tool call if the index of the tool call is greater than the index of
# the chunk.
if chunk.index and tool_call.index > chunk.index:
print("\n\n", flush=True, end="")
# print the tool arguments
if chunk.tool_call.arguments:
print(chunk.tool_call.arguments, flush=True, end="")
print("[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="")
# print the tool arguments
if tool_call.arguments:
print(tool_call.arguments, flush=True, end="")
## Tool Call Result streaming
# Print tool call results if available (from ToolInvoker)
@ -76,39 +83,41 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
# Process tool calls if present in any chunk
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
for chunk in chunks:
if chunk.tool_call:
if chunk.tool_calls:
# We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if
# tool_call is present
assert chunk.index is not None
# We use the index of the chunk to track the tool call across chunks since the ID is not always provided
if chunk.index not in tool_call_data:
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""}
for tool_call in chunk.tool_calls:
# We use the index of the tool_call to track the tool call across chunks since the ID is not always
# provided
if tool_call.index not in tool_call_data:
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""}
# Save the ID if present
if chunk.tool_call.id is not None:
tool_call_data[chunk.index]["id"] = chunk.tool_call.id
# Save the ID if present
if tool_call.id is not None:
tool_call_data[chunk.index]["id"] = tool_call.id
if chunk.tool_call.tool_name is not None:
tool_call_data[chunk.index]["name"] += chunk.tool_call.tool_name
if chunk.tool_call.arguments is not None:
tool_call_data[chunk.index]["arguments"] += chunk.tool_call.arguments
if tool_call.tool_name is not None:
tool_call_data[chunk.index]["name"] += tool_call.tool_name
if tool_call.arguments is not None:
tool_call_data[chunk.index]["arguments"] += tool_call.arguments
# Convert accumulated tool call data into ToolCall objects
sorted_keys = sorted(tool_call_data.keys())
for key in sorted_keys:
tool_call = tool_call_data[key]
tool_call_dict = tool_call_data[key]
try:
arguments = json.loads(tool_call["arguments"])
tool_calls.append(ToolCall(id=tool_call["id"], tool_name=tool_call["name"], arguments=arguments))
arguments = json.loads(tool_call_dict["arguments"])
tool_calls.append(ToolCall(id=tool_call_dict["id"], tool_name=tool_call_dict["name"], arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=tool_call["id"],
_name=tool_call["name"],
_arguments=tool_call["arguments"],
_id=tool_call_dict["id"],
_name=tool_call_dict["name"],
_arguments=tool_call_dict["arguments"],
)
# finish_reason can appear in different places so we look for the last one

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Dict, Literal, Optional, Union, overload
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union, overload
from haystack.core.component import Component
from haystack.dataclasses.chat_message import ToolCallResult
@ -15,11 +15,13 @@ class ToolCallDelta:
"""
Represents a Tool call prepared by the model, usually contained in an assistant message.
:param index: The index of the Tool call in the list of Tool calls.
:param tool_name: The name of the Tool to call.
:param arguments: Either the full arguments in JSON format or a delta of the arguments.
:param id: The ID of the Tool call.
"""
index: int
tool_name: Optional[str] = field(default=None)
arguments: Optional[str] = field(default=None)
id: Optional[str] = field(default=None) # noqa: A003
@ -71,7 +73,8 @@ class StreamingChunk:
:param component_info: A `ComponentInfo` object containing information about the component that generated the chunk,
such as the component name and type.
:param index: An optional integer index representing which content block this chunk belongs to.
:param tool_call: An optional ToolCallDelta object representing a tool call associated with the message chunk.
:param tool_calls: An optional list of ToolCallDelta object representing a tool call associated with the message
chunk.
:param tool_call_result: An optional ToolCallResult object representing the result of a tool call.
:param start: A boolean indicating whether this chunk marks the start of a content block.
"""
@ -80,21 +83,21 @@ class StreamingChunk:
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
component_info: Optional[ComponentInfo] = field(default=None)
index: Optional[int] = field(default=None)
tool_call: Optional[ToolCallDelta] = field(default=None)
tool_calls: Optional[List[ToolCallDelta]] = field(default=None)
tool_call_result: Optional[ToolCallResult] = field(default=None)
start: bool = field(default=False)
def __post_init__(self):
fields_set = sum(bool(x) for x in (self.content, self.tool_call, self.tool_call_result))
fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result))
if fields_set > 1:
raise ValueError(
"Only one of `content`, `tool_call`, or `tool_call_result` may be set in a StreamingChunk. "
f"Got content: '{self.content}', tool_call: '{self.tool_call}', "
f"Got content: '{self.content}', tool_call: '{self.tool_calls}', "
f"tool_call_result: '{self.tool_call_result}'"
)
# NOTE: We don't enforce this for self.content otherwise it would be a breaking change
if (self.tool_call or self.tool_call_result) and self.index is None:
if (self.tool_calls or self.tool_call_result) and self.index is None:
raise ValueError("If `tool_call`, or `tool_call_result` is set, `index` must also be set.")

View File

@ -1,8 +1,8 @@
---
features:
- |
Updated StreamingChunk to add the fields `tool_call`, `tool_call_result`, `index`, and `start` to make it easier to format the stream in a streaming callback.
- Added new dataclass ToolCallDelta for the `StreamingChunk.tool_call` field to reflect that the arguments can be a string delta.
Updated StreamingChunk to add the fields `tool_calls`, `tool_call_result`, `index`, and `start` to make it easier to format the stream in a streaming callback.
- Added new dataclass ToolCallDelta for the `StreamingChunk.tool_calls` field to reflect that the arguments can be a string delta.
- Updated `print_streaming_chunk` and `_convert_streaming_chunks_to_chat_message` utility methods to use these new fields. This especially improves the formatting when using `print_streaming_chunk` with Agent.
- Updated `OpenAIGenerator`, `OpenAIChatGenerator`, `HuggingFaceAPIGenerator`, `HuggingFaceAPIChatGenerator`, `HuggingFaceLocalGenerator` and `HuggingFaceLocalChatGenerator` to follow the new dataclasses.
- Updated `ToolInvoker` to follow the StreamingChunk dataclass.

View File

@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch, MagicMock
from unittest.mock import patch, ANY, MagicMock
import pytest
@ -21,7 +21,7 @@ from openai.types.chat import chat_completion_chunk
from haystack import component
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import StreamingChunk, ToolCallDelta
from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import ComponentTool, Tool
@ -598,295 +598,6 @@ class TestOpenAIChatGenerator:
assert message.meta["finish_reason"] == "tool_calls"
assert message.meta["usage"]["completion_tokens"] == 47
def test_handle_stream_response(self):
openai_chunks = [
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(role="assistant"), index=0)],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="call_zcvlnVaTeJWRjLAFfYxX69z4",
function=ChoiceDeltaToolCallFunction(arguments="", name="weather"),
type="function",
)
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"ci'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='ty": '))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"Paris'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"}'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=1,
id="call_C88m67V16CrETq6jbNXjdZI9",
function=ChoiceDeltaToolCallFunction(arguments="", name="weather"),
type="function",
)
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='{"ci'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='ty": '))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='"Berli'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='n"}'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
usage=CompletionUsage(
completion_tokens=42,
prompt_tokens=282,
total_tokens=324,
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
),
),
]
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
result = component._handle_stream_response(openai_chunks, callback=lambda chunk: None)[0] # type: ignore
assert not result.texts
assert not result.text
# Verify both tool calls were found and processed
assert len(result.tool_calls) == 2
assert result.tool_calls[0].id == "call_zcvlnVaTeJWRjLAFfYxX69z4"
assert result.tool_calls[0].tool_name == "weather"
assert result.tool_calls[0].arguments == {"city": "Paris"}
assert result.tool_calls[1].id == "call_C88m67V16CrETq6jbNXjdZI9"
assert result.tool_calls[1].tool_name == "weather"
assert result.tool_calls[1].arguments == {"city": "Berlin"}
# Verify meta information
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
assert result.meta["finish_reason"] == "tool_calls"
assert result.meta["index"] == 0
assert result.meta["completion_start_time"] is not None
assert result.meta["usage"] == {
"completion_tokens": 42,
"prompt_tokens": 282,
"total_tokens": 324,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
}
def test_convert_usage_chunk_to_streaming_chunk(self):
chunk = ChatCompletionChunk(
id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw",
choices=[],
created=1742207200,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_06737a9306",
usage=CompletionUsage(
completion_tokens=8,
prompt_tokens=13,
total_tokens=21,
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
),
)
result = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, previous_chunks=[])[0]
assert result.content == ""
assert result.start is False
assert result.tool_call is None
assert result.tool_call_result is None
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
assert result.meta["received_at"] is not None
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
@ -1032,3 +743,497 @@ class TestOpenAIChatGenerator:
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"
@pytest.fixture
def chat_completion_chunks():
return [
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(role="assistant"), index=0)],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="call_zcvlnVaTeJWRjLAFfYxX69z4",
function=ChoiceDeltaToolCallFunction(arguments="", name="weather"),
type="function",
)
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"ci'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='ty": '))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"Paris'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"}'))]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=1,
id="call_C88m67V16CrETq6jbNXjdZI9",
function=ChoiceDeltaToolCallFunction(arguments="", name="weather"),
type="function",
)
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='{"ci'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='ty": '))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='"Berli'))
]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[
chat_completion_chunk.Choice(
delta=ChoiceDelta(
tool_calls=[ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='n"}'))]
),
index=0,
)
],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
),
ChatCompletionChunk(
id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F",
choices=[],
created=1747834733,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_54eb4bd693",
usage=CompletionUsage(
completion_tokens=42,
prompt_tokens=282,
total_tokens=324,
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
),
),
]
@pytest.fixture
def streaming_chunks():
return [
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": None,
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [
ChoiceDeltaToolCall(
index=0,
id="call_zcvlnVaTeJWRjLAFfYxX69z4",
function=ChoiceDeltaToolCallFunction(arguments="", name="weather"),
type="function",
)
],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=0,
tool_calls=[ToolCallDelta(tool_name="weather", id="call_zcvlnVaTeJWRjLAFfYxX69z4", index=0)],
start=True,
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"ci'))],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=0,
tool_calls=[ToolCallDelta(arguments='{"ci', index=0)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='ty": '))],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=0,
tool_calls=[ToolCallDelta(arguments='ty": ', index=0)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"Paris'))],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=0,
tool_calls=[ToolCallDelta(arguments='"Paris', index=0)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"}'))],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=0,
tool_calls=[ToolCallDelta(arguments='"}', index=0)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [
ChoiceDeltaToolCall(
index=1,
id="call_C88m67V16CrETq6jbNXjdZI9",
function=ChoiceDeltaToolCallFunction(arguments="", name="weather"),
type="function",
)
],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=1,
tool_calls=[ToolCallDelta(tool_name="weather", id="call_C88m67V16CrETq6jbNXjdZI9", index=1)],
start=True,
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='{"ci'))],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=1,
tool_calls=[ToolCallDelta(arguments='{"ci', index=1)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='ty": '))],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=1,
tool_calls=[ToolCallDelta(arguments='ty": ', index=1)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='"Berli'))],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=1,
tool_calls=[ToolCallDelta(arguments='"Berli', index=1)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='n"}'))],
"finish_reason": None,
"received_at": ANY,
"usage": None,
},
index=1,
tool_calls=[ToolCallDelta(arguments='n"}', index=1)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": None,
"finish_reason": "tool_calls",
"received_at": ANY,
"usage": None,
},
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"received_at": ANY,
"usage": {
"completion_tokens": 42,
"prompt_tokens": 282,
"total_tokens": 324,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
},
},
),
]
class TestChatCompletionChunkConversion:
def test_convert_chat_completion_chunk_to_streaming_chunk(self, chat_completion_chunks, streaming_chunks):
previous_chunks = []
for openai_chunk, haystack_chunk in zip(chat_completion_chunks, streaming_chunks):
stream_chunk = _convert_chat_completion_chunk_to_streaming_chunk(
chunk=openai_chunk, previous_chunks=previous_chunks
)
assert stream_chunk == haystack_chunk
previous_chunks.append(openai_chunk)
def test_handle_stream_response(self, chat_completion_chunks):
openai_chunks = chat_completion_chunks
comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
result = comp._handle_stream_response(openai_chunks, callback=lambda chunk: None)[0] # type: ignore
assert not result.texts
assert not result.text
# Verify both tool calls were found and processed
assert len(result.tool_calls) == 2
assert result.tool_calls[0].id == "call_zcvlnVaTeJWRjLAFfYxX69z4"
assert result.tool_calls[0].tool_name == "weather"
assert result.tool_calls[0].arguments == {"city": "Paris"}
assert result.tool_calls[1].id == "call_C88m67V16CrETq6jbNXjdZI9"
assert result.tool_calls[1].tool_name == "weather"
assert result.tool_calls[1].arguments == {"city": "Berlin"}
# Verify meta information
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
assert result.meta["finish_reason"] == "tool_calls"
assert result.meta["index"] == 0
assert result.meta["completion_start_time"] is not None
assert result.meta["usage"] == {
"completion_tokens": 42,
"prompt_tokens": 282,
"total_tokens": 324,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
}
def test_convert_usage_chunk_to_streaming_chunk(self):
usage_chunk = ChatCompletionChunk(
id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw",
choices=[],
created=1742207200,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="default",
system_fingerprint="fp_06737a9306",
usage=CompletionUsage(
completion_tokens=8,
prompt_tokens=13,
total_tokens=21,
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
),
)
result = _convert_chat_completion_chunk_to_streaming_chunk(chunk=usage_chunk, previous_chunks=[])
assert result.content == ""
assert result.start is False
assert result.tool_calls is None
assert result.tool_call_result is None
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
assert result.meta["received_at"] is not None

View File

@ -42,7 +42,9 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
component_info=ComponentInfo(name="test", type="test"),
index=0,
start=True,
tool_call=ToolCallDelta(id="call_ZOj5l67zhZOx6jqjg7ATQwb6", tool_name="rag_pipeline_tool", arguments=""),
tool_calls=[
ToolCallDelta(id="call_ZOj5l67zhZOx6jqjg7ATQwb6", tool_name="rag_pipeline_tool", arguments="", index=0)
],
),
StreamingChunk(
content="",
@ -59,7 +61,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_call=ToolCallDelta(arguments='{"qu'),
tool_calls=[ToolCallDelta(arguments='{"qu', index=0)],
),
StreamingChunk(
content="",
@ -76,7 +78,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_call=ToolCallDelta(arguments='ery":'),
tool_calls=[ToolCallDelta(arguments='ery":', index=0)],
),
StreamingChunk(
content="",
@ -93,7 +95,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_call=ToolCallDelta(arguments=' "Wher'),
tool_calls=[ToolCallDelta(arguments=' "Wher', index=0)],
),
StreamingChunk(
content="",
@ -110,7 +112,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_call=ToolCallDelta(arguments="e do"),
tool_calls=[ToolCallDelta(arguments="e do", index=0)],
),
StreamingChunk(
content="",
@ -127,7 +129,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_call=ToolCallDelta(arguments="es Ma"),
tool_calls=[ToolCallDelta(arguments="es Ma", index=0)],
),
StreamingChunk(
content="",
@ -144,7 +146,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_call=ToolCallDelta(arguments="rk liv"),
tool_calls=[ToolCallDelta(arguments="rk liv", index=0)],
),
StreamingChunk(
content="",
@ -161,7 +163,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_call=ToolCallDelta(arguments='e?"}'),
tool_calls=[ToolCallDelta(arguments='e?"}', index=0)],
),
StreamingChunk(
content="",
@ -182,7 +184,9 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
component_info=ComponentInfo(name="test", type="test"),
index=1,
start=True,
tool_call=ToolCallDelta(id="call_STxsYY69wVOvxWqopAt3uWTB", tool_name="get_weather", arguments=""),
tool_calls=[
ToolCallDelta(id="call_STxsYY69wVOvxWqopAt3uWTB", tool_name="get_weather", arguments="", index=1)
],
),
StreamingChunk(
content="",
@ -199,7 +203,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=1,
tool_call=ToolCallDelta(arguments='{"ci'),
tool_calls=[ToolCallDelta(arguments='{"ci', index=1)],
),
StreamingChunk(
content="",
@ -216,7 +220,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=1,
tool_call=ToolCallDelta(arguments='ty": '),
tool_calls=[ToolCallDelta(arguments='ty": ', index=1)],
),
StreamingChunk(
content="",
@ -233,7 +237,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=1,
tool_call=ToolCallDelta(arguments='"Berli'),
tool_calls=[ToolCallDelta(arguments='"Berli', index=1)],
),
StreamingChunk(
content="",
@ -250,7 +254,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
},
component_info=ComponentInfo(name="test", type="test"),
index=1,
tool_call=ToolCallDelta(arguments='n"}'),
tool_calls=[ToolCallDelta(arguments='n"}', index=1)],
),
StreamingChunk(
content="",

View File

@ -58,7 +58,7 @@ def test_create_chunk_with_content_and_tool_call():
StreamingChunk(
content="Test content",
meta={"key": "value"},
tool_call=ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}'),
tool_calls=[ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0)],
)
@ -92,12 +92,13 @@ def test_component_info_from_component_with_name_from_pipeline():
def test_tool_call_delta():
tool_call = ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}')
tool_call = ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0)
assert tool_call.id == "123"
assert tool_call.tool_name == "test_tool"
assert tool_call.arguments == '{"arg1": "value1"}'
assert tool_call.index == 0
def test_tool_call_delta_with_missing_fields():
with pytest.raises(ValueError):
_ = ToolCallDelta(id="123")
_ = ToolCallDelta(id="123", index=0)