feat: enable streaming ToolCall/Result from Agent (#9290)

* Testing solutions for streaming

* Remove unused methods

* Add fixes

* Update docstrings

* add release notes and test

* PR comments

* add a new util function

* Adjust emit_tool_info

* PR comments

* Remove emit function, add streaming for tool_call


---------

Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
This commit is contained in:
Amna Mubashar 2025-05-05 19:23:44 +05:00 committed by GitHub
parent 7db719981d
commit 64f384b52d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 120 additions and 12 deletions

View File

@ -15,7 +15,7 @@ from haystack.core.serialization import component_to_dict
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
from haystack.dataclasses.state_utils import merge_lists
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
@ -84,6 +84,7 @@ class Agent:
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called.
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
"""
# Check if chat_generator supports tools parameter
@ -201,9 +202,8 @@ class Agent:
def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]:
"""Prepare inputs for the chat generator."""
generator_inputs: Dict[str, Any] = {"tools": self.tools}
selected_callback = streaming_callback or self.streaming_callback
if selected_callback is not None:
generator_inputs["streaming_callback"] = selected_callback
if streaming_callback is not None:
generator_inputs["streaming_callback"] = streaming_callback
return generator_inputs
def _create_agent_span(self) -> Any:
@ -229,6 +229,7 @@ class Agent:
:param messages: List of chat messages to process
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called.
:param kwargs: Additional data to pass to the State schema used by the Agent.
The keys must match the schema defined in the Agent's `state_schema`.
:return: Dictionary containing messages and outputs matching the defined output types
@ -239,6 +240,10 @@ class Agent:
if self.system_prompt is not None:
messages = [ChatMessage.from_system(self.system_prompt)] + messages
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)
input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
state = State(schema=self.state_schema, data=kwargs)
@ -271,7 +276,7 @@ class Agent:
tool_invoker_result = Pipeline._run_component(
component_name="tool_invoker",
component={"instance": self._tool_invoker},
inputs={"messages": llm_messages, "state": state},
inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback},
component_visits=component_visits,
parent_span=span,
)
@ -312,6 +317,7 @@ class Agent:
:param messages: List of chat messages to process
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called.
:param kwargs: Additional data to pass to the State schema used by the Agent.
The keys must match the schema defined in the Agent's `state_schema`.
:return: Dictionary containing messages and outputs matching the defined output types
@ -322,6 +328,10 @@ class Agent:
if self.system_prompt is not None:
messages = [ChatMessage.from_system(self.system_prompt)] + messages
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
)
input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
state = State(schema=self.state_schema, data=kwargs)

View File

@ -279,6 +279,7 @@ class OpenAIChatGenerator:
chat_completion, # type: ignore
streaming_callback, # type: ignore
)
else:
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
completions = [
@ -355,6 +356,7 @@ class OpenAIChatGenerator:
chat_completion, # type: ignore
streaming_callback, # type: ignore
)
else:
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
completions = [

View File

@ -2,13 +2,47 @@
#
# SPDX-License-Identifier: Apache-2.0
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from haystack.dataclasses import StreamingChunk
def print_streaming_chunk(chunk: StreamingChunk) -> None:
"""
Default callback function for streaming responses.
Callback function to handle and display streaming output chunks.
Prints the tokens of the first completion to stdout as soon as they are received
This function processes a `StreamingChunk` object by:
- Printing tool call metadata (if any), including function names and arguments, as they arrive.
- Printing tool call results when available.
- Printing the main content (e.g., text tokens) of the chunk as it is received.
The function outputs data directly to stdout and flushes output buffers to ensure immediate display during
streaming.
:param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and
tool results.
"""
print(chunk.content, flush=True, end="")
# Print tool call metadata if available (from ChatGenerator)
if chunk.meta.get("tool_calls"):
for tool_call in chunk.meta["tool_calls"]:
if isinstance(tool_call, ChoiceDeltaToolCall) and tool_call.function:
# print the tool name
if tool_call.function.name and not tool_call.function.arguments:
print("[TOOL CALL]\n", flush=True, end="")
print(f"Tool: {tool_call.function.name} ", flush=True, end="")
# print the tool arguments
if tool_call.function.arguments:
if tool_call.function.arguments.startswith("{"):
print("\nArguments: ", flush=True, end="")
print(tool_call.function.arguments, flush=True, end="")
if tool_call.function.arguments.endswith("}"):
print("\n\n", flush=True, end="")
# Print tool call results if available (from ToolInvoker)
if chunk.meta.get("tool_result"):
print(f"[TOOL RESULT]\n{chunk.meta['tool_result']}\n\n", flush=True, end="")
# Print the main content of the chunk (from ChatGenerator)
if chunk.content:
print(chunk.content, flush=True, end="")

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.core.component.sockets import Sockets
from haystack.dataclasses import ChatMessage, State, ToolCall
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
from haystack.tools import (
ComponentTool,
Tool,
@ -159,6 +160,7 @@ class ToolInvoker:
tools: Union[List[Tool], Toolset],
raise_on_failure: bool = True,
convert_result_to_json_string: bool = False,
streaming_callback: Optional[StreamingCallbackT] = None,
):
"""
Initialize the ToolInvoker component.
@ -173,6 +175,10 @@ class ToolInvoker:
:param convert_result_to_json_string:
If True, the tool invocation result will be converted to a string using `json.dumps`.
If False, the tool invocation result will be converted to a string using `str`.
:param streaming_callback:
A callback function that will be called to emit tool results.
Note that the result is only emitted once it becomes available it is not
streamed incrementally in real time.
:raises ValueError:
If no tools are provided or if duplicate tool names are found.
"""
@ -181,6 +187,7 @@ class ToolInvoker:
# could be a Toolset instance or a list of Tools
self.tools = tools
self.streaming_callback = streaming_callback
# Convert Toolset to list for internal use
if isinstance(tools, Toolset):
@ -272,7 +279,6 @@ class ToolInvoker:
except StringConversionError as conversion_error:
# If _handle_error re-raises, this properly preserves the chain
raise conversion_error from e
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
@staticmethod
@ -358,13 +364,21 @@ class ToolInvoker:
state.set(state_key, output_value, handler_override=handler)
@component.output_types(tool_messages=List[ChatMessage], state=State)
def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dict[str, Any]:
def run(
self,
messages: List[ChatMessage],
state: Optional[State] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
) -> Dict[str, Any]:
"""
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
:param messages:
A list of ChatMessage objects.
:param state: The runtime state that should be used by the tools.
:param streaming_callback: A callback function that will be called to emit tool results.
Note that the result is only emitted once it becomes available it is not
streamed incrementally in real time.
:returns:
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
Each ChatMessage objects wraps the result of a tool invocation.
@ -383,6 +397,9 @@ class ToolInvoker:
# Only keep messages with tool calls
messages_with_tool_calls = [message for message in messages if message.tool_calls]
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)
tool_messages = []
for message in messages_with_tool_calls:
@ -406,6 +423,7 @@ class ToolInvoker:
# 2) Invoke the tool
try:
tool_result = tool_to_invoke.invoke(**final_args)
except ToolInvocationError as e:
error_message = self._handle_error(e)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
@ -434,6 +452,11 @@ class ToolInvoker:
)
)
if streaming_callback is not None:
streaming_callback(
StreamingChunk(content="", meta={"tool_result": tool_result, "tool_call": tool_call})
)
return {"tool_messages": tool_messages, "state": state}
def to_dict(self) -> Dict[str, Any]:

View File

@ -0,0 +1,9 @@
---
features:
- |
Add a `streaming_callback` parameter to `ToolInvoker` to enable streaming of tool results.
Note that tool_result is emitted only after the tool execution completes and is not streamed incrementally.
- Update `print_streaming_chunk` to print ToolCall information if it is present in the chunk's metadata.
- Update `Agent` to forward the `streaming_callback` to `ToolInvoker` to emit tool results during tool invocation.

View File

@ -24,6 +24,7 @@ from haystack.core.component.types import OutputSocket
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.chat_message import ChatRole, TextContent
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool, ComponentTool
from haystack.tools.toolset import Toolset
from haystack.utils import serialize_callable, Secret
@ -778,6 +779,25 @@ class TestAgent:
assert [isinstance(reply, ChatMessage) for reply in result["messages"]]
assert "Hello from run_async" in result["messages"][1].text
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
def test_agent_streaming_with_tool_call(self, monkeypatch, weather_tool):
chat_generator = OpenAIChatGenerator()
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
result = agent.run(
[ChatMessage.from_user("What's the weather in Paris?")], streaming_callback=streaming_callback
)
assert result is not None
assert result["messages"] is not None
assert streaming_callback_called
class TestAgentTracing:
def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool):

View File

@ -11,6 +11,7 @@ from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
from haystack.dataclasses.state import State
from haystack.tools import ComponentTool, Tool, Toolset
from haystack.tools.errors import ToolInvocationError
from haystack.dataclasses import StreamingChunk
def weather_function(location):
@ -162,14 +163,23 @@ class TestToolInvoker:
args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state)
assert args == {"location": "Paris"}
def test_run(self, invoker):
def test_run_with_streaming_callback(self, invoker):
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
message = ChatMessage.from_assistant(tool_calls=[tool_call])
result = invoker.run(messages=[message])
result = invoker.run(messages=[message], streaming_callback=streaming_callback)
assert "tool_messages" in result
assert len(result["tool_messages"]) == 1
# check we called the streaming callback
assert streaming_callback_called
tool_message = result["tool_messages"][0]
assert isinstance(tool_message, ChatMessage)
assert tool_message.is_from(ChatRole.TOOL)