feat: enable parallel tool execution in ToolInvoker (#9530)

* Enable parallel tool execution in ToolInvoker

* Update handling of errors

* Small fixes

* Small fixes

* Adapt number of executors

* Add release notes

* Add parallel tool calling to sync run

* Deprecate async_executor

* Deprecate async_executor

* Add thread lock

* extract methods

* Update release notes

* Update release notes

* Updates

* Add new tests

* Add test for async

* PR comments
This commit is contained in:
Amna Mubashar 2025-06-25 13:32:11 +02:00 committed by GitHub
parent 91094e1038
commit 1cd0a128d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 374 additions and 128 deletions

View File

@ -5,6 +5,7 @@
import asyncio
import inspect
import json
import warnings
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Dict, List, Optional, Set, Union
@ -169,6 +170,7 @@ class ToolInvoker:
streaming_callback: Optional[StreamingCallbackT] = None,
*,
enable_streaming_callback_passthrough: bool = False,
max_workers: int = 4,
async_executor: Optional[ThreadPoolExecutor] = None,
):
"""
@ -193,9 +195,15 @@ class ToolInvoker:
This allows tools to stream their results back to the client.
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
If False, the `streaming_callback` will not be passed to the tool invocation.
:param max_workers:
The maximum number of workers to use in the thread pool executor.
:param async_executor:
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
initialized and used.
Optional `ThreadPoolExecutor` to use for asynchronous calls.
Note: As of Haystack 2.15.0, you no longer need to explicitly pass
`async_executor`. Instead, you can provide the `max_workers` parameter,
and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations.
Support for `async_executor` will be removed in Haystack 2.16.0.
Please migrate to using `max_workers` instead.
:raises ValueError:
If no tools are provided or if duplicate tool names are found.
"""
@ -206,6 +214,7 @@ class ToolInvoker:
self.tools = tools
self.streaming_callback = streaming_callback
self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
self.max_workers = max_workers
# Convert Toolset to list for internal use
if isinstance(tools, Toolset):
@ -223,8 +232,19 @@ class ToolInvoker:
self.raise_on_failure = raise_on_failure
self.convert_result_to_json_string = convert_result_to_json_string
self._owns_executor = async_executor is None
if self._owns_executor:
warnings.warn(
"'async_executor' is deprecated in favor of the 'max_workers' parameter. "
"ToolInvoker now creates its own thread pool executor by default using 'max_workers'. "
"Support for 'async_executor' will be removed in Haystack 2.16.0. "
"Please update your usage to pass 'max_workers' instead.",
DeprecationWarning,
)
self.executor = (
ThreadPoolExecutor(thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=1)
ThreadPoolExecutor(
thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers
)
if async_executor is None
else async_executor
)
@ -427,6 +447,61 @@ class ToolInvoker:
# Merge other outputs into the state
state.set(state_key, output_value, handler_override=handler)
def _prepare_tool_call_params(
self,
messages_with_tool_calls: List[ChatMessage],
state: State,
streaming_callback: Optional[StreamingCallbackT],
enable_streaming_passthrough: bool,
) -> tuple[List[Dict[str, Any]], List[ChatMessage]]:
"""
Prepare tool call parameters for execution and collect any error messages.
:param messages_with_tool_calls: Messages containing tool calls to process
:param state: The current state for argument injection
:param streaming_callback: Optional streaming callback to inject
:param enable_streaming_passthrough: Whether to pass streaming callback to tools
:returns: Tuple of (tool_call_params, error_messages)
"""
tool_call_params = []
error_messages = []
for message in messages_with_tool_calls:
for tool_call in message.tool_calls:
tool_name = tool_call.tool_name
# Check if the tool is available, otherwise return an error message
if tool_name not in self._tools_with_names:
error_message = self._handle_error(
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
)
error_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
)
continue
tool_to_invoke = self._tools_with_names[tool_name]
# Combine user + state inputs
llm_args = tool_call.arguments.copy()
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
# Check whether to inject streaming_callback
if (
enable_streaming_passthrough
and streaming_callback is not None
and "streaming_callback" not in final_args
):
invoke_params = self._get_func_params(tool_to_invoke)
if "streaming_callback" in invoke_params:
final_args["streaming_callback"] = streaming_callback
tool_call_params.append(
{"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "final_args": final_args}
)
return tool_call_params, error_messages
@component.output_types(tool_messages=List[ChatMessage], state=State)
def run(
self,
@ -480,76 +555,69 @@ class ToolInvoker:
)
tool_messages = []
for message in messages_with_tool_calls:
for tool_call in message.tool_calls:
tool_name = tool_call.tool_name
# Check if the tool is available, otherwise return an error message
if tool_name not in self._tools_with_names:
error_message = self._handle_error(
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue
# 1) Collect all tool calls and their parameters for parallel execution
tool_call_params, error_messages = self._prepare_tool_call_params(
messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough
)
tool_messages.extend(error_messages)
tool_to_invoke = self._tools_with_names[tool_name]
# 2) Execute valid tool calls in parallel
if tool_call_params:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for params in tool_call_params:
future = executor.submit(self._execute_single_tool_call, **params) # type: ignore[arg-type]
futures.append(future)
# 1) Combine user + state inputs
llm_args = tool_call.arguments.copy()
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
# 3) Process results in the order they are submitted
for future in futures:
result = future.result()
if isinstance(result, ChatMessage):
tool_messages.append(result)
else:
# Handle state merging and prepare tool result message
tool_call, tool_to_invoke, tool_result = result
# Check whether to inject streaming_callback
if (
resolved_enable_streaming_passthrough
and streaming_callback is not None
and "streaming_callback" not in final_args
):
invoke_params = self._get_func_params(tool_to_invoke)
if "streaming_callback" in invoke_params:
final_args["streaming_callback"] = streaming_callback
# 4) Merge outputs into state
try:
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
except Exception as e:
try:
error_message = self._handle_error(
ToolOutputMergeError(
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
)
)
tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
)
continue
except ToolOutputMergeError as propagated_e:
# Re-raise with proper error chain
raise propagated_e from e
# 2) Invoke the tool
try:
tool_result = tool_to_invoke.invoke(**final_args)
except ToolInvocationError as e:
error_message = self._handle_error(e)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue
# 3) Merge outputs into state
try:
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
except Exception as e:
try:
error_message = self._handle_error(
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
)
# 5) Prepare the tool result ChatMessage message
tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
self._prepare_tool_result_message(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
)
)
continue
except ToolOutputMergeError as propagated_e:
# Re-raise with proper error chain
raise propagated_e from e
# 4) Prepare the tool result ChatMessage message
tool_messages.append(
self._prepare_tool_result_message(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
)
)
if streaming_callback is not None:
streaming_callback(
StreamingChunk(
content="",
index=len(tool_messages) - 1,
tool_call_result=tool_messages[-1].tool_call_results[0],
start=True,
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
)
# 6) Handle streaming callback
if streaming_callback is not None:
streaming_callback(
StreamingChunk(
content="",
index=len(tool_messages) - 1,
tool_call_result=tool_messages[-1].tool_call_results[0],
start=True,
meta={
"tool_result": tool_messages[-1].tool_call_results[0].result,
"tool_call": tool_call,
},
)
)
# We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None:
@ -561,6 +629,31 @@ class ToolInvoker:
return {"tool_messages": tool_messages, "state": state}
def _execute_single_tool_call(self, tool_call: ToolCall, tool_to_invoke: Tool, final_args: Dict[str, Any]):
"""
Execute a single tool call. This method is designed to be run in a thread pool.
:param tool_call: The ToolCall object containing the tool name and arguments.
:param tool_to_invoke: The Tool object that should be invoked.
:param final_args: The final arguments to pass to the tool.
:returns: Either a ChatMessage with error or a tuple of (tool_call, tool_to_invoke, tool_result)
"""
try:
tool_result = tool_to_invoke.invoke(**final_args)
return (tool_call, tool_to_invoke, tool_result)
except ToolInvocationError as e:
error_message = self._handle_error(e)
return ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
@staticmethod
async def invoke_tool_safely(executor: ThreadPoolExecutor, tool_to_invoke: Tool, final_args: Dict[str, Any]) -> Any:
"""Safely invoke a tool with proper exception handling."""
loop = asyncio.get_running_loop()
try:
return await loop.run_in_executor(executor, partial(tool_to_invoke.invoke, **final_args))
except ToolInvocationError as e:
return e
@component.output_types(tool_messages=List[ChatMessage], state=State)
async def run_async(
self,
@ -571,8 +664,9 @@ class ToolInvoker:
enable_streaming_callback_passthrough: Optional[bool] = None,
) -> Dict[str, Any]:
"""
Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools.
Asynchronously processes ChatMessage objects containing tool calls.
Multiple tool calls are performed concurrently.
:param messages:
A list of ChatMessage objects.
:param state: The runtime state that should be used by the tools.
@ -598,6 +692,7 @@ class ToolInvoker:
:raises ToolOutputMergeError:
If merging tool outputs into state fails and `raise_on_failure` is True.
"""
if state is None:
state = State(schema={})
@ -614,78 +709,78 @@ class ToolInvoker:
)
tool_messages = []
for message in messages_with_tool_calls:
for tool_call in message.tool_calls:
tool_name = tool_call.tool_name
# Check if the tool is available, otherwise return an error message
if tool_name not in self._tools_with_names:
error_message = self._handle_error(
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue
# 1) Prepare tool call parameters for execution
tool_call_params, error_messages = self._prepare_tool_call_params(
messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough
)
tool_messages.extend(error_messages)
tool_to_invoke = self._tools_with_names[tool_name]
# 2) Execute valid tool calls in parallel
if tool_call_params:
with self.executor as executor:
tool_call_tasks = []
valid_tool_calls = []
# 1) Combine user + state inputs
llm_args = tool_call.arguments.copy()
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
# 3) Create async tasks for valid tool calls
for params in tool_call_params:
task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
tool_call_tasks.append(task)
valid_tool_calls.append((params["tool_call"], params["tool_to_invoke"]))
# Check whether to inject streaming_callback
if (
resolved_enable_streaming_passthrough
and streaming_callback is not None
and "streaming_callback" not in final_args
):
invoke_params = self._get_func_params(tool_to_invoke)
if "streaming_callback" in invoke_params:
final_args["streaming_callback"] = streaming_callback
if tool_call_tasks:
# 4) Gather results from all tool calls
tool_results = await asyncio.gather(*tool_call_tasks)
# 2) Invoke the tool asynchronously
try:
tool_result = await asyncio.get_running_loop().run_in_executor(
self.executor, partial(tool_to_invoke.invoke, **final_args)
)
# 5) Process results
for i, ((tool_call, tool_to_invoke), tool_result) in enumerate(zip(valid_tool_calls, tool_results)):
# Check if the tool_result is a ToolInvocationError (caught by our wrapper)
if isinstance(tool_result, ToolInvocationError):
error_message = self._handle_error(tool_result)
tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
)
continue
except ToolInvocationError as e:
error_message = self._handle_error(e)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue
# 6) Merge outputs into state
try:
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
except Exception as e:
try:
error_message = self._handle_error(
ToolOutputMergeError(
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
)
)
tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
)
continue
except ToolOutputMergeError as propagated_e:
# Re-raise with proper error chain
raise propagated_e from e
# 3) Merge outputs into state
try:
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
except Exception as e:
try:
error_message = self._handle_error(
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
)
# 7) Prepare the tool result ChatMessage message
tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
self._prepare_tool_result_message(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
)
)
continue
except ToolOutputMergeError as propagated_e:
# Re-raise with proper error chain
raise propagated_e from e
# 4) Prepare the tool result ChatMessage message
tool_messages.append(
self._prepare_tool_result_message(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
)
)
if streaming_callback is not None:
await streaming_callback(
StreamingChunk(
content="",
index=len(tool_messages) - 1,
tool_call_result=tool_messages[-1].tool_call_results[0],
start=True,
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
)
# 8) Handle streaming callback
if streaming_callback is not None:
await streaming_callback(
StreamingChunk(
content="",
index=i,
tool_call_result=tool_messages[-1].tool_call_results[0],
start=True,
meta={
"tool_result": tool_messages[-1].tool_call_results[0].result,
"tool_call": tool_call,
},
)
)
# We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None:

View File

@ -0,0 +1,9 @@
---
features:
- |
`ToolInvoker` now executes `tool_calls` in parallel for both sync and async mode.
deprecations:
- |
`async_executor` parameter in `ToolInvoker` is deprecated in favor of `max_workers` parameter and will be removed in Haystack 2.16.0.
You can use `max_workers` parameter to control the number of threads used for parallel tool calling.

View File

@ -6,6 +6,7 @@ from unittest.mock import patch
import pytest
import json
import datetime
import time
from haystack import Pipeline
from haystack.components.builders.prompt_builder import PromptBuilder
@ -676,6 +677,147 @@ class TestToolInvoker:
)
mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")])
def test_parallel_tool_calling_with_state_updates(self):
"""Test that parallel tool execution with state updates works correctly with the state lock."""
# Create a shared counter variable to simulate a state value that gets updated
execution_log = []
def function_1():
time.sleep(0.1)
execution_log.append("tool_1_executed")
return {"counter": 1, "tool_name": "tool_1"}
def function_2():
time.sleep(0.1)
execution_log.append("tool_2_executed")
return {"counter": 2, "tool_name": "tool_2"}
def function_3():
time.sleep(0.1)
execution_log.append("tool_3_executed")
return {"counter": 3, "tool_name": "tool_3"}
# Create tools that all update the same state key
tool_1 = Tool(
name="state_tool_1",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_1,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
tool_2 = Tool(
name="state_tool_2",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_2,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
tool_3 = Tool(
name="state_tool_3",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_3,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
# Create ToolInvoker with all three tools
invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
tool_calls = [
ToolCall(tool_name="state_tool_1", arguments={}),
ToolCall(tool_name="state_tool_2", arguments={}),
ToolCall(tool_name="state_tool_3", arguments={}),
]
message = ChatMessage.from_assistant(tool_calls=tool_calls)
result = invoker.run(messages=[message], state=state)
# Verify that all three tools were executed
assert len(execution_log) == 3
assert "tool_1_executed" in execution_log
assert "tool_2_executed" in execution_log
assert "tool_3_executed" in execution_log
# Verify that the state was updated correctly
# Due to parallel execution, we can't predict which tool will be the last to update
assert state.has("counter")
assert state.has("last_tool")
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
@pytest.mark.asyncio
async def test_async_parallel_tool_calling_with_state_updates(self):
"""Test that parallel tool execution with state updates works correctly with the state lock."""
# Create a shared counter variable to simulate a state value that gets updated
execution_log = []
def function_1():
time.sleep(0.1)
execution_log.append("tool_1_executed")
return {"counter": 1, "tool_name": "tool_1"}
def function_2():
time.sleep(0.1)
execution_log.append("tool_2_executed")
return {"counter": 2, "tool_name": "tool_2"}
def function_3():
time.sleep(0.1)
execution_log.append("tool_3_executed")
return {"counter": 3, "tool_name": "tool_3"}
# Create tools that all update the same state key
tool_1 = Tool(
name="state_tool_1",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_1,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
tool_2 = Tool(
name="state_tool_2",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_2,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
tool_3 = Tool(
name="state_tool_3",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_3,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
# Create ToolInvoker with all three tools
invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
tool_calls = [
ToolCall(tool_name="state_tool_1", arguments={}),
ToolCall(tool_name="state_tool_2", arguments={}),
ToolCall(tool_name="state_tool_3", arguments={}),
]
message = ChatMessage.from_assistant(tool_calls=tool_calls)
result = await invoker.run_async(messages=[message], state=state)
# Verify that all three tools were executed
assert len(execution_log) == 3
assert "tool_1_executed" in execution_log
assert "tool_2_executed" in execution_log
assert "tool_3_executed" in execution_log
# Verify that the state was updated correctly
# Due to parallel execution, we can't predict which tool will be the last to update
assert state.has("counter")
assert state.has("last_tool")
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
class TestMergeToolOutputs:
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):