mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-09 13:46:54 +00:00
feat: enable streaming ToolCall/Result from Agent (#9290)
* Testing solutions for streaming * Remove unused methods * Add fixes * Update docstrings * add release notes and test * PR comments * add a new util function * Adjust emit_tool_info * PR comments * Remove emit function, add streaming for tool_call --------- Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
This commit is contained in:
parent
7db719981d
commit
64f384b52d
@ -15,7 +15,7 @@ from haystack.core.serialization import component_to_dict
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
|
||||
from haystack.dataclasses.state_utils import merge_lists
|
||||
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
||||
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
|
||||
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
|
||||
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
||||
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
|
||||
@ -84,6 +84,7 @@ class Agent:
|
||||
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
|
||||
If set to False, the exception will be turned into a chat message and passed to the LLM.
|
||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
||||
The same callback can be configured to emit tool results when a tool is called.
|
||||
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
|
||||
"""
|
||||
# Check if chat_generator supports tools parameter
|
||||
@ -201,9 +202,8 @@ class Agent:
|
||||
def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]:
|
||||
"""Prepare inputs for the chat generator."""
|
||||
generator_inputs: Dict[str, Any] = {"tools": self.tools}
|
||||
selected_callback = streaming_callback or self.streaming_callback
|
||||
if selected_callback is not None:
|
||||
generator_inputs["streaming_callback"] = selected_callback
|
||||
if streaming_callback is not None:
|
||||
generator_inputs["streaming_callback"] = streaming_callback
|
||||
return generator_inputs
|
||||
|
||||
def _create_agent_span(self) -> Any:
|
||||
@ -229,6 +229,7 @@ class Agent:
|
||||
|
||||
:param messages: List of chat messages to process
|
||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
||||
The same callback can be configured to emit tool results when a tool is called.
|
||||
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
||||
The keys must match the schema defined in the Agent's `state_schema`.
|
||||
:return: Dictionary containing messages and outputs matching the defined output types
|
||||
@ -239,6 +240,10 @@ class Agent:
|
||||
if self.system_prompt is not None:
|
||||
messages = [ChatMessage.from_system(self.system_prompt)] + messages
|
||||
|
||||
streaming_callback = select_streaming_callback(
|
||||
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
|
||||
)
|
||||
|
||||
input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
|
||||
|
||||
state = State(schema=self.state_schema, data=kwargs)
|
||||
@ -271,7 +276,7 @@ class Agent:
|
||||
tool_invoker_result = Pipeline._run_component(
|
||||
component_name="tool_invoker",
|
||||
component={"instance": self._tool_invoker},
|
||||
inputs={"messages": llm_messages, "state": state},
|
||||
inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback},
|
||||
component_visits=component_visits,
|
||||
parent_span=span,
|
||||
)
|
||||
@ -312,6 +317,7 @@ class Agent:
|
||||
|
||||
:param messages: List of chat messages to process
|
||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
||||
The same callback can be configured to emit tool results when a tool is called.
|
||||
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
||||
The keys must match the schema defined in the Agent's `state_schema`.
|
||||
:return: Dictionary containing messages and outputs matching the defined output types
|
||||
@ -322,6 +328,10 @@ class Agent:
|
||||
if self.system_prompt is not None:
|
||||
messages = [ChatMessage.from_system(self.system_prompt)] + messages
|
||||
|
||||
streaming_callback = select_streaming_callback(
|
||||
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
|
||||
)
|
||||
|
||||
input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
|
||||
|
||||
state = State(schema=self.state_schema, data=kwargs)
|
||||
|
||||
@ -279,6 +279,7 @@ class OpenAIChatGenerator:
|
||||
chat_completion, # type: ignore
|
||||
streaming_callback, # type: ignore
|
||||
)
|
||||
|
||||
else:
|
||||
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
|
||||
completions = [
|
||||
@ -355,6 +356,7 @@ class OpenAIChatGenerator:
|
||||
chat_completion, # type: ignore
|
||||
streaming_callback, # type: ignore
|
||||
)
|
||||
|
||||
else:
|
||||
assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
|
||||
completions = [
|
||||
|
||||
@ -2,13 +2,47 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
||||
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
|
||||
|
||||
def print_streaming_chunk(chunk: StreamingChunk) -> None:
|
||||
"""
|
||||
Default callback function for streaming responses.
|
||||
Callback function to handle and display streaming output chunks.
|
||||
|
||||
Prints the tokens of the first completion to stdout as soon as they are received
|
||||
This function processes a `StreamingChunk` object by:
|
||||
- Printing tool call metadata (if any), including function names and arguments, as they arrive.
|
||||
- Printing tool call results when available.
|
||||
- Printing the main content (e.g., text tokens) of the chunk as it is received.
|
||||
|
||||
The function outputs data directly to stdout and flushes output buffers to ensure immediate display during
|
||||
streaming.
|
||||
|
||||
:param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and
|
||||
tool results.
|
||||
"""
|
||||
print(chunk.content, flush=True, end="")
|
||||
# Print tool call metadata if available (from ChatGenerator)
|
||||
if chunk.meta.get("tool_calls"):
|
||||
for tool_call in chunk.meta["tool_calls"]:
|
||||
if isinstance(tool_call, ChoiceDeltaToolCall) and tool_call.function:
|
||||
# print the tool name
|
||||
if tool_call.function.name and not tool_call.function.arguments:
|
||||
print("[TOOL CALL]\n", flush=True, end="")
|
||||
print(f"Tool: {tool_call.function.name} ", flush=True, end="")
|
||||
|
||||
# print the tool arguments
|
||||
if tool_call.function.arguments:
|
||||
if tool_call.function.arguments.startswith("{"):
|
||||
print("\nArguments: ", flush=True, end="")
|
||||
print(tool_call.function.arguments, flush=True, end="")
|
||||
if tool_call.function.arguments.endswith("}"):
|
||||
print("\n\n", flush=True, end="")
|
||||
|
||||
# Print tool call results if available (from ToolInvoker)
|
||||
if chunk.meta.get("tool_result"):
|
||||
print(f"[TOOL RESULT]\n{chunk.meta['tool_result']}\n\n", flush=True, end="")
|
||||
|
||||
# Print the main content of the chunk (from ChatGenerator)
|
||||
if chunk.content:
|
||||
print(chunk.content, flush=True, end="")
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.core.component.sockets import Sockets
|
||||
from haystack.dataclasses import ChatMessage, State, ToolCall
|
||||
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
|
||||
from haystack.tools import (
|
||||
ComponentTool,
|
||||
Tool,
|
||||
@ -159,6 +160,7 @@ class ToolInvoker:
|
||||
tools: Union[List[Tool], Toolset],
|
||||
raise_on_failure: bool = True,
|
||||
convert_result_to_json_string: bool = False,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the ToolInvoker component.
|
||||
@ -173,6 +175,10 @@ class ToolInvoker:
|
||||
:param convert_result_to_json_string:
|
||||
If True, the tool invocation result will be converted to a string using `json.dumps`.
|
||||
If False, the tool invocation result will be converted to a string using `str`.
|
||||
:param streaming_callback:
|
||||
A callback function that will be called to emit tool results.
|
||||
Note that the result is only emitted once it becomes available — it is not
|
||||
streamed incrementally in real time.
|
||||
:raises ValueError:
|
||||
If no tools are provided or if duplicate tool names are found.
|
||||
"""
|
||||
@ -181,6 +187,7 @@ class ToolInvoker:
|
||||
|
||||
# could be a Toolset instance or a list of Tools
|
||||
self.tools = tools
|
||||
self.streaming_callback = streaming_callback
|
||||
|
||||
# Convert Toolset to list for internal use
|
||||
if isinstance(tools, Toolset):
|
||||
@ -272,7 +279,6 @@ class ToolInvoker:
|
||||
except StringConversionError as conversion_error:
|
||||
# If _handle_error re-raises, this properly preserves the chain
|
||||
raise conversion_error from e
|
||||
|
||||
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
|
||||
|
||||
@staticmethod
|
||||
@ -358,13 +364,21 @@ class ToolInvoker:
|
||||
state.set(state_key, output_value, handler_override=handler)
|
||||
|
||||
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
||||
def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dict[str, Any]:
|
||||
def run(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
state: Optional[State] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
|
||||
|
||||
:param messages:
|
||||
A list of ChatMessage objects.
|
||||
:param state: The runtime state that should be used by the tools.
|
||||
:param streaming_callback: A callback function that will be called to emit tool results.
|
||||
Note that the result is only emitted once it becomes available — it is not
|
||||
streamed incrementally in real time.
|
||||
:returns:
|
||||
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
|
||||
Each ChatMessage objects wraps the result of a tool invocation.
|
||||
@ -383,6 +397,9 @@ class ToolInvoker:
|
||||
|
||||
# Only keep messages with tool calls
|
||||
messages_with_tool_calls = [message for message in messages if message.tool_calls]
|
||||
streaming_callback = select_streaming_callback(
|
||||
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
|
||||
)
|
||||
|
||||
tool_messages = []
|
||||
for message in messages_with_tool_calls:
|
||||
@ -406,6 +423,7 @@ class ToolInvoker:
|
||||
# 2) Invoke the tool
|
||||
try:
|
||||
tool_result = tool_to_invoke.invoke(**final_args)
|
||||
|
||||
except ToolInvocationError as e:
|
||||
error_message = self._handle_error(e)
|
||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
||||
@ -434,6 +452,11 @@ class ToolInvoker:
|
||||
)
|
||||
)
|
||||
|
||||
if streaming_callback is not None:
|
||||
streaming_callback(
|
||||
StreamingChunk(content="", meta={"tool_result": tool_result, "tool_call": tool_call})
|
||||
)
|
||||
|
||||
return {"tool_messages": tool_messages, "state": state}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add a `streaming_callback` parameter to `ToolInvoker` to enable streaming of tool results.
|
||||
Note that tool_result is emitted only after the tool execution completes and is not streamed incrementally.
|
||||
|
||||
- Update `print_streaming_chunk` to print ToolCall information if it is present in the chunk's metadata.
|
||||
|
||||
- Update `Agent` to forward the `streaming_callback` to `ToolInvoker` to emit tool results during tool invocation.
|
||||
@ -24,6 +24,7 @@ from haystack.core.component.types import OutputSocket
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.dataclasses.chat_message import ChatRole, TextContent
|
||||
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||
|
||||
from haystack.tools import Tool, ComponentTool
|
||||
from haystack.tools.toolset import Toolset
|
||||
from haystack.utils import serialize_callable, Secret
|
||||
@ -778,6 +779,25 @@ class TestAgent:
|
||||
assert [isinstance(reply, ChatMessage) for reply in result["messages"]]
|
||||
assert "Hello from run_async" in result["messages"][1].text
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
||||
def test_agent_streaming_with_tool_call(self, monkeypatch, weather_tool):
|
||||
chat_generator = OpenAIChatGenerator()
|
||||
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
|
||||
agent.warm_up()
|
||||
streaming_callback_called = False
|
||||
|
||||
def streaming_callback(chunk: StreamingChunk) -> None:
|
||||
nonlocal streaming_callback_called
|
||||
streaming_callback_called = True
|
||||
|
||||
result = agent.run(
|
||||
[ChatMessage.from_user("What's the weather in Paris?")], streaming_callback=streaming_callback
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["messages"] is not None
|
||||
assert streaming_callback_called
|
||||
|
||||
|
||||
class TestAgentTracing:
|
||||
def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool):
|
||||
|
||||
@ -11,6 +11,7 @@ from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
|
||||
from haystack.dataclasses.state import State
|
||||
from haystack.tools import ComponentTool, Tool, Toolset
|
||||
from haystack.tools.errors import ToolInvocationError
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
|
||||
|
||||
def weather_function(location):
|
||||
@ -162,14 +163,23 @@ class TestToolInvoker:
|
||||
args = ToolInvoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state)
|
||||
assert args == {"location": "Paris"}
|
||||
|
||||
def test_run(self, invoker):
|
||||
def test_run_with_streaming_callback(self, invoker):
|
||||
streaming_callback_called = False
|
||||
|
||||
def streaming_callback(chunk: StreamingChunk) -> None:
|
||||
nonlocal streaming_callback_called
|
||||
streaming_callback_called = True
|
||||
|
||||
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
|
||||
result = invoker.run(messages=[message])
|
||||
result = invoker.run(messages=[message], streaming_callback=streaming_callback)
|
||||
assert "tool_messages" in result
|
||||
assert len(result["tool_messages"]) == 1
|
||||
|
||||
# check we called the streaming callback
|
||||
assert streaming_callback_called
|
||||
|
||||
tool_message = result["tool_messages"][0]
|
||||
assert isinstance(tool_message, ChatMessage)
|
||||
assert tool_message.is_from(ChatRole.TOOL)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user