mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
feat: enable parallel tool execution in ToolInvoker (#9530)
* Enable parallel tool execution in ToolInvoker * Update handling of errors * Small fixes * Small fixes * Adapt number of executors * Add release notes * Add parallel tool calling to sync run * Deprecate async_executor * Deprecate async_executor * Add thread lock * extract methods * Update release notes * Update release notes * Updates * Add new tests * Add test for async * PR comments
This commit is contained in:
parent
91094e1038
commit
1cd0a128d0
@ -5,6 +5,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
@ -169,6 +170,7 @@ class ToolInvoker:
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
*,
|
||||
enable_streaming_callback_passthrough: bool = False,
|
||||
max_workers: int = 4,
|
||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
||||
):
|
||||
"""
|
||||
@ -193,9 +195,15 @@ class ToolInvoker:
|
||||
This allows tools to stream their results back to the client.
|
||||
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
|
||||
If False, the `streaming_callback` will not be passed to the tool invocation.
|
||||
:param max_workers:
|
||||
The maximum number of workers to use in the thread pool executor.
|
||||
:param async_executor:
|
||||
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
|
||||
initialized and used.
|
||||
Optional `ThreadPoolExecutor` to use for asynchronous calls.
|
||||
Note: As of Haystack 2.15.0, you no longer need to explicitly pass
|
||||
`async_executor`. Instead, you can provide the `max_workers` parameter,
|
||||
and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations.
|
||||
Support for `async_executor` will be removed in Haystack 2.16.0.
|
||||
Please migrate to using `max_workers` instead.
|
||||
:raises ValueError:
|
||||
If no tools are provided or if duplicate tool names are found.
|
||||
"""
|
||||
@ -206,6 +214,7 @@ class ToolInvoker:
|
||||
self.tools = tools
|
||||
self.streaming_callback = streaming_callback
|
||||
self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
|
||||
self.max_workers = max_workers
|
||||
|
||||
# Convert Toolset to list for internal use
|
||||
if isinstance(tools, Toolset):
|
||||
@ -223,8 +232,19 @@ class ToolInvoker:
|
||||
self.raise_on_failure = raise_on_failure
|
||||
self.convert_result_to_json_string = convert_result_to_json_string
|
||||
self._owns_executor = async_executor is None
|
||||
if self._owns_executor:
|
||||
warnings.warn(
|
||||
"'async_executor' is deprecated in favor of the 'max_workers' parameter. "
|
||||
"ToolInvoker now creates its own thread pool executor by default using 'max_workers'. "
|
||||
"Support for 'async_executor' will be removed in Haystack 2.16.0. "
|
||||
"Please update your usage to pass 'max_workers' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.executor = (
|
||||
ThreadPoolExecutor(thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=1)
|
||||
ThreadPoolExecutor(
|
||||
thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers
|
||||
)
|
||||
if async_executor is None
|
||||
else async_executor
|
||||
)
|
||||
@ -427,6 +447,61 @@ class ToolInvoker:
|
||||
# Merge other outputs into the state
|
||||
state.set(state_key, output_value, handler_override=handler)
|
||||
|
||||
def _prepare_tool_call_params(
|
||||
self,
|
||||
messages_with_tool_calls: List[ChatMessage],
|
||||
state: State,
|
||||
streaming_callback: Optional[StreamingCallbackT],
|
||||
enable_streaming_passthrough: bool,
|
||||
) -> tuple[List[Dict[str, Any]], List[ChatMessage]]:
|
||||
"""
|
||||
Prepare tool call parameters for execution and collect any error messages.
|
||||
|
||||
:param messages_with_tool_calls: Messages containing tool calls to process
|
||||
:param state: The current state for argument injection
|
||||
:param streaming_callback: Optional streaming callback to inject
|
||||
:param enable_streaming_passthrough: Whether to pass streaming callback to tools
|
||||
:returns: Tuple of (tool_call_params, error_messages)
|
||||
"""
|
||||
tool_call_params = []
|
||||
error_messages = []
|
||||
|
||||
for message in messages_with_tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.tool_name
|
||||
|
||||
# Check if the tool is available, otherwise return an error message
|
||||
if tool_name not in self._tools_with_names:
|
||||
error_message = self._handle_error(
|
||||
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
|
||||
)
|
||||
error_messages.append(
|
||||
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||
)
|
||||
continue
|
||||
|
||||
tool_to_invoke = self._tools_with_names[tool_name]
|
||||
|
||||
# Combine user + state inputs
|
||||
llm_args = tool_call.arguments.copy()
|
||||
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
|
||||
|
||||
# Check whether to inject streaming_callback
|
||||
if (
|
||||
enable_streaming_passthrough
|
||||
and streaming_callback is not None
|
||||
and "streaming_callback" not in final_args
|
||||
):
|
||||
invoke_params = self._get_func_params(tool_to_invoke)
|
||||
if "streaming_callback" in invoke_params:
|
||||
final_args["streaming_callback"] = streaming_callback
|
||||
|
||||
tool_call_params.append(
|
||||
{"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "final_args": final_args}
|
||||
)
|
||||
|
||||
return tool_call_params, error_messages
|
||||
|
||||
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
||||
def run(
|
||||
self,
|
||||
@ -480,76 +555,69 @@ class ToolInvoker:
|
||||
)
|
||||
|
||||
tool_messages = []
|
||||
for message in messages_with_tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.tool_name
|
||||
|
||||
# Check if the tool is available, otherwise return an error message
|
||||
if tool_name not in self._tools_with_names:
|
||||
error_message = self._handle_error(
|
||||
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
|
||||
)
|
||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
||||
continue
|
||||
# 1) Collect all tool calls and their parameters for parallel execution
|
||||
tool_call_params, error_messages = self._prepare_tool_call_params(
|
||||
messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough
|
||||
)
|
||||
tool_messages.extend(error_messages)
|
||||
|
||||
tool_to_invoke = self._tools_with_names[tool_name]
|
||||
# 2) Execute valid tool calls in parallel
|
||||
if tool_call_params:
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
futures = []
|
||||
for params in tool_call_params:
|
||||
future = executor.submit(self._execute_single_tool_call, **params) # type: ignore[arg-type]
|
||||
futures.append(future)
|
||||
|
||||
# 1) Combine user + state inputs
|
||||
llm_args = tool_call.arguments.copy()
|
||||
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
|
||||
# 3) Process results in the order they are submitted
|
||||
for future in futures:
|
||||
result = future.result()
|
||||
if isinstance(result, ChatMessage):
|
||||
tool_messages.append(result)
|
||||
else:
|
||||
# Handle state merging and prepare tool result message
|
||||
tool_call, tool_to_invoke, tool_result = result
|
||||
|
||||
# Check whether to inject streaming_callback
|
||||
if (
|
||||
resolved_enable_streaming_passthrough
|
||||
and streaming_callback is not None
|
||||
and "streaming_callback" not in final_args
|
||||
):
|
||||
invoke_params = self._get_func_params(tool_to_invoke)
|
||||
if "streaming_callback" in invoke_params:
|
||||
final_args["streaming_callback"] = streaming_callback
|
||||
# 4) Merge outputs into state
|
||||
try:
|
||||
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
|
||||
except Exception as e:
|
||||
try:
|
||||
error_message = self._handle_error(
|
||||
ToolOutputMergeError(
|
||||
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
|
||||
)
|
||||
)
|
||||
tool_messages.append(
|
||||
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||
)
|
||||
continue
|
||||
except ToolOutputMergeError as propagated_e:
|
||||
# Re-raise with proper error chain
|
||||
raise propagated_e from e
|
||||
|
||||
# 2) Invoke the tool
|
||||
try:
|
||||
tool_result = tool_to_invoke.invoke(**final_args)
|
||||
|
||||
except ToolInvocationError as e:
|
||||
error_message = self._handle_error(e)
|
||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
||||
continue
|
||||
|
||||
# 3) Merge outputs into state
|
||||
try:
|
||||
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
|
||||
except Exception as e:
|
||||
try:
|
||||
error_message = self._handle_error(
|
||||
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
|
||||
)
|
||||
# 5) Prepare the tool result ChatMessage message
|
||||
tool_messages.append(
|
||||
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||
self._prepare_tool_result_message(
|
||||
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
|
||||
)
|
||||
)
|
||||
continue
|
||||
except ToolOutputMergeError as propagated_e:
|
||||
# Re-raise with proper error chain
|
||||
raise propagated_e from e
|
||||
|
||||
# 4) Prepare the tool result ChatMessage message
|
||||
tool_messages.append(
|
||||
self._prepare_tool_result_message(
|
||||
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
|
||||
)
|
||||
)
|
||||
|
||||
if streaming_callback is not None:
|
||||
streaming_callback(
|
||||
StreamingChunk(
|
||||
content="",
|
||||
index=len(tool_messages) - 1,
|
||||
tool_call_result=tool_messages[-1].tool_call_results[0],
|
||||
start=True,
|
||||
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
|
||||
)
|
||||
)
|
||||
# 6) Handle streaming callback
|
||||
if streaming_callback is not None:
|
||||
streaming_callback(
|
||||
StreamingChunk(
|
||||
content="",
|
||||
index=len(tool_messages) - 1,
|
||||
tool_call_result=tool_messages[-1].tool_call_results[0],
|
||||
start=True,
|
||||
meta={
|
||||
"tool_result": tool_messages[-1].tool_call_results[0].result,
|
||||
"tool_call": tool_call,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# We stream one more chunk that contains a finish_reason if tool_messages were generated
|
||||
if len(tool_messages) > 0 and streaming_callback is not None:
|
||||
@ -561,6 +629,31 @@ class ToolInvoker:
|
||||
|
||||
return {"tool_messages": tool_messages, "state": state}
|
||||
|
||||
def _execute_single_tool_call(self, tool_call: ToolCall, tool_to_invoke: Tool, final_args: Dict[str, Any]):
|
||||
"""
|
||||
Execute a single tool call. This method is designed to be run in a thread pool.
|
||||
|
||||
:param tool_call: The ToolCall object containing the tool name and arguments.
|
||||
:param tool_to_invoke: The Tool object that should be invoked.
|
||||
:param final_args: The final arguments to pass to the tool.
|
||||
:returns: Either a ChatMessage with error or a tuple of (tool_call, tool_to_invoke, tool_result)
|
||||
"""
|
||||
try:
|
||||
tool_result = tool_to_invoke.invoke(**final_args)
|
||||
return (tool_call, tool_to_invoke, tool_result)
|
||||
except ToolInvocationError as e:
|
||||
error_message = self._handle_error(e)
|
||||
return ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||
|
||||
@staticmethod
|
||||
async def invoke_tool_safely(executor: ThreadPoolExecutor, tool_to_invoke: Tool, final_args: Dict[str, Any]) -> Any:
|
||||
"""Safely invoke a tool with proper exception handling."""
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await loop.run_in_executor(executor, partial(tool_to_invoke.invoke, **final_args))
|
||||
except ToolInvocationError as e:
|
||||
return e
|
||||
|
||||
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
||||
async def run_async(
|
||||
self,
|
||||
@ -571,8 +664,9 @@ class ToolInvoker:
|
||||
enable_streaming_callback_passthrough: Optional[bool] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools.
|
||||
Asynchronously processes ChatMessage objects containing tool calls.
|
||||
|
||||
Multiple tool calls are performed concurrently.
|
||||
:param messages:
|
||||
A list of ChatMessage objects.
|
||||
:param state: The runtime state that should be used by the tools.
|
||||
@ -598,6 +692,7 @@ class ToolInvoker:
|
||||
:raises ToolOutputMergeError:
|
||||
If merging tool outputs into state fails and `raise_on_failure` is True.
|
||||
"""
|
||||
|
||||
if state is None:
|
||||
state = State(schema={})
|
||||
|
||||
@ -614,78 +709,78 @@ class ToolInvoker:
|
||||
)
|
||||
|
||||
tool_messages = []
|
||||
for message in messages_with_tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.tool_name
|
||||
|
||||
# Check if the tool is available, otherwise return an error message
|
||||
if tool_name not in self._tools_with_names:
|
||||
error_message = self._handle_error(
|
||||
ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
|
||||
)
|
||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
||||
continue
|
||||
# 1) Prepare tool call parameters for execution
|
||||
tool_call_params, error_messages = self._prepare_tool_call_params(
|
||||
messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough
|
||||
)
|
||||
tool_messages.extend(error_messages)
|
||||
|
||||
tool_to_invoke = self._tools_with_names[tool_name]
|
||||
# 2) Execute valid tool calls in parallel
|
||||
if tool_call_params:
|
||||
with self.executor as executor:
|
||||
tool_call_tasks = []
|
||||
valid_tool_calls = []
|
||||
|
||||
# 1) Combine user + state inputs
|
||||
llm_args = tool_call.arguments.copy()
|
||||
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
|
||||
# 3) Create async tasks for valid tool calls
|
||||
for params in tool_call_params:
|
||||
task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
|
||||
tool_call_tasks.append(task)
|
||||
valid_tool_calls.append((params["tool_call"], params["tool_to_invoke"]))
|
||||
|
||||
# Check whether to inject streaming_callback
|
||||
if (
|
||||
resolved_enable_streaming_passthrough
|
||||
and streaming_callback is not None
|
||||
and "streaming_callback" not in final_args
|
||||
):
|
||||
invoke_params = self._get_func_params(tool_to_invoke)
|
||||
if "streaming_callback" in invoke_params:
|
||||
final_args["streaming_callback"] = streaming_callback
|
||||
if tool_call_tasks:
|
||||
# 4) Gather results from all tool calls
|
||||
tool_results = await asyncio.gather(*tool_call_tasks)
|
||||
|
||||
# 2) Invoke the tool asynchronously
|
||||
try:
|
||||
tool_result = await asyncio.get_running_loop().run_in_executor(
|
||||
self.executor, partial(tool_to_invoke.invoke, **final_args)
|
||||
)
|
||||
# 5) Process results
|
||||
for i, ((tool_call, tool_to_invoke), tool_result) in enumerate(zip(valid_tool_calls, tool_results)):
|
||||
# Check if the tool_result is a ToolInvocationError (caught by our wrapper)
|
||||
if isinstance(tool_result, ToolInvocationError):
|
||||
error_message = self._handle_error(tool_result)
|
||||
tool_messages.append(
|
||||
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||
)
|
||||
continue
|
||||
|
||||
except ToolInvocationError as e:
|
||||
error_message = self._handle_error(e)
|
||||
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
|
||||
continue
|
||||
# 6) Merge outputs into state
|
||||
try:
|
||||
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
|
||||
except Exception as e:
|
||||
try:
|
||||
error_message = self._handle_error(
|
||||
ToolOutputMergeError(
|
||||
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
|
||||
)
|
||||
)
|
||||
tool_messages.append(
|
||||
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||
)
|
||||
continue
|
||||
except ToolOutputMergeError as propagated_e:
|
||||
# Re-raise with proper error chain
|
||||
raise propagated_e from e
|
||||
|
||||
# 3) Merge outputs into state
|
||||
try:
|
||||
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
|
||||
except Exception as e:
|
||||
try:
|
||||
error_message = self._handle_error(
|
||||
ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
|
||||
)
|
||||
# 7) Prepare the tool result ChatMessage message
|
||||
tool_messages.append(
|
||||
ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
|
||||
self._prepare_tool_result_message(
|
||||
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
|
||||
)
|
||||
)
|
||||
continue
|
||||
except ToolOutputMergeError as propagated_e:
|
||||
# Re-raise with proper error chain
|
||||
raise propagated_e from e
|
||||
|
||||
# 4) Prepare the tool result ChatMessage message
|
||||
tool_messages.append(
|
||||
self._prepare_tool_result_message(
|
||||
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
|
||||
)
|
||||
)
|
||||
|
||||
if streaming_callback is not None:
|
||||
await streaming_callback(
|
||||
StreamingChunk(
|
||||
content="",
|
||||
index=len(tool_messages) - 1,
|
||||
tool_call_result=tool_messages[-1].tool_call_results[0],
|
||||
start=True,
|
||||
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
|
||||
)
|
||||
)
|
||||
# 8) Handle streaming callback
|
||||
if streaming_callback is not None:
|
||||
await streaming_callback(
|
||||
StreamingChunk(
|
||||
content="",
|
||||
index=i,
|
||||
tool_call_result=tool_messages[-1].tool_call_results[0],
|
||||
start=True,
|
||||
meta={
|
||||
"tool_result": tool_messages[-1].tool_call_results[0].result,
|
||||
"tool_call": tool_call,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# We stream one more chunk that contains a finish_reason if tool_messages were generated
|
||||
if len(tool_messages) > 0 and streaming_callback is not None:
|
||||
|
@ -0,0 +1,9 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
`ToolInvoker` now executes `tool_calls` in parallel for both sync and async mode.
|
||||
|
||||
deprecations:
|
||||
- |
|
||||
`async_executor` parameter in `ToolInvoker` is deprecated in favor of `max_workers` parameter and will be removed in Haystack 2.16.0.
|
||||
You can use `max_workers` parameter to control the number of threads used for parallel tool calling.
|
@ -6,6 +6,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import json
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
@ -676,6 +677,147 @@ class TestToolInvoker:
|
||||
)
|
||||
mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")])
|
||||
|
||||
def test_parallel_tool_calling_with_state_updates(self):
|
||||
"""Test that parallel tool execution with state updates works correctly with the state lock."""
|
||||
# Create a shared counter variable to simulate a state value that gets updated
|
||||
execution_log = []
|
||||
|
||||
def function_1():
|
||||
time.sleep(0.1)
|
||||
execution_log.append("tool_1_executed")
|
||||
return {"counter": 1, "tool_name": "tool_1"}
|
||||
|
||||
def function_2():
|
||||
time.sleep(0.1)
|
||||
execution_log.append("tool_2_executed")
|
||||
return {"counter": 2, "tool_name": "tool_2"}
|
||||
|
||||
def function_3():
|
||||
time.sleep(0.1)
|
||||
execution_log.append("tool_3_executed")
|
||||
return {"counter": 3, "tool_name": "tool_3"}
|
||||
|
||||
# Create tools that all update the same state key
|
||||
tool_1 = Tool(
|
||||
name="state_tool_1",
|
||||
description="A tool that updates state counter",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
function=function_1,
|
||||
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||
)
|
||||
|
||||
tool_2 = Tool(
|
||||
name="state_tool_2",
|
||||
description="A tool that updates state counter",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
function=function_2,
|
||||
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||
)
|
||||
|
||||
tool_3 = Tool(
|
||||
name="state_tool_3",
|
||||
description="A tool that updates state counter",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
function=function_3,
|
||||
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||
)
|
||||
|
||||
# Create ToolInvoker with all three tools
|
||||
invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
|
||||
|
||||
state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
|
||||
tool_calls = [
|
||||
ToolCall(tool_name="state_tool_1", arguments={}),
|
||||
ToolCall(tool_name="state_tool_2", arguments={}),
|
||||
ToolCall(tool_name="state_tool_3", arguments={}),
|
||||
]
|
||||
message = ChatMessage.from_assistant(tool_calls=tool_calls)
|
||||
result = invoker.run(messages=[message], state=state)
|
||||
|
||||
# Verify that all three tools were executed
|
||||
assert len(execution_log) == 3
|
||||
assert "tool_1_executed" in execution_log
|
||||
assert "tool_2_executed" in execution_log
|
||||
assert "tool_3_executed" in execution_log
|
||||
|
||||
# Verify that the state was updated correctly
|
||||
# Due to parallel execution, we can't predict which tool will be the last to update
|
||||
assert state.has("counter")
|
||||
assert state.has("last_tool")
|
||||
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
|
||||
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parallel_tool_calling_with_state_updates(self):
|
||||
"""Test that parallel tool execution with state updates works correctly with the state lock."""
|
||||
# Create a shared counter variable to simulate a state value that gets updated
|
||||
execution_log = []
|
||||
|
||||
def function_1():
|
||||
time.sleep(0.1)
|
||||
execution_log.append("tool_1_executed")
|
||||
return {"counter": 1, "tool_name": "tool_1"}
|
||||
|
||||
def function_2():
|
||||
time.sleep(0.1)
|
||||
execution_log.append("tool_2_executed")
|
||||
return {"counter": 2, "tool_name": "tool_2"}
|
||||
|
||||
def function_3():
|
||||
time.sleep(0.1)
|
||||
execution_log.append("tool_3_executed")
|
||||
return {"counter": 3, "tool_name": "tool_3"}
|
||||
|
||||
# Create tools that all update the same state key
|
||||
tool_1 = Tool(
|
||||
name="state_tool_1",
|
||||
description="A tool that updates state counter",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
function=function_1,
|
||||
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||
)
|
||||
|
||||
tool_2 = Tool(
|
||||
name="state_tool_2",
|
||||
description="A tool that updates state counter",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
function=function_2,
|
||||
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||
)
|
||||
|
||||
tool_3 = Tool(
|
||||
name="state_tool_3",
|
||||
description="A tool that updates state counter",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
function=function_3,
|
||||
outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
|
||||
)
|
||||
|
||||
# Create ToolInvoker with all three tools
|
||||
invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
|
||||
|
||||
state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
|
||||
tool_calls = [
|
||||
ToolCall(tool_name="state_tool_1", arguments={}),
|
||||
ToolCall(tool_name="state_tool_2", arguments={}),
|
||||
ToolCall(tool_name="state_tool_3", arguments={}),
|
||||
]
|
||||
message = ChatMessage.from_assistant(tool_calls=tool_calls)
|
||||
result = await invoker.run_async(messages=[message], state=state)
|
||||
|
||||
# Verify that all three tools were executed
|
||||
assert len(execution_log) == 3
|
||||
assert "tool_1_executed" in execution_log
|
||||
assert "tool_2_executed" in execution_log
|
||||
assert "tool_3_executed" in execution_log
|
||||
|
||||
# Verify that the state was updated correctly
|
||||
# Due to parallel execution, we can't predict which tool will be the last to update
|
||||
assert state.has("counter")
|
||||
assert state.has("last_tool")
|
||||
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
|
||||
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
|
||||
|
||||
|
||||
class TestMergeToolOutputs:
|
||||
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):
|
||||
|
Loading…
x
Reference in New Issue
Block a user