feat: Add run_async to Agent (#9239)

* add run_async

* refactor with _check_exit_conditions

* add run_async tests

* reno

* fix linting issues
This commit is contained in:
Julian Risch 2025-04-14 21:01:59 +02:00 committed by GitHub
parent c67d1bf0e9
commit 13780cfcc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 211 additions and 27 deletions

View File

@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import inspect
from copy import deepcopy
from typing import Any, Dict, List, Optional
@ -13,7 +14,7 @@ from haystack.core.serialization import component_to_dict
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
from haystack.dataclasses.state_utils import merge_lists
from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.tools import Tool, deserialize_tools_or_toolset_inplace
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
@ -63,7 +64,7 @@ class Agent:
state_schema: Optional[Dict[str, Any]] = None,
max_agent_steps: int = 100,
raise_on_tool_invocation_failure: bool = False,
streaming_callback: Optional[SyncStreamingCallbackT] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
):
"""
Initialize the agent component.
@ -189,7 +190,7 @@ class Agent:
def run(
self,
messages: List[ChatMessage],
streaming_callback: Optional[SyncStreamingCallbackT] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
**kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""
@ -235,28 +236,8 @@ class Agent:
state.set("messages", tool_messages)
# 4. Check if any LLM message's tool call name matches an exit condition
if self.exit_conditions != ["text"]:
matched_exit_conditions = set()
has_errors = False
for msg in llm_messages:
if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions:
matched_exit_conditions.add(msg.tool_call.tool_name)
# Check if any error is specifically from the tool matching the exit condition
tool_errors = [
tool_msg.tool_call_result.error
for tool_msg in tool_messages
if tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name
]
if any(tool_errors):
has_errors = True
# No need to check further if we found an error
break
# Only return if at least one exit condition was matched AND none had errors
if matched_exit_conditions and not has_errors:
return {**state.data}
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
return {**state.data}
# 5. Fetch the combined messages and send them back to the LLM
messages = state.get("messages")
@ -266,3 +247,116 @@ class Agent:
"Agent exceeded maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps
)
return {**state.data}
async def run_async(
self,
messages: List[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
**kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""
Asynchronously process messages and execute tools until the exit condition is met.
This is the asynchronous version of the `run` method. It follows the same logic but uses
asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator
if available.
:param messages: List of chat messages to process
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
:param kwargs: Additional data to pass to the State schema used by the Agent.
The keys must match the schema defined in the Agent's `state_schema`.
:return: Dictionary containing messages and outputs matching the defined output types
"""
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.")
state = State(schema=self.state_schema, data=kwargs)
if self.system_prompt is not None:
messages = [ChatMessage.from_system(self.system_prompt)] + messages
state.set("messages", messages)
generator_inputs: Dict[str, Any] = {"tools": self.tools}
selected_callback = streaming_callback or self.streaming_callback
if selected_callback is not None:
generator_inputs["streaming_callback"] = selected_callback
# Repeat until the exit condition is met
counter = 0
while counter < self.max_agent_steps:
# 1. Call the ChatGenerator
# Check if the chat generator supports async execution
if getattr(self.chat_generator, "__haystack_supports_async__", False):
result = await self.chat_generator.run_async(messages=messages, **generator_inputs) # type: ignore[attr-defined]
llm_messages = result["replies"]
else:
# Fall back to synchronous run if async is not available
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None, lambda: self.chat_generator.run(messages=messages, **generator_inputs)
)
llm_messages = result["replies"]
state.set("messages", llm_messages)
# 2. Check if any of the LLM responses contain a tool call
if not any(msg.tool_call for msg in llm_messages):
return {**state.data}
# 3. Call the ToolInvoker
# We only send the messages from the LLM to the tool invoker
# Check if the ToolInvoker supports async execution. Currently, it doesn't.
if getattr(self._tool_invoker, "__haystack_supports_async__", False):
tool_invoker_result = await self._tool_invoker.run_async(messages=llm_messages, state=state) # type: ignore[attr-defined]
else:
loop = asyncio.get_running_loop()
tool_invoker_result = await loop.run_in_executor(
None, lambda: self._tool_invoker.run(messages=llm_messages, state=state)
)
tool_messages = tool_invoker_result["tool_messages"]
state = tool_invoker_result["state"]
state.set("messages", tool_messages)
# 4. Check if any LLM message's tool call name matches an exit condition
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
return {**state.data}
# 5. Fetch the combined messages and send them back to the LLM
messages = state.get("messages")
counter += 1
logger.warning(
"Agent exceeded maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps
)
return {**state.data}
def _check_exit_conditions(self, llm_messages: List[ChatMessage], tool_messages: List[ChatMessage]) -> bool:
"""
Check if any of the LLM messages' tool calls match an exit condition and if there are no errors.
:param llm_messages: List of messages from the LLM
:param tool_messages: List of messages from the tool invoker
:return: True if an exit condition is met and there are no errors, False otherwise
"""
matched_exit_conditions = set()
has_errors = False
for msg in llm_messages:
if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions:
matched_exit_conditions.add(msg.tool_call.tool_name)
# Check if any error is specifically from the tool matching the exit condition
tool_errors = [
tool_msg.tool_call_result.error
for tool_msg in tool_messages
if tool_msg.tool_call_result is not None
and tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name
]
if any(tool_errors):
has_errors = True
# No need to check further if we found an error
break
# Only return True if at least one exit condition was matched AND none had errors
return bool(matched_exit_conditions) and not has_errors

View File

@ -0,0 +1,4 @@
---
features:
- |
Add a run_async to the Agent, which calls the run_async of the underlying ChatGenerator if available.

View File

@ -4,9 +4,9 @@
import os
from datetime import datetime
from typing import Iterator, Dict, Any, List
from typing import Iterator, Dict, Any, List, Optional, Union
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, AsyncMock
import pytest
from openai import Stream
@ -16,10 +16,13 @@ from haystack.components.agents import Agent
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.generators.chat.types import ChatGenerator
from haystack import component
from haystack.core.component.types import OutputSocket
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.chat_message import ChatRole, TextContent
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool, ComponentTool
from haystack.tools.toolset import Toolset
from haystack.utils import serialize_callable, Secret
from haystack.dataclasses.state_utils import merge_lists
@ -104,6 +107,22 @@ class MockChatGeneratorWithoutTools(ChatGenerator):
return {"replies": [ChatMessage.from_assistant("Hello")]}
class MockChatGeneratorWithoutRunAsync(ChatGenerator):
"""A mock chat generator that implements ChatGenerator protocol but doesn't have run_async method."""
def to_dict(self) -> Dict[str, Any]:
return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync":
return cls()
def run(
self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs
) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello")]}
class TestAgent:
def test_output_types(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
@ -500,3 +519,70 @@ class TestAgent:
assert response["messages"][1].tool_calls[0].arguments is not None
assert response["messages"][2].tool_call_results[0].result is not None
assert response["messages"][2].tool_call_results[0].origin is not None
@pytest.mark.asyncio
async def test_run_async_falls_back_to_run_when_chat_generator_has_no_run_async(self, weather_tool):
chat_generator = MockChatGeneratorWithoutRunAsync()
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Hello")]})
result = await agent.run_async([ChatMessage.from_user("Hello")])
expected_messages = [
ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={})
]
chat_generator.run.assert_called_once_with(messages=expected_messages, tools=[weather_tool])
assert isinstance(result, dict)
assert "messages" in result
assert isinstance(result["messages"], list)
assert len(result["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in result["messages"]]
assert "Hello" in result["messages"][1].text
@pytest.mark.asyncio
async def test_run_async_uses_chat_generator_run_async_when_available(self, weather_tool):
# Create a mock chat generator with run_async
# We need to use @component so that has_async_run is set
@component
class MockChatGeneratorWithRunAsync:
def to_dict(self) -> Dict[str, Any]:
return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync":
return cls()
def run(
self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs
) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello")]}
async def run_async(
self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs
) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello from run_async")]}
chat_generator = MockChatGeneratorWithRunAsync()
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
chat_generator.run_async = AsyncMock(
return_value={"replies": [ChatMessage.from_assistant("Hello from run_async")]}
)
result = await agent.run_async([ChatMessage.from_user("Hello")])
expected_messages = [
ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={})
]
chat_generator.run_async.assert_called_once_with(messages=expected_messages, tools=[weather_tool])
assert isinstance(result, dict)
assert "messages" in result
assert isinstance(result["messages"], list)
assert len(result["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in result["messages"]]
assert "Hello from run_async" in result["messages"][1].text