diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index ebebcae34..7d978ef5a 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -15,7 +15,7 @@ from haystack.core.serialization import component_to_dict from haystack.dataclasses import ChatMessage from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema from haystack.dataclasses.state_utils import merge_lists -from haystack.dataclasses.streaming_chunk import StreamingCallbackT +from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from haystack.utils.deserialization import deserialize_chatgenerator_inplace @@ -84,6 +84,7 @@ class Agent: :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? If set to False, the exception will be turned into a chat message and passed to the LLM. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + The same callback can be configured to emit tool results when a tool is called. :raises TypeError: If the chat_generator does not support tools parameter in its run method. """ # Check if chat_generator supports tools parameter @@ -201,9 +202,8 @@ class Agent: def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]: """Prepare inputs for the chat generator.""" generator_inputs: Dict[str, Any] = {"tools": self.tools} - selected_callback = streaming_callback or self.streaming_callback - if selected_callback is not None: - generator_inputs["streaming_callback"] = selected_callback + if streaming_callback is not None: + generator_inputs["streaming_callback"] = streaming_callback return generator_inputs def _create_agent_span(self) -> Any: @@ -229,6 +229,7 @@ class Agent: :param messages: List of chat messages to process :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + The same callback can be configured to emit tool results when a tool is called. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :return: Dictionary containing messages and outputs matching the defined output types @@ -239,6 +240,10 @@ class Agent: if self.system_prompt is not None: messages = [ChatMessage.from_system(self.system_prompt)] + messages + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False + ) + input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs}) state = State(schema=self.state_schema, data=kwargs) @@ -271,7 +276,7 @@ class Agent: tool_invoker_result = Pipeline._run_component( component_name="tool_invoker", component={"instance": self._tool_invoker}, - inputs={"messages": llm_messages, "state": state}, + inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback}, component_visits=component_visits, parent_span=span, ) @@ -312,6 +317,7 @@ class Agent: :param messages: List of chat messages to process :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + The same callback can be configured to emit tool results when a tool is called. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :return: Dictionary containing messages and outputs matching the defined output types @@ -322,6 +328,10 @@ class Agent: if self.system_prompt is not None: messages = [ChatMessage.from_system(self.system_prompt)] + messages + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True + ) + input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs}) state = State(schema=self.state_schema, data=kwargs) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index b5df7caaa..693b12be0 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -279,6 +279,7 @@ class OpenAIChatGenerator: chat_completion, # type: ignore streaming_callback, # type: ignore ) + else: assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request." completions = [ @@ -355,6 +356,7 @@ class OpenAIChatGenerator: chat_completion, # type: ignore streaming_callback, # type: ignore ) + else: assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request." completions = [ diff --git a/haystack/components/generators/utils.py b/haystack/components/generators/utils.py index 4b31f0d06..54ff68bcb 100644 --- a/haystack/components/generators/utils.py +++ b/haystack/components/generators/utils.py @@ -2,13 +2,47 @@ # # SPDX-License-Identifier: Apache-2.0 +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall + from haystack.dataclasses import StreamingChunk def print_streaming_chunk(chunk: StreamingChunk) -> None: """ - Default callback function for streaming responses. + Callback function to handle and display streaming output chunks. - Prints the tokens of the first completion to stdout as soon as they are received + This function processes a `StreamingChunk` object by: + - Printing tool call metadata (if any), including function names and arguments, as they arrive. + - Printing tool call results when available. + - Printing the main content (e.g., text tokens) of the chunk as it is received. + + The function outputs data directly to stdout and flushes output buffers to ensure immediate display during + streaming. + + :param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and + tool results. """ - print(chunk.content, flush=True, end="") + # Print tool call metadata if available (from ChatGenerator) + if chunk.meta.get("tool_calls"): + for tool_call in chunk.meta["tool_calls"]: + if isinstance(tool_call, ChoiceDeltaToolCall) and tool_call.function: + # print the tool name + if tool_call.function.name and not tool_call.function.arguments: + print("[TOOL CALL]\n", flush=True, end="") + print(f"Tool: {tool_call.function.name} ", flush=True, end="") + + # print the tool arguments + if tool_call.function.arguments: + if tool_call.function.arguments.startswith("{"): + print("\nArguments: ", flush=True, end="") + print(tool_call.function.arguments, flush=True, end="") + if tool_call.function.arguments.endswith("}"): + print("\n\n", flush=True, end="") + + # Print tool call results if available (from ToolInvoker) + if chunk.meta.get("tool_result"): + print(f"[TOOL RESULT]\n{chunk.meta['tool_result']}\n\n", flush=True, end="") + + # Print the main content of the chunk (from ChatGenerator) + if chunk.content: + print(chunk.content, flush=True, end="") diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 43bc383f1..4e557630f 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging from haystack.core.component.sockets import Sockets from haystack.dataclasses import ChatMessage, State, ToolCall +from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback from haystack.tools import ( ComponentTool, Tool, @@ -159,6 +160,7 @@ class ToolInvoker: tools: Union[List[Tool], Toolset], raise_on_failure: bool = True, convert_result_to_json_string: bool = False, + streaming_callback: Optional[StreamingCallbackT] = None, ): """ Initialize the ToolInvoker component. @@ -173,6 +175,10 @@ class ToolInvoker: :param convert_result_to_json_string: If True, the tool invocation result will be converted to a string using `json.dumps`. If False, the tool invocation result will be converted to a string using `str`. + :param streaming_callback: + A callback function that will be called to emit tool results. + Note that the result is only emitted once it becomes available — it is not + streamed incrementally in real time. :raises ValueError: If no tools are provided or if duplicate tool names are found. """ @@ -181,6 +187,7 @@ class ToolInvoker: # could be a Toolset instance or a list of Tools self.tools = tools + self.streaming_callback = streaming_callback # Convert Toolset to list for internal use if isinstance(tools, Toolset): @@ -272,7 +279,6 @@ class ToolInvoker: except StringConversionError as conversion_error: # If _handle_error re-raises, this properly preserves the chain raise conversion_error from e - return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call) @staticmethod @@ -358,13 +364,21 @@ class ToolInvoker: state.set(state_key, output_value, handler_override=handler) @component.output_types(tool_messages=List[ChatMessage], state=State) - def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dict[str, Any]: + def run( + self, + messages: List[ChatMessage], + state: Optional[State] = None, + streaming_callback: Optional[StreamingCallbackT] = None, + ) -> Dict[str, Any]: """ Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available. :param messages: A list of ChatMessage objects. :param state: The runtime state that should be used by the tools. + :param streaming_callback: A callback function that will be called to emit tool results. + Note that the result is only emitted once it becomes available — it is not + streamed incrementally in real time. :returns: A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role. Each ChatMessage objects wraps the result of a tool invocation. @@ -383,6 +397,9 @@ class ToolInvoker: # Only keep messages with tool calls messages_with_tool_calls = [message for message in messages if message.tool_calls] + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False + ) tool_messages = [] for message in messages_with_tool_calls: @@ -406,6 +423,7 @@ class ToolInvoker: # 2) Invoke the tool try: tool_result = tool_to_invoke.invoke(**final_args) + except ToolInvocationError as e: error_message = self._handle_error(e) tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) @@ -434,6 +452,11 @@ class ToolInvoker: ) ) + if streaming_callback is not None: + streaming_callback( + StreamingChunk(content="", meta={"tool_result": tool_result, "tool_call": tool_call}) + ) + return {"tool_messages": tool_messages, "state": state} def to_dict(self) -> Dict[str, Any]: diff --git a/releasenotes/notes/stream-tool-results-agent-7eaa5c2ccfa5e4bb.yaml b/releasenotes/notes/stream-tool-results-agent-7eaa5c2ccfa5e4bb.yaml new file mode 100644 index 000000000..4ebf93cee --- /dev/null +++ b/releasenotes/notes/stream-tool-results-agent-7eaa5c2ccfa5e4bb.yaml @@ -0,0 +1,9 @@ +--- +features: + - | + Add a `streaming_callback` parameter to `ToolInvoker` to enable streaming of tool results. + Note that tool_result is emitted only after the tool execution completes and is not streamed incrementally. + + - Update `print_streaming_chunk` to print ToolCall information if it is present in the chunk's metadata. + + - Update `Agent` to forward the `streaming_callback` to `ToolInvoker` to emit tool results during tool invocation. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index abdb62d2e..0dd98aa06 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -24,6 +24,7 @@ from haystack.core.component.types import OutputSocket from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses.chat_message import ChatRole, TextContent from haystack.dataclasses.streaming_chunk import StreamingChunk + from haystack.tools import Tool, ComponentTool from haystack.tools.toolset import Toolset from haystack.utils import serialize_callable, Secret @@ -778,6 +779,25 @@ class TestAgent: assert [isinstance(reply, ChatMessage) for reply in result["messages"]] assert "Hello from run_async" in result["messages"][1].text + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + def test_agent_streaming_with_tool_call(self, monkeypatch, weather_tool): + chat_generator = OpenAIChatGenerator() + agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) + agent.warm_up() + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + result = agent.run( + [ChatMessage.from_user("What's the weather in Paris?")], streaming_callback=streaming_callback + ) + + assert result is not None + assert result["messages"] is not None + assert streaming_callback_called + class TestAgentTracing: def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool): diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index f2dd0acbb..f85bdffc9 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -11,6 +11,7 @@ from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole from haystack.dataclasses.state import State from haystack.tools import ComponentTool, Tool, Toolset from haystack.tools.errors import ToolInvocationError +from haystack.dataclasses import StreamingChunk def weather_function(location): @@ -162,14 +163,23 @@ class TestToolInvoker: args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state) assert args == {"location": "Paris"} - def test_run(self, invoker): + def test_run_with_streaming_callback(self, invoker): + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) message = ChatMessage.from_assistant(tool_calls=[tool_call]) - result = invoker.run(messages=[message]) + result = invoker.run(messages=[message], streaming_callback=streaming_callback) assert "tool_messages" in result assert len(result["tool_messages"]) == 1 + # check we called the streaming callback + assert streaming_callback_called + tool_message = result["tool_messages"][0] assert isinstance(tool_message, ChatMessage) assert tool_message.is_from(ChatRole.TOOL)