feat: enable parallel tool execution in ToolInvoker (#9530)

* Enable parallel tool execution in ToolInvoker

* Update handling of errors

* Small fixes

* Small fixes

* Adapt number of executors

* Add release notes

* Add parallel tool calling to sync run

* Deprecate async_executor

* Deprecate async_executor

* Add thread lock

* extract methods

* Update release notes

* Update release notes

* Updates

* Add new tests

* Add test for async

* PR comments
This commit is contained in:
Amna Mubashar 2025-06-25 13:32:11 +02:00 committed by GitHub
parent 91094e1038
commit 1cd0a128d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 374 additions and 128 deletions

View File

@ -5,6 +5,7 @@
import asyncio import asyncio
import inspect import inspect
import json import json
import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
@ -169,6 +170,7 @@ class ToolInvoker:
streaming_callback: Optional[StreamingCallbackT] = None, streaming_callback: Optional[StreamingCallbackT] = None,
*, *,
enable_streaming_callback_passthrough: bool = False, enable_streaming_callback_passthrough: bool = False,
max_workers: int = 4,
async_executor: Optional[ThreadPoolExecutor] = None, async_executor: Optional[ThreadPoolExecutor] = None,
): ):
""" """
@ -193,9 +195,15 @@ class ToolInvoker:
This allows tools to stream their results back to the client. This allows tools to stream their results back to the client.
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature. Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
If False, the `streaming_callback` will not be passed to the tool invocation. If False, the `streaming_callback` will not be passed to the tool invocation.
:param max_workers:
The maximum number of workers to use in the thread pool executor.
:param async_executor: :param async_executor:
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be Optional `ThreadPoolExecutor` to use for asynchronous calls.
initialized and used. Note: As of Haystack 2.15.0, you no longer need to explicitly pass
`async_executor`. Instead, you can provide the `max_workers` parameter,
and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations.
Support for `async_executor` will be removed in Haystack 2.16.0.
Please migrate to using `max_workers` instead.
:raises ValueError: :raises ValueError:
If no tools are provided or if duplicate tool names are found. If no tools are provided or if duplicate tool names are found.
""" """
@ -206,6 +214,7 @@ class ToolInvoker:
self.tools = tools self.tools = tools
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
self.max_workers = max_workers
# Convert Toolset to list for internal use # Convert Toolset to list for internal use
if isinstance(tools, Toolset): if isinstance(tools, Toolset):
@ -223,8 +232,19 @@ class ToolInvoker:
self.raise_on_failure = raise_on_failure self.raise_on_failure = raise_on_failure
self.convert_result_to_json_string = convert_result_to_json_string self.convert_result_to_json_string = convert_result_to_json_string
self._owns_executor = async_executor is None self._owns_executor = async_executor is None
if self._owns_executor:
warnings.warn(
"'async_executor' is deprecated in favor of the 'max_workers' parameter. "
"ToolInvoker now creates its own thread pool executor by default using 'max_workers'. "
"Support for 'async_executor' will be removed in Haystack 2.16.0. "
"Please update your usage to pass 'max_workers' instead.",
DeprecationWarning,
)
self.executor = ( self.executor = (
ThreadPoolExecutor(thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=1) ThreadPoolExecutor(
thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers
)
if async_executor is None if async_executor is None
else async_executor else async_executor
) )
@ -427,6 +447,61 @@ class ToolInvoker:
# Merge other outputs into the state # Merge other outputs into the state
state.set(state_key, output_value, handler_override=handler) state.set(state_key, output_value, handler_override=handler)
def _prepare_tool_call_params(
self,
messages_with_tool_calls: List[ChatMessage],
state: State,
streaming_callback: Optional[StreamingCallbackT],
enable_streaming_passthrough: bool,
) -> tuple[List[Dict[str, Any]], List[ChatMessage]]:
"""
Prepare tool call parameters for execution and collect any error messages.
:param messages_with_tool_calls: Messages containing tool calls to process
:param state: The current state for argument injection
:param streaming_callback: Optional streaming callback to inject
:param enable_streaming_passthrough: Whether to pass streaming callback to tools
:returns: Tuple of (tool_call_params, error_messages)
"""
tool_call_params = []
error_messages = []
for message in messages_with_tool_calls:
for tool_call in message.tool_calls:
tool_name = tool_call.tool_name
# Check if the tool is available, otherwise return an error message
if tool_name not in self._tools_with_names:
error_message = self._handle_error(
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
)
error_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
)
continue
tool_to_invoke = self._tools_with_names[tool_name]
# Combine user + state inputs
llm_args = tool_call.arguments.copy()
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
# Check whether to inject streaming_callback
if (
enable_streaming_passthrough
and streaming_callback is not None
and "streaming_callback" not in final_args
):
invoke_params = self._get_func_params(tool_to_invoke)
if "streaming_callback" in invoke_params:
final_args["streaming_callback"] = streaming_callback
tool_call_params.append(
{"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "final_args": final_args}
)
return tool_call_params, error_messages
@component.output_types(tool_messages=List[ChatMessage], state=State) @component.output_types(tool_messages=List[ChatMessage], state=State)
def run( def run(
self, self,
@ -480,76 +555,69 @@ class ToolInvoker:
) )
tool_messages = [] tool_messages = []
for message in messages_with_tool_calls:
for tool_call in message.tool_calls:
tool_name = tool_call.tool_name
# Check if the tool is available, otherwise return an error message # 1) Collect all tool calls and their parameters for parallel execution
if tool_name not in self._tools_with_names: tool_call_params, error_messages = self._prepare_tool_call_params(
error_message = self._handle_error( messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough
ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) )
) tool_messages.extend(error_messages)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue
tool_to_invoke = self._tools_with_names[tool_name] # 2) Execute valid tool calls in parallel
if tool_call_params:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for params in tool_call_params:
future = executor.submit(self._execute_single_tool_call, **params) # type: ignore[arg-type]
futures.append(future)
# 1) Combine user + state inputs # 3) Process results in the order they are submitted
llm_args = tool_call.arguments.copy() for future in futures:
final_args = self._inject_state_args(tool_to_invoke, llm_args, state) result = future.result()
if isinstance(result, ChatMessage):
tool_messages.append(result)
else:
# Handle state merging and prepare tool result message
tool_call, tool_to_invoke, tool_result = result
# Check whether to inject streaming_callback # 4) Merge outputs into state
if ( try:
resolved_enable_streaming_passthrough self._merge_tool_outputs(tool_to_invoke, tool_result, state)
and streaming_callback is not None except Exception as e:
and "streaming_callback" not in final_args try:
): error_message = self._handle_error(
invoke_params = self._get_func_params(tool_to_invoke) ToolOutputMergeError(
if "streaming_callback" in invoke_params: f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
final_args["streaming_callback"] = streaming_callback )
)
tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
)
continue
except ToolOutputMergeError as propagated_e:
# Re-raise with proper error chain
raise propagated_e from e
# 2) Invoke the tool # 5) Prepare the tool result ChatMessage message
try:
tool_result = tool_to_invoke.invoke(**final_args)
except ToolInvocationError as e:
error_message = self._handle_error(e)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue
# 3) Merge outputs into state
try:
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
except Exception as e:
try:
error_message = self._handle_error(
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
)
tool_messages.append( tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) self._prepare_tool_result_message(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
)
) )
continue
except ToolOutputMergeError as propagated_e:
# Re-raise with proper error chain
raise propagated_e from e
# 4) Prepare the tool result ChatMessage message # 6) Handle streaming callback
tool_messages.append( if streaming_callback is not None:
self._prepare_tool_result_message( streaming_callback(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke StreamingChunk(
) content="",
) index=len(tool_messages) - 1,
tool_call_result=tool_messages[-1].tool_call_results[0],
if streaming_callback is not None: start=True,
streaming_callback( meta={
StreamingChunk( "tool_result": tool_messages[-1].tool_call_results[0].result,
content="", "tool_call": tool_call,
index=len(tool_messages) - 1, },
tool_call_result=tool_messages[-1].tool_call_results[0], )
start=True, )
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
)
# We stream one more chunk that contains a finish_reason if tool_messages were generated # We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None: if len(tool_messages) > 0 and streaming_callback is not None:
@ -561,6 +629,31 @@ class ToolInvoker:
return {"tool_messages": tool_messages, "state": state} return {"tool_messages": tool_messages, "state": state}
def _execute_single_tool_call(self, tool_call: ToolCall, tool_to_invoke: Tool, final_args: Dict[str, Any]):
"""
Execute a single tool call. This method is designed to be run in a thread pool.
:param tool_call: The ToolCall object containing the tool name and arguments.
:param tool_to_invoke: The Tool object that should be invoked.
:param final_args: The final arguments to pass to the tool.
:returns: Either a ChatMessage with error or a tuple of (tool_call, tool_to_invoke, tool_result)
"""
try:
tool_result = tool_to_invoke.invoke(**final_args)
return (tool_call, tool_to_invoke, tool_result)
except ToolInvocationError as e:
error_message = self._handle_error(e)
return ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
@staticmethod
async def invoke_tool_safely(executor: ThreadPoolExecutor, tool_to_invoke: Tool, final_args: Dict[str, Any]) -> Any:
"""Safely invoke a tool with proper exception handling."""
loop = asyncio.get_running_loop()
try:
return await loop.run_in_executor(executor, partial(tool_to_invoke.invoke, **final_args))
except ToolInvocationError as e:
return e
@component.output_types(tool_messages=List[ChatMessage], state=State) @component.output_types(tool_messages=List[ChatMessage], state=State)
async def run_async( async def run_async(
self, self,
@ -571,8 +664,9 @@ class ToolInvoker:
enable_streaming_callback_passthrough: Optional[bool] = None, enable_streaming_callback_passthrough: Optional[bool] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools. Asynchronously processes ChatMessage objects containing tool calls.
Multiple tool calls are performed concurrently.
:param messages: :param messages:
A list of ChatMessage objects. A list of ChatMessage objects.
:param state: The runtime state that should be used by the tools. :param state: The runtime state that should be used by the tools.
@ -598,6 +692,7 @@ class ToolInvoker:
:raises ToolOutputMergeError: :raises ToolOutputMergeError:
If merging tool outputs into state fails and `raise_on_failure` is True. If merging tool outputs into state fails and `raise_on_failure` is True.
""" """
if state is None: if state is None:
state = State(schema={}) state = State(schema={})
@ -614,78 +709,78 @@ class ToolInvoker:
) )
tool_messages = [] tool_messages = []
for message in messages_with_tool_calls:
for tool_call in message.tool_calls:
tool_name = tool_call.tool_name
# Check if the tool is available, otherwise return an error message # 1) Prepare tool call parameters for execution
if tool_name not in self._tools_with_names: tool_call_params, error_messages = self._prepare_tool_call_params(
error_message = self._handle_error( messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough
ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) )
) tool_messages.extend(error_messages)
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue
tool_to_invoke = self._tools_with_names[tool_name] # 2) Execute valid tool calls in parallel
if tool_call_params:
with self.executor as executor:
tool_call_tasks = []
valid_tool_calls = []
# 1) Combine user + state inputs # 3) Create async tasks for valid tool calls
llm_args = tool_call.arguments.copy() for params in tool_call_params:
final_args = self._inject_state_args(tool_to_invoke, llm_args, state) task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
tool_call_tasks.append(task)
valid_tool_calls.append((params["tool_call"], params["tool_to_invoke"]))
# Check whether to inject streaming_callback if tool_call_tasks:
if ( # 4) Gather results from all tool calls
resolved_enable_streaming_passthrough tool_results = await asyncio.gather(*tool_call_tasks)
and streaming_callback is not None
and "streaming_callback" not in final_args
):
invoke_params = self._get_func_params(tool_to_invoke)
if "streaming_callback" in invoke_params:
final_args["streaming_callback"] = streaming_callback
# 2) Invoke the tool asynchronously # 5) Process results
try: for i, ((tool_call, tool_to_invoke), tool_result) in enumerate(zip(valid_tool_calls, tool_results)):
tool_result = await asyncio.get_running_loop().run_in_executor( # Check if the tool_result is a ToolInvocationError (caught by our wrapper)
self.executor, partial(tool_to_invoke.invoke, **final_args) if isinstance(tool_result, ToolInvocationError):
) error_message = self._handle_error(tool_result)
tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
)
continue
except ToolInvocationError as e: # 6) Merge outputs into state
error_message = self._handle_error(e) try:
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) self._merge_tool_outputs(tool_to_invoke, tool_result, state)
continue except Exception as e:
try:
error_message = self._handle_error(
ToolOutputMergeError(
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
)
)
tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
)
continue
except ToolOutputMergeError as propagated_e:
# Re-raise with proper error chain
raise propagated_e from e
# 3) Merge outputs into state # 7) Prepare the tool result ChatMessage message
try:
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
except Exception as e:
try:
error_message = self._handle_error(
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
)
tool_messages.append( tool_messages.append(
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) self._prepare_tool_result_message(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
)
) )
continue
except ToolOutputMergeError as propagated_e:
# Re-raise with proper error chain
raise propagated_e from e
# 4) Prepare the tool result ChatMessage message # 8) Handle streaming callback
tool_messages.append( if streaming_callback is not None:
self._prepare_tool_result_message( await streaming_callback(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke StreamingChunk(
) content="",
) index=i,
tool_call_result=tool_messages[-1].tool_call_results[0],
if streaming_callback is not None: start=True,
await streaming_callback( meta={
StreamingChunk( "tool_result": tool_messages[-1].tool_call_results[0].result,
content="", "tool_call": tool_call,
index=len(tool_messages) - 1, },
tool_call_result=tool_messages[-1].tool_call_results[0], )
start=True, )
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
)
# We stream one more chunk that contains a finish_reason if tool_messages were generated # We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None: if len(tool_messages) > 0 and streaming_callback is not None:

View File

@ -0,0 +1,9 @@
---
features:
- |
`ToolInvoker` now executes `tool_calls` in parallel for both sync and async mode.
deprecations:
- |
`async_executor` parameter in `ToolInvoker` is deprecated in favor of `max_workers` parameter and will be removed in Haystack 2.16.0.
You can use `max_workers` parameter to control the number of threads used for parallel tool calling.

View File

@ -6,6 +6,7 @@ from unittest.mock import patch
import pytest import pytest
import json import json
import datetime import datetime
import time
from haystack import Pipeline from haystack import Pipeline
from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.builders.prompt_builder import PromptBuilder
@ -676,6 +677,147 @@ class TestToolInvoker:
) )
mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")]) mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")])
def test_parallel_tool_calling_with_state_updates(self):
"""Test that parallel tool execution with state updates works correctly with the state lock."""
# Create a shared counter variable to simulate a state value that gets updated
execution_log = []
def function_1():
time.sleep(0.1)
execution_log.append("tool_1_executed")
return {"counter": 1, "tool_name": "tool_1"}
def function_2():
time.sleep(0.1)
execution_log.append("tool_2_executed")
return {"counter": 2, "tool_name": "tool_2"}
def function_3():
time.sleep(0.1)
execution_log.append("tool_3_executed")
return {"counter": 3, "tool_name": "tool_3"}
# Create tools that all update the same state key
tool_1 = Tool(
name="state_tool_1",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_1,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
tool_2 = Tool(
name="state_tool_2",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_2,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
tool_3 = Tool(
name="state_tool_3",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_3,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
# Create ToolInvoker with all three tools
invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
tool_calls = [
ToolCall(tool_name="state_tool_1", arguments={}),
ToolCall(tool_name="state_tool_2", arguments={}),
ToolCall(tool_name="state_tool_3", arguments={}),
]
message = ChatMessage.from_assistant(tool_calls=tool_calls)
result = invoker.run(messages=[message], state=state)
# Verify that all three tools were executed
assert len(execution_log) == 3
assert "tool_1_executed" in execution_log
assert "tool_2_executed" in execution_log
assert "tool_3_executed" in execution_log
# Verify that the state was updated correctly
# Due to parallel execution, we can't predict which tool will be the last to update
assert state.has("counter")
assert state.has("last_tool")
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
@pytest.mark.asyncio
async def test_async_parallel_tool_calling_with_state_updates(self):
"""Test that parallel tool execution with state updates works correctly with the state lock."""
# Create a shared counter variable to simulate a state value that gets updated
execution_log = []
def function_1():
time.sleep(0.1)
execution_log.append("tool_1_executed")
return {"counter": 1, "tool_name": "tool_1"}
def function_2():
time.sleep(0.1)
execution_log.append("tool_2_executed")
return {"counter": 2, "tool_name": "tool_2"}
def function_3():
time.sleep(0.1)
execution_log.append("tool_3_executed")
return {"counter": 3, "tool_name": "tool_3"}
# Create tools that all update the same state key
tool_1 = Tool(
name="state_tool_1",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_1,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
tool_2 = Tool(
name="state_tool_2",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_2,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
tool_3 = Tool(
name="state_tool_3",
description="A tool that updates state counter",
parameters={"type": "object", "properties": {}},
function=function_3,
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
)
# Create ToolInvoker with all three tools
invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
tool_calls = [
ToolCall(tool_name="state_tool_1", arguments={}),
ToolCall(tool_name="state_tool_2", arguments={}),
ToolCall(tool_name="state_tool_3", arguments={}),
]
message = ChatMessage.from_assistant(tool_calls=tool_calls)
result = await invoker.run_async(messages=[message], state=state)
# Verify that all three tools were executed
assert len(execution_log) == 3
assert "tool_1_executed" in execution_log
assert "tool_2_executed" in execution_log
assert "tool_3_executed" in execution_log
# Verify that the state was updated correctly
# Due to parallel execution, we can't predict which tool will be the last to update
assert state.has("counter")
assert state.has("last_tool")
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
class TestMergeToolOutputs: class TestMergeToolOutputs:
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool): def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):