mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
feat: enable parallel tool execution in ToolInvoker (#9530)
* Enable parallel tool execution in ToolInvoker * Update handling of errors * Small fixes * Small fixes * Adapt number of executors * Add release notes * Add parallel tool calling to sync run * Deprecate async_executor * Deprecate async_executor * Add thread lock * extract methods * Update release notes * Update release notes * Updates * Add new tests * Add test for async * PR comments
This commit is contained in:
parent
91094e1038
commit
1cd0a128d0
@ -5,6 +5,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
@ -169,6 +170,7 @@ class ToolInvoker:
|
|||||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
*,
|
*,
|
||||||
enable_streaming_callback_passthrough: bool = False,
|
enable_streaming_callback_passthrough: bool = False,
|
||||||
|
max_workers: int = 4,
|
||||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
async_executor: Optional[ThreadPoolExecutor] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -193,9 +195,15 @@ class ToolInvoker:
|
|||||||
This allows tools to stream their results back to the client.
|
This allows tools to stream their results back to the client.
|
||||||
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
|
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
|
||||||
If False, the `streaming_callback` will not be passed to the tool invocation.
|
If False, the `streaming_callback` will not be passed to the tool invocation.
|
||||||
|
:param max_workers:
|
||||||
|
The maximum number of workers to use in the thread pool executor.
|
||||||
:param async_executor:
|
:param async_executor:
|
||||||
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
|
Optional `ThreadPoolExecutor` to use for asynchronous calls.
|
||||||
initialized and used.
|
Note: As of Haystack 2.15.0, you no longer need to explicitly pass
|
||||||
|
`async_executor`. Instead, you can provide the `max_workers` parameter,
|
||||||
|
and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations.
|
||||||
|
Support for `async_executor` will be removed in Haystack 2.16.0.
|
||||||
|
Please migrate to using `max_workers` instead.
|
||||||
:raises ValueError:
|
:raises ValueError:
|
||||||
If no tools are provided or if duplicate tool names are found.
|
If no tools are provided or if duplicate tool names are found.
|
||||||
"""
|
"""
|
||||||
@ -206,6 +214,7 @@ class ToolInvoker:
|
|||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.streaming_callback = streaming_callback
|
self.streaming_callback = streaming_callback
|
||||||
self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
|
self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
|
||||||
|
self.max_workers = max_workers
|
||||||
|
|
||||||
# Convert Toolset to list for internal use
|
# Convert Toolset to list for internal use
|
||||||
if isinstance(tools, Toolset):
|
if isinstance(tools, Toolset):
|
||||||
@ -223,8 +232,19 @@ class ToolInvoker:
|
|||||||
self.raise_on_failure = raise_on_failure
|
self.raise_on_failure = raise_on_failure
|
||||||
self.convert_result_to_json_string = convert_result_to_json_string
|
self.convert_result_to_json_string = convert_result_to_json_string
|
||||||
self._owns_executor = async_executor is None
|
self._owns_executor = async_executor is None
|
||||||
|
if self._owns_executor:
|
||||||
|
warnings.warn(
|
||||||
|
"'async_executor' is deprecated in favor of the 'max_workers' parameter. "
|
||||||
|
"ToolInvoker now creates its own thread pool executor by default using 'max_workers'. "
|
||||||
|
"Support for 'async_executor' will be removed in Haystack 2.16.0. "
|
||||||
|
"Please update your usage to pass 'max_workers' instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
|
||||||
self.executor = (
|
self.executor = (
|
||||||
ThreadPoolExecutor(thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=1)
|
ThreadPoolExecutor(
|
||||||
|
thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers
|
||||||
|
)
|
||||||
if async_executor is None
|
if async_executor is None
|
||||||
else async_executor
|
else async_executor
|
||||||
)
|
)
|
||||||
@ -427,6 +447,61 @@ class ToolInvoker:
|
|||||||
# Merge other outputs into the state
|
# Merge other outputs into the state
|
||||||
state.set(state_key, output_value, handler_override=handler)
|
state.set(state_key, output_value, handler_override=handler)
|
||||||
|
|
||||||
|
def _prepare_tool_call_params(
|
||||||
|
self,
|
||||||
|
messages_with_tool_calls: List[ChatMessage],
|
||||||
|
state: State,
|
||||||
|
streaming_callback: Optional[StreamingCallbackT],
|
||||||
|
enable_streaming_passthrough: bool,
|
||||||
|
) -> tuple[List[Dict[str, Any]], List[ChatMessage]]:
|
||||||
|
"""
|
||||||
|
Prepare tool call parameters for execution and collect any error messages.
|
||||||
|
|
||||||
|
:param messages_with_tool_calls: Messages containing tool calls to process
|
||||||
|
:param state: The current state for argument injection
|
||||||
|
:param streaming_callback: Optional streaming callback to inject
|
||||||
|
:param enable_streaming_passthrough: Whether to pass streaming callback to tools
|
||||||
|
:returns: Tuple of (tool_call_params, error_messages)
|
||||||
|
"""
|
||||||
|
tool_call_params = []
|
||||||
|
error_messages = []
|
||||||
|
|
||||||
|
for message in messages_with_tool_calls:
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
tool_name = tool_call.tool_name
|
||||||
|
|
||||||
|
# Check if the tool is available, otherwise return an error message
|
||||||
|
if tool_name not in self._tools_with_names:
|
||||||
|
error_message = self._handle_error(
|
||||||
|
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
|
||||||
|
)
|
||||||
|
error_messages.append(
|
||||||
|
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_to_invoke = self._tools_with_names[tool_name]
|
||||||
|
|
||||||
|
# Combine user + state inputs
|
||||||
|
llm_args = tool_call.arguments.copy()
|
||||||
|
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
|
||||||
|
|
||||||
|
# Check whether to inject streaming_callback
|
||||||
|
if (
|
||||||
|
enable_streaming_passthrough
|
||||||
|
and streaming_callback is not None
|
||||||
|
and "streaming_callback" not in final_args
|
||||||
|
):
|
||||||
|
invoke_params = self._get_func_params(tool_to_invoke)
|
||||||
|
if "streaming_callback" in invoke_params:
|
||||||
|
final_args["streaming_callback"] = streaming_callback
|
||||||
|
|
||||||
|
tool_call_params.append(
|
||||||
|
{"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "final_args": final_args}
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_call_params, error_messages
|
||||||
|
|
||||||
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
@ -480,76 +555,69 @@ class ToolInvoker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
tool_messages = []
|
tool_messages = []
|
||||||
for message in messages_with_tool_calls:
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
tool_name = tool_call.tool_name
|
|
||||||
|
|
||||||
# Check if the tool is available, otherwise return an error message
|
# 1) Collect all tool calls and their parameters for parallel execution
|
||||||
if tool_name not in self._tools_with_names:
|
tool_call_params, error_messages = self._prepare_tool_call_params(
|
||||||
error_message = self._handle_error(
|
messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough
|
||||||
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
|
)
|
||||||
)
|
tool_messages.extend(error_messages)
|
||||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_to_invoke = self._tools_with_names[tool_name]
|
# 2) Execute valid tool calls in parallel
|
||||||
|
if tool_call_params:
|
||||||
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
|
futures = []
|
||||||
|
for params in tool_call_params:
|
||||||
|
future = executor.submit(self._execute_single_tool_call, **params) # type: ignore[arg-type]
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
# 1) Combine user + state inputs
|
# 3) Process results in the order they are submitted
|
||||||
llm_args = tool_call.arguments.copy()
|
for future in futures:
|
||||||
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
|
result = future.result()
|
||||||
|
if isinstance(result, ChatMessage):
|
||||||
|
tool_messages.append(result)
|
||||||
|
else:
|
||||||
|
# Handle state merging and prepare tool result message
|
||||||
|
tool_call, tool_to_invoke, tool_result = result
|
||||||
|
|
||||||
# Check whether to inject streaming_callback
|
# 4) Merge outputs into state
|
||||||
if (
|
try:
|
||||||
resolved_enable_streaming_passthrough
|
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
|
||||||
and streaming_callback is not None
|
except Exception as e:
|
||||||
and "streaming_callback" not in final_args
|
try:
|
||||||
):
|
error_message = self._handle_error(
|
||||||
invoke_params = self._get_func_params(tool_to_invoke)
|
ToolOutputMergeError(
|
||||||
if "streaming_callback" in invoke_params:
|
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
|
||||||
final_args["streaming_callback"] = streaming_callback
|
)
|
||||||
|
)
|
||||||
|
tool_messages.append(
|
||||||
|
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except ToolOutputMergeError as propagated_e:
|
||||||
|
# Re-raise with proper error chain
|
||||||
|
raise propagated_e from e
|
||||||
|
|
||||||
# 2) Invoke the tool
|
# 5) Prepare the tool result ChatMessage message
|
||||||
try:
|
|
||||||
tool_result = tool_to_invoke.invoke(**final_args)
|
|
||||||
|
|
||||||
except ToolInvocationError as e:
|
|
||||||
error_message = self._handle_error(e)
|
|
||||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 3) Merge outputs into state
|
|
||||||
try:
|
|
||||||
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
|
|
||||||
except Exception as e:
|
|
||||||
try:
|
|
||||||
error_message = self._handle_error(
|
|
||||||
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
|
|
||||||
)
|
|
||||||
tool_messages.append(
|
tool_messages.append(
|
||||||
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
self._prepare_tool_result_message(
|
||||||
|
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
|
||||||
|
)
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
except ToolOutputMergeError as propagated_e:
|
|
||||||
# Re-raise with proper error chain
|
|
||||||
raise propagated_e from e
|
|
||||||
|
|
||||||
# 4) Prepare the tool result ChatMessage message
|
# 6) Handle streaming callback
|
||||||
tool_messages.append(
|
if streaming_callback is not None:
|
||||||
self._prepare_tool_result_message(
|
streaming_callback(
|
||||||
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
|
StreamingChunk(
|
||||||
)
|
content="",
|
||||||
)
|
index=len(tool_messages) - 1,
|
||||||
|
tool_call_result=tool_messages[-1].tool_call_results[0],
|
||||||
if streaming_callback is not None:
|
start=True,
|
||||||
streaming_callback(
|
meta={
|
||||||
StreamingChunk(
|
"tool_result": tool_messages[-1].tool_call_results[0].result,
|
||||||
content="",
|
"tool_call": tool_call,
|
||||||
index=len(tool_messages) - 1,
|
},
|
||||||
tool_call_result=tool_messages[-1].tool_call_results[0],
|
)
|
||||||
start=True,
|
)
|
||||||
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# We stream one more chunk that contains a finish_reason if tool_messages were generated
|
# We stream one more chunk that contains a finish_reason if tool_messages were generated
|
||||||
if len(tool_messages) > 0 and streaming_callback is not None:
|
if len(tool_messages) > 0 and streaming_callback is not None:
|
||||||
@ -561,6 +629,31 @@ class ToolInvoker:
|
|||||||
|
|
||||||
return {"tool_messages": tool_messages, "state": state}
|
return {"tool_messages": tool_messages, "state": state}
|
||||||
|
|
||||||
|
def _execute_single_tool_call(self, tool_call: ToolCall, tool_to_invoke: Tool, final_args: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Execute a single tool call. This method is designed to be run in a thread pool.
|
||||||
|
|
||||||
|
:param tool_call: The ToolCall object containing the tool name and arguments.
|
||||||
|
:param tool_to_invoke: The Tool object that should be invoked.
|
||||||
|
:param final_args: The final arguments to pass to the tool.
|
||||||
|
:returns: Either a ChatMessage with error or a tuple of (tool_call, tool_to_invoke, tool_result)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tool_result = tool_to_invoke.invoke(**final_args)
|
||||||
|
return (tool_call, tool_to_invoke, tool_result)
|
||||||
|
except ToolInvocationError as e:
|
||||||
|
error_message = self._handle_error(e)
|
||||||
|
return ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def invoke_tool_safely(executor: ThreadPoolExecutor, tool_to_invoke: Tool, final_args: Dict[str, Any]) -> Any:
|
||||||
|
"""Safely invoke a tool with proper exception handling."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
return await loop.run_in_executor(executor, partial(tool_to_invoke.invoke, **final_args))
|
||||||
|
except ToolInvocationError as e:
|
||||||
|
return e
|
||||||
|
|
||||||
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
||||||
async def run_async(
|
async def run_async(
|
||||||
self,
|
self,
|
||||||
@ -571,8 +664,9 @@ class ToolInvoker:
|
|||||||
enable_streaming_callback_passthrough: Optional[bool] = None,
|
enable_streaming_callback_passthrough: Optional[bool] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools.
|
Asynchronously processes ChatMessage objects containing tool calls.
|
||||||
|
|
||||||
|
Multiple tool calls are performed concurrently.
|
||||||
:param messages:
|
:param messages:
|
||||||
A list of ChatMessage objects.
|
A list of ChatMessage objects.
|
||||||
:param state: The runtime state that should be used by the tools.
|
:param state: The runtime state that should be used by the tools.
|
||||||
@ -598,6 +692,7 @@ class ToolInvoker:
|
|||||||
:raises ToolOutputMergeError:
|
:raises ToolOutputMergeError:
|
||||||
If merging tool outputs into state fails and `raise_on_failure` is True.
|
If merging tool outputs into state fails and `raise_on_failure` is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if state is None:
|
if state is None:
|
||||||
state = State(schema={})
|
state = State(schema={})
|
||||||
|
|
||||||
@ -614,78 +709,78 @@ class ToolInvoker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
tool_messages = []
|
tool_messages = []
|
||||||
for message in messages_with_tool_calls:
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
tool_name = tool_call.tool_name
|
|
||||||
|
|
||||||
# Check if the tool is available, otherwise return an error message
|
# 1) Prepare tool call parameters for execution
|
||||||
if tool_name not in self._tools_with_names:
|
tool_call_params, error_messages = self._prepare_tool_call_params(
|
||||||
error_message = self._handle_error(
|
messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough
|
||||||
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
|
)
|
||||||
)
|
tool_messages.extend(error_messages)
|
||||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_to_invoke = self._tools_with_names[tool_name]
|
# 2) Execute valid tool calls in parallel
|
||||||
|
if tool_call_params:
|
||||||
|
with self.executor as executor:
|
||||||
|
tool_call_tasks = []
|
||||||
|
valid_tool_calls = []
|
||||||
|
|
||||||
# 1) Combine user + state inputs
|
# 3) Create async tasks for valid tool calls
|
||||||
llm_args = tool_call.arguments.copy()
|
for params in tool_call_params:
|
||||||
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
|
task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
|
||||||
|
tool_call_tasks.append(task)
|
||||||
|
valid_tool_calls.append((params["tool_call"], params["tool_to_invoke"]))
|
||||||
|
|
||||||
# Check whether to inject streaming_callback
|
if tool_call_tasks:
|
||||||
if (
|
# 4) Gather results from all tool calls
|
||||||
resolved_enable_streaming_passthrough
|
tool_results = await asyncio.gather(*tool_call_tasks)
|
||||||
and streaming_callback is not None
|
|
||||||
and "streaming_callback" not in final_args
|
|
||||||
):
|
|
||||||
invoke_params = self._get_func_params(tool_to_invoke)
|
|
||||||
if "streaming_callback" in invoke_params:
|
|
||||||
final_args["streaming_callback"] = streaming_callback
|
|
||||||
|
|
||||||
# 2) Invoke the tool asynchronously
|
# 5) Process results
|
||||||
try:
|
for i, ((tool_call, tool_to_invoke), tool_result) in enumerate(zip(valid_tool_calls, tool_results)):
|
||||||
tool_result = await asyncio.get_running_loop().run_in_executor(
|
# Check if the tool_result is a ToolInvocationError (caught by our wrapper)
|
||||||
self.executor, partial(tool_to_invoke.invoke, **final_args)
|
if isinstance(tool_result, ToolInvocationError):
|
||||||
)
|
error_message = self._handle_error(tool_result)
|
||||||
|
tool_messages.append(
|
||||||
|
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
except ToolInvocationError as e:
|
# 6) Merge outputs into state
|
||||||
error_message = self._handle_error(e)
|
try:
|
||||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
|
||||||
continue
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
error_message = self._handle_error(
|
||||||
|
ToolOutputMergeError(
|
||||||
|
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tool_messages.append(
|
||||||
|
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except ToolOutputMergeError as propagated_e:
|
||||||
|
# Re-raise with proper error chain
|
||||||
|
raise propagated_e from e
|
||||||
|
|
||||||
# 3) Merge outputs into state
|
# 7) Prepare the tool result ChatMessage message
|
||||||
try:
|
|
||||||
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
|
|
||||||
except Exception as e:
|
|
||||||
try:
|
|
||||||
error_message = self._handle_error(
|
|
||||||
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
|
|
||||||
)
|
|
||||||
tool_messages.append(
|
tool_messages.append(
|
||||||
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
self._prepare_tool_result_message(
|
||||||
|
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
|
||||||
|
)
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
except ToolOutputMergeError as propagated_e:
|
|
||||||
# Re-raise with proper error chain
|
|
||||||
raise propagated_e from e
|
|
||||||
|
|
||||||
# 4) Prepare the tool result ChatMessage message
|
# 8) Handle streaming callback
|
||||||
tool_messages.append(
|
if streaming_callback is not None:
|
||||||
self._prepare_tool_result_message(
|
await streaming_callback(
|
||||||
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
|
StreamingChunk(
|
||||||
)
|
content="",
|
||||||
)
|
index=i,
|
||||||
|
tool_call_result=tool_messages[-1].tool_call_results[0],
|
||||||
if streaming_callback is not None:
|
start=True,
|
||||||
await streaming_callback(
|
meta={
|
||||||
StreamingChunk(
|
"tool_result": tool_messages[-1].tool_call_results[0].result,
|
||||||
content="",
|
"tool_call": tool_call,
|
||||||
index=len(tool_messages) - 1,
|
},
|
||||||
tool_call_result=tool_messages[-1].tool_call_results[0],
|
)
|
||||||
start=True,
|
)
|
||||||
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# We stream one more chunk that contains a finish_reason if tool_messages were generated
|
# We stream one more chunk that contains a finish_reason if tool_messages were generated
|
||||||
if len(tool_messages) > 0 and streaming_callback is not None:
|
if len(tool_messages) > 0 and streaming_callback is not None:
|
||||||
|
@ -0,0 +1,9 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
`ToolInvoker` now executes `tool_calls` in parallel for both sync and async mode.
|
||||||
|
|
||||||
|
deprecations:
|
||||||
|
- |
|
||||||
|
`async_executor` parameter in `ToolInvoker` is deprecated in favor of `max_workers` parameter and will be removed in Haystack 2.16.0.
|
||||||
|
You can use `max_workers` parameter to control the number of threads used for parallel tool calling.
|
@ -6,6 +6,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import json
|
import json
|
||||||
import datetime
|
import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||||
@ -676,6 +677,147 @@ class TestToolInvoker:
|
|||||||
)
|
)
|
||||||
mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")])
|
mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")])
|
||||||
|
|
||||||
|
def test_parallel_tool_calling_with_state_updates(self):
|
||||||
|
"""Test that parallel tool execution with state updates works correctly with the state lock."""
|
||||||
|
# Create a shared counter variable to simulate a state value that gets updated
|
||||||
|
execution_log = []
|
||||||
|
|
||||||
|
def function_1():
|
||||||
|
time.sleep(0.1)
|
||||||
|
execution_log.append("tool_1_executed")
|
||||||
|
return {"counter": 1, "tool_name": "tool_1"}
|
||||||
|
|
||||||
|
def function_2():
|
||||||
|
time.sleep(0.1)
|
||||||
|
execution_log.append("tool_2_executed")
|
||||||
|
return {"counter": 2, "tool_name": "tool_2"}
|
||||||
|
|
||||||
|
def function_3():
|
||||||
|
time.sleep(0.1)
|
||||||
|
execution_log.append("tool_3_executed")
|
||||||
|
return {"counter": 3, "tool_name": "tool_3"}
|
||||||
|
|
||||||
|
# Create tools that all update the same state key
|
||||||
|
tool_1 = Tool(
|
||||||
|
name="state_tool_1",
|
||||||
|
description="A tool that updates state counter",
|
||||||
|
parameters={"type": "object", "properties": {}},
|
||||||
|
function=function_1,
|
||||||
|
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_2 = Tool(
|
||||||
|
name="state_tool_2",
|
||||||
|
description="A tool that updates state counter",
|
||||||
|
parameters={"type": "object", "properties": {}},
|
||||||
|
function=function_2,
|
||||||
|
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_3 = Tool(
|
||||||
|
name="state_tool_3",
|
||||||
|
description="A tool that updates state counter",
|
||||||
|
parameters={"type": "object", "properties": {}},
|
||||||
|
function=function_3,
|
||||||
|
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create ToolInvoker with all three tools
|
||||||
|
invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
|
||||||
|
|
||||||
|
state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(tool_name="state_tool_1", arguments={}),
|
||||||
|
ToolCall(tool_name="state_tool_2", arguments={}),
|
||||||
|
ToolCall(tool_name="state_tool_3", arguments={}),
|
||||||
|
]
|
||||||
|
message = ChatMessage.from_assistant(tool_calls=tool_calls)
|
||||||
|
result = invoker.run(messages=[message], state=state)
|
||||||
|
|
||||||
|
# Verify that all three tools were executed
|
||||||
|
assert len(execution_log) == 3
|
||||||
|
assert "tool_1_executed" in execution_log
|
||||||
|
assert "tool_2_executed" in execution_log
|
||||||
|
assert "tool_3_executed" in execution_log
|
||||||
|
|
||||||
|
# Verify that the state was updated correctly
|
||||||
|
# Due to parallel execution, we can't predict which tool will be the last to update
|
||||||
|
assert state.has("counter")
|
||||||
|
assert state.has("last_tool")
|
||||||
|
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
|
||||||
|
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_parallel_tool_calling_with_state_updates(self):
|
||||||
|
"""Test that parallel tool execution with state updates works correctly with the state lock."""
|
||||||
|
# Create a shared counter variable to simulate a state value that gets updated
|
||||||
|
execution_log = []
|
||||||
|
|
||||||
|
def function_1():
|
||||||
|
time.sleep(0.1)
|
||||||
|
execution_log.append("tool_1_executed")
|
||||||
|
return {"counter": 1, "tool_name": "tool_1"}
|
||||||
|
|
||||||
|
def function_2():
|
||||||
|
time.sleep(0.1)
|
||||||
|
execution_log.append("tool_2_executed")
|
||||||
|
return {"counter": 2, "tool_name": "tool_2"}
|
||||||
|
|
||||||
|
def function_3():
|
||||||
|
time.sleep(0.1)
|
||||||
|
execution_log.append("tool_3_executed")
|
||||||
|
return {"counter": 3, "tool_name": "tool_3"}
|
||||||
|
|
||||||
|
# Create tools that all update the same state key
|
||||||
|
tool_1 = Tool(
|
||||||
|
name="state_tool_1",
|
||||||
|
description="A tool that updates state counter",
|
||||||
|
parameters={"type": "object", "properties": {}},
|
||||||
|
function=function_1,
|
||||||
|
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_2 = Tool(
|
||||||
|
name="state_tool_2",
|
||||||
|
description="A tool that updates state counter",
|
||||||
|
parameters={"type": "object", "properties": {}},
|
||||||
|
function=function_2,
|
||||||
|
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_3 = Tool(
|
||||||
|
name="state_tool_3",
|
||||||
|
description="A tool that updates state counter",
|
||||||
|
parameters={"type": "object", "properties": {}},
|
||||||
|
function=function_3,
|
||||||
|
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create ToolInvoker with all three tools
|
||||||
|
invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
|
||||||
|
|
||||||
|
state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(tool_name="state_tool_1", arguments={}),
|
||||||
|
ToolCall(tool_name="state_tool_2", arguments={}),
|
||||||
|
ToolCall(tool_name="state_tool_3", arguments={}),
|
||||||
|
]
|
||||||
|
message = ChatMessage.from_assistant(tool_calls=tool_calls)
|
||||||
|
result = await invoker.run_async(messages=[message], state=state)
|
||||||
|
|
||||||
|
# Verify that all three tools were executed
|
||||||
|
assert len(execution_log) == 3
|
||||||
|
assert "tool_1_executed" in execution_log
|
||||||
|
assert "tool_2_executed" in execution_log
|
||||||
|
assert "tool_3_executed" in execution_log
|
||||||
|
|
||||||
|
# Verify that the state was updated correctly
|
||||||
|
# Due to parallel execution, we can't predict which tool will be the last to update
|
||||||
|
assert state.has("counter")
|
||||||
|
assert state.has("last_tool")
|
||||||
|
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
|
||||||
|
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
|
||||||
|
|
||||||
|
|
||||||
class TestMergeToolOutputs:
|
class TestMergeToolOutputs:
|
||||||
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):
|
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user