diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 47f2f0eca..935623365 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -5,6 +5,7 @@ import asyncio import inspect import json +import warnings from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import Any, Dict, List, Optional, Set, Union @@ -169,6 +170,7 @@ class ToolInvoker: streaming_callback: Optional[StreamingCallbackT] = None, *, enable_streaming_callback_passthrough: bool = False, + max_workers: int = 4, async_executor: Optional[ThreadPoolExecutor] = None, ): """ @@ -193,9 +195,15 @@ class ToolInvoker: This allows tools to stream their results back to the client. Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature. If False, the `streaming_callback` will not be passed to the tool invocation. + :param max_workers: + The maximum number of workers to use in the thread pool executor. :param async_executor: - Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be - initialized and used. + Optional `ThreadPoolExecutor` to use for asynchronous calls. + Note: As of Haystack 2.15.0, you no longer need to explicitly pass + `async_executor`. Instead, you can provide the `max_workers` parameter, + and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations. + Support for `async_executor` will be removed in Haystack 2.16.0. + Please migrate to using `max_workers` instead. :raises ValueError: If no tools are provided or if duplicate tool names are found. """ @@ -206,6 +214,7 @@ class ToolInvoker: self.tools = tools self.streaming_callback = streaming_callback self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough + self.max_workers = max_workers # Convert Toolset to list for internal use if isinstance(tools, Toolset): @@ -223,8 +232,19 @@ class ToolInvoker: self.raise_on_failure = raise_on_failure self.convert_result_to_json_string = convert_result_to_json_string self._owns_executor = async_executor is None + if self._owns_executor: + warnings.warn( + "'async_executor' is deprecated in favor of the 'max_workers' parameter. " + "ToolInvoker now creates its own thread pool executor by default using 'max_workers'. " + "Support for 'async_executor' will be removed in Haystack 2.16.0. " + "Please update your usage to pass 'max_workers' instead.", + DeprecationWarning, + ) + self.executor = ( - ThreadPoolExecutor(thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=1) + ThreadPoolExecutor( + thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers + ) if async_executor is None else async_executor ) @@ -427,6 +447,61 @@ class ToolInvoker: # Merge other outputs into the state state.set(state_key, output_value, handler_override=handler) + def _prepare_tool_call_params( + self, + messages_with_tool_calls: List[ChatMessage], + state: State, + streaming_callback: Optional[StreamingCallbackT], + enable_streaming_passthrough: bool, + ) -> tuple[List[Dict[str, Any]], List[ChatMessage]]: + """ + Prepare tool call parameters for execution and collect any error messages. + + :param messages_with_tool_calls: Messages containing tool calls to process + :param state: The current state for argument injection + :param streaming_callback: Optional streaming callback to inject + :param enable_streaming_passthrough: Whether to pass streaming callback to tools + :returns: Tuple of (tool_call_params, error_messages) + """ + tool_call_params = [] + error_messages = [] + + for message in messages_with_tool_calls: + for tool_call in message.tool_calls: + tool_name = tool_call.tool_name + + # Check if the tool is available, otherwise return an error message + if tool_name not in self._tools_with_names: + error_message = self._handle_error( + ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) + ) + error_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + + tool_to_invoke = self._tools_with_names[tool_name] + + # Combine user + state inputs + llm_args = tool_call.arguments.copy() + final_args = self._inject_state_args(tool_to_invoke, llm_args, state) + + # Check whether to inject streaming_callback + if ( + enable_streaming_passthrough + and streaming_callback is not None + and "streaming_callback" not in final_args + ): + invoke_params = self._get_func_params(tool_to_invoke) + if "streaming_callback" in invoke_params: + final_args["streaming_callback"] = streaming_callback + + tool_call_params.append( + {"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "final_args": final_args} + ) + + return tool_call_params, error_messages + @component.output_types(tool_messages=List[ChatMessage], state=State) def run( self, @@ -480,76 +555,69 @@ class ToolInvoker: ) tool_messages = [] - for message in messages_with_tool_calls: - for tool_call in message.tool_calls: - tool_name = tool_call.tool_name - # Check if the tool is available, otherwise return an error message - if tool_name not in self._tools_with_names: - error_message = self._handle_error( - ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) - ) - tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) - continue + # 1) Collect all tool calls and their parameters for parallel execution + tool_call_params, error_messages = self._prepare_tool_call_params( + messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough + ) + tool_messages.extend(error_messages) - tool_to_invoke = self._tools_with_names[tool_name] + # 2) Execute valid tool calls in parallel + if tool_call_params: + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [] + for params in tool_call_params: + future = executor.submit(self._execute_single_tool_call, **params) # type: ignore[arg-type] + futures.append(future) - # 1) Combine user + state inputs - llm_args = tool_call.arguments.copy() - final_args = self._inject_state_args(tool_to_invoke, llm_args, state) + # 3) Process results in the order they are submitted + for future in futures: + result = future.result() + if isinstance(result, ChatMessage): + tool_messages.append(result) + else: + # Handle state merging and prepare tool result message + tool_call, tool_to_invoke, tool_result = result - # Check whether to inject streaming_callback - if ( - resolved_enable_streaming_passthrough - and streaming_callback is not None - and "streaming_callback" not in final_args - ): - invoke_params = self._get_func_params(tool_to_invoke) - if "streaming_callback" in invoke_params: - final_args["streaming_callback"] = streaming_callback + # 4) Merge outputs into state + try: + self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: + try: + error_message = self._handle_error( + ToolOutputMergeError( + f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}" + ) + ) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e - # 2) Invoke the tool - try: - tool_result = tool_to_invoke.invoke(**final_args) - - except ToolInvocationError as e: - error_message = self._handle_error(e) - tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) - continue - - # 3) Merge outputs into state - try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) - except Exception as e: - try: - error_message = self._handle_error( - ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}") - ) + # 5) Prepare the tool result ChatMessage message tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + self._prepare_tool_result_message( + result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke + ) ) - continue - except ToolOutputMergeError as propagated_e: - # Re-raise with proper error chain - raise propagated_e from e - # 4) Prepare the tool result ChatMessage message - tool_messages.append( - self._prepare_tool_result_message( - result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke - ) - ) - - if streaming_callback is not None: - streaming_callback( - StreamingChunk( - content="", - index=len(tool_messages) - 1, - tool_call_result=tool_messages[-1].tool_call_results[0], - start=True, - meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, - ) - ) + # 6) Handle streaming callback + if streaming_callback is not None: + streaming_callback( + StreamingChunk( + content="", + index=len(tool_messages) - 1, + tool_call_result=tool_messages[-1].tool_call_results[0], + start=True, + meta={ + "tool_result": tool_messages[-1].tool_call_results[0].result, + "tool_call": tool_call, + }, + ) + ) # We stream one more chunk that contains a finish_reason if tool_messages were generated if len(tool_messages) > 0 and streaming_callback is not None: @@ -561,6 +629,31 @@ class ToolInvoker: return {"tool_messages": tool_messages, "state": state} + def _execute_single_tool_call(self, tool_call: ToolCall, tool_to_invoke: Tool, final_args: Dict[str, Any]): + """ + Execute a single tool call. This method is designed to be run in a thread pool. + + :param tool_call: The ToolCall object containing the tool name and arguments. + :param tool_to_invoke: The Tool object that should be invoked. + :param final_args: The final arguments to pass to the tool. + :returns: Either a ChatMessage with error or a tuple of (tool_call, tool_to_invoke, tool_result) + """ + try: + tool_result = tool_to_invoke.invoke(**final_args) + return (tool_call, tool_to_invoke, tool_result) + except ToolInvocationError as e: + error_message = self._handle_error(e) + return ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + + @staticmethod + async def invoke_tool_safely(executor: ThreadPoolExecutor, tool_to_invoke: Tool, final_args: Dict[str, Any]) -> Any: + """Safely invoke a tool with proper exception handling.""" + loop = asyncio.get_running_loop() + try: + return await loop.run_in_executor(executor, partial(tool_to_invoke.invoke, **final_args)) + except ToolInvocationError as e: + return e + @component.output_types(tool_messages=List[ChatMessage], state=State) async def run_async( self, @@ -571,8 +664,9 @@ class ToolInvoker: enable_streaming_callback_passthrough: Optional[bool] = None, ) -> Dict[str, Any]: """ - Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools. + Asynchronously processes ChatMessage objects containing tool calls. + Multiple tool calls are performed concurrently. :param messages: A list of ChatMessage objects. :param state: The runtime state that should be used by the tools. @@ -598,6 +692,7 @@ class ToolInvoker: :raises ToolOutputMergeError: If merging tool outputs into state fails and `raise_on_failure` is True. """ + if state is None: state = State(schema={}) @@ -614,78 +709,78 @@ class ToolInvoker: ) tool_messages = [] - for message in messages_with_tool_calls: - for tool_call in message.tool_calls: - tool_name = tool_call.tool_name - # Check if the tool is available, otherwise return an error message - if tool_name not in self._tools_with_names: - error_message = self._handle_error( - ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) - ) - tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) - continue + # 1) Prepare tool call parameters for execution + tool_call_params, error_messages = self._prepare_tool_call_params( + messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough + ) + tool_messages.extend(error_messages) - tool_to_invoke = self._tools_with_names[tool_name] + # 2) Execute valid tool calls in parallel + if tool_call_params: + with self.executor as executor: + tool_call_tasks = [] + valid_tool_calls = [] - # 1) Combine user + state inputs - llm_args = tool_call.arguments.copy() - final_args = self._inject_state_args(tool_to_invoke, llm_args, state) + # 3) Create async tasks for valid tool calls + for params in tool_call_params: + task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"]) + tool_call_tasks.append(task) + valid_tool_calls.append((params["tool_call"], params["tool_to_invoke"])) - # Check whether to inject streaming_callback - if ( - resolved_enable_streaming_passthrough - and streaming_callback is not None - and "streaming_callback" not in final_args - ): - invoke_params = self._get_func_params(tool_to_invoke) - if "streaming_callback" in invoke_params: - final_args["streaming_callback"] = streaming_callback + if tool_call_tasks: + # 4) Gather results from all tool calls + tool_results = await asyncio.gather(*tool_call_tasks) - # 2) Invoke the tool asynchronously - try: - tool_result = await asyncio.get_running_loop().run_in_executor( - self.executor, partial(tool_to_invoke.invoke, **final_args) - ) + # 5) Process results + for i, ((tool_call, tool_to_invoke), tool_result) in enumerate(zip(valid_tool_calls, tool_results)): + # Check if the tool_result is a ToolInvocationError (caught by our wrapper) + if isinstance(tool_result, ToolInvocationError): + error_message = self._handle_error(tool_result) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue - except ToolInvocationError as e: - error_message = self._handle_error(e) - tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) - continue + # 6) Merge outputs into state + try: + self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: + try: + error_message = self._handle_error( + ToolOutputMergeError( + f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}" + ) + ) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e - # 3) Merge outputs into state - try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) - except Exception as e: - try: - error_message = self._handle_error( - ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}") - ) + # 7) Prepare the tool result ChatMessage message tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + self._prepare_tool_result_message( + result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke + ) ) - continue - except ToolOutputMergeError as propagated_e: - # Re-raise with proper error chain - raise propagated_e from e - # 4) Prepare the tool result ChatMessage message - tool_messages.append( - self._prepare_tool_result_message( - result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke - ) - ) - - if streaming_callback is not None: - await streaming_callback( - StreamingChunk( - content="", - index=len(tool_messages) - 1, - tool_call_result=tool_messages[-1].tool_call_results[0], - start=True, - meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, - ) - ) + # 8) Handle streaming callback + if streaming_callback is not None: + await streaming_callback( + StreamingChunk( + content="", + index=i, + tool_call_result=tool_messages[-1].tool_call_results[0], + start=True, + meta={ + "tool_result": tool_messages[-1].tool_call_results[0].result, + "tool_call": tool_call, + }, + ) + ) # We stream one more chunk that contains a finish_reason if tool_messages were generated if len(tool_messages) > 0 and streaming_callback is not None: diff --git a/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml b/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml new file mode 100644 index 000000000..1ca33c77d --- /dev/null +++ b/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml @@ -0,0 +1,9 @@ +--- +features: + - | + `ToolInvoker` now executes `tool_calls` in parallel for both sync and async mode. + +deprecations: + - | + `async_executor` parameter in `ToolInvoker` is deprecated in favor of `max_workers` parameter and will be removed in Haystack 2.16.0. + You can use `max_workers` parameter to control the number of threads used for parallel tool calling. diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index 9e14fcac5..d6b26cb4c 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -6,6 +6,7 @@ from unittest.mock import patch import pytest import json import datetime +import time from haystack import Pipeline from haystack.components.builders.prompt_builder import PromptBuilder @@ -676,6 +677,147 @@ class TestToolInvoker: ) mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")]) + def test_parallel_tool_calling_with_state_updates(self): + """Test that parallel tool execution with state updates works correctly with the state lock.""" + # Create a shared counter variable to simulate a state value that gets updated + execution_log = [] + + def function_1(): + time.sleep(0.1) + execution_log.append("tool_1_executed") + return {"counter": 1, "tool_name": "tool_1"} + + def function_2(): + time.sleep(0.1) + execution_log.append("tool_2_executed") + return {"counter": 2, "tool_name": "tool_2"} + + def function_3(): + time.sleep(0.1) + execution_log.append("tool_3_executed") + return {"counter": 3, "tool_name": "tool_3"} + + # Create tools that all update the same state key + tool_1 = Tool( + name="state_tool_1", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_1, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + tool_2 = Tool( + name="state_tool_2", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_2, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + tool_3 = Tool( + name="state_tool_3", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_3, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + # Create ToolInvoker with all three tools + invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True) + + state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}}) + tool_calls = [ + ToolCall(tool_name="state_tool_1", arguments={}), + ToolCall(tool_name="state_tool_2", arguments={}), + ToolCall(tool_name="state_tool_3", arguments={}), + ] + message = ChatMessage.from_assistant(tool_calls=tool_calls) + result = invoker.run(messages=[message], state=state) + + # Verify that all three tools were executed + assert len(execution_log) == 3 + assert "tool_1_executed" in execution_log + assert "tool_2_executed" in execution_log + assert "tool_3_executed" in execution_log + + # Verify that the state was updated correctly + # Due to parallel execution, we can't predict which tool will be the last to update + assert state.has("counter") + assert state.has("last_tool") + assert state.get("counter") in [1, 2, 3] # Should be one of the tool values + assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names + + @pytest.mark.asyncio + async def test_async_parallel_tool_calling_with_state_updates(self): + """Test that parallel tool execution with state updates works correctly with the state lock.""" + # Create a shared counter variable to simulate a state value that gets updated + execution_log = [] + + def function_1(): + time.sleep(0.1) + execution_log.append("tool_1_executed") + return {"counter": 1, "tool_name": "tool_1"} + + def function_2(): + time.sleep(0.1) + execution_log.append("tool_2_executed") + return {"counter": 2, "tool_name": "tool_2"} + + def function_3(): + time.sleep(0.1) + execution_log.append("tool_3_executed") + return {"counter": 3, "tool_name": "tool_3"} + + # Create tools that all update the same state key + tool_1 = Tool( + name="state_tool_1", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_1, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + tool_2 = Tool( + name="state_tool_2", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_2, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + tool_3 = Tool( + name="state_tool_3", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_3, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + # Create ToolInvoker with all three tools + invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True) + + state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}}) + tool_calls = [ + ToolCall(tool_name="state_tool_1", arguments={}), + ToolCall(tool_name="state_tool_2", arguments={}), + ToolCall(tool_name="state_tool_3", arguments={}), + ] + message = ChatMessage.from_assistant(tool_calls=tool_calls) + result = await invoker.run_async(messages=[message], state=state) + + # Verify that all three tools were executed + assert len(execution_log) == 3 + assert "tool_1_executed" in execution_log + assert "tool_2_executed" in execution_log + assert "tool_3_executed" in execution_log + + # Verify that the state was updated correctly + # Due to parallel execution, we can't predict which tool will be the last to update + assert state.has("counter") + assert state.has("last_tool") + assert state.get("counter") in [1, 2, 3] # Should be one of the tool values + assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names + class TestMergeToolOutputs: def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):