diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 815b862c8..dce17d6a8 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import asyncio import inspect from copy import deepcopy from typing import Any, Dict, List, Optional @@ -13,7 +14,7 @@ from haystack.core.serialization import component_to_dict from haystack.dataclasses import ChatMessage from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema from haystack.dataclasses.state_utils import merge_lists -from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT +from haystack.dataclasses.streaming_chunk import StreamingCallbackT from haystack.tools import Tool, deserialize_tools_or_toolset_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from haystack.utils.deserialization import deserialize_chatgenerator_inplace @@ -63,7 +64,7 @@ class Agent: state_schema: Optional[Dict[str, Any]] = None, max_agent_steps: int = 100, raise_on_tool_invocation_failure: bool = False, - streaming_callback: Optional[SyncStreamingCallbackT] = None, + streaming_callback: Optional[StreamingCallbackT] = None, ): """ Initialize the agent component. @@ -189,7 +190,7 @@ class Agent: def run( self, messages: List[ChatMessage], - streaming_callback: Optional[SyncStreamingCallbackT] = None, + streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Dict[str, Any], ) -> Dict[str, Any]: """ @@ -235,28 +236,8 @@ class Agent: state.set("messages", tool_messages) # 4. Check if any LLM message's tool call name matches an exit condition - if self.exit_conditions != ["text"]: - matched_exit_conditions = set() - has_errors = False - - for msg in llm_messages: - if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions: - matched_exit_conditions.add(msg.tool_call.tool_name) - - # Check if any error is specifically from the tool matching the exit condition - tool_errors = [ - tool_msg.tool_call_result.error - for tool_msg in tool_messages - if tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name - ] - if any(tool_errors): - has_errors = True - # No need to check further if we found an error - break - - # Only return if at least one exit condition was matched AND none had errors - if matched_exit_conditions and not has_errors: - return {**state.data} + if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): + return {**state.data} # 5. Fetch the combined messages and send them back to the LLM messages = state.get("messages") @@ -266,3 +247,116 @@ class Agent: "Agent exceeded maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps ) return {**state.data} + + async def run_async( + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + **kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Asynchronously process messages and execute tools until the exit condition is met. + + This is the asynchronous version of the `run` method. It follows the same logic but uses + asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator + if available. + + :param messages: List of chat messages to process + :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + :param kwargs: Additional data to pass to the State schema used by the Agent. + The keys must match the schema defined in the Agent's `state_schema`. + :return: Dictionary containing messages and outputs matching the defined output types + """ + if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): + raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.") + + state = State(schema=self.state_schema, data=kwargs) + + if self.system_prompt is not None: + messages = [ChatMessage.from_system(self.system_prompt)] + messages + state.set("messages", messages) + + generator_inputs: Dict[str, Any] = {"tools": self.tools} + + selected_callback = streaming_callback or self.streaming_callback + if selected_callback is not None: + generator_inputs["streaming_callback"] = selected_callback + + # Repeat until the exit condition is met + counter = 0 + while counter < self.max_agent_steps: + # 1. Call the ChatGenerator + # Check if the chat generator supports async execution + if getattr(self.chat_generator, "__haystack_supports_async__", False): + result = await self.chat_generator.run_async(messages=messages, **generator_inputs) # type: ignore[attr-defined] + llm_messages = result["replies"] + else: + # Fall back to synchronous run if async is not available + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, lambda: self.chat_generator.run(messages=messages, **generator_inputs) + ) + llm_messages = result["replies"] + + state.set("messages", llm_messages) + + # 2. Check if any of the LLM responses contain a tool call + if not any(msg.tool_call for msg in llm_messages): + return {**state.data} + + # 3. Call the ToolInvoker + # We only send the messages from the LLM to the tool invoker + # Check if the ToolInvoker supports async execution. Currently, it doesn't. + if getattr(self._tool_invoker, "__haystack_supports_async__", False): + tool_invoker_result = await self._tool_invoker.run_async(messages=llm_messages, state=state) # type: ignore[attr-defined] + else: + loop = asyncio.get_running_loop() + tool_invoker_result = await loop.run_in_executor( + None, lambda: self._tool_invoker.run(messages=llm_messages, state=state) + ) + tool_messages = tool_invoker_result["tool_messages"] + state = tool_invoker_result["state"] + state.set("messages", tool_messages) + + # 4. Check if any LLM message's tool call name matches an exit condition + if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): + return {**state.data} + + # 5. Fetch the combined messages and send them back to the LLM + messages = state.get("messages") + counter += 1 + + logger.warning( + "Agent exceeded maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps + ) + return {**state.data} + + def _check_exit_conditions(self, llm_messages: List[ChatMessage], tool_messages: List[ChatMessage]) -> bool: + """ + Check if any of the LLM messages' tool calls match an exit condition and if there are no errors. + + :param llm_messages: List of messages from the LLM + :param tool_messages: List of messages from the tool invoker + :return: True if an exit condition is met and there are no errors, False otherwise + """ + matched_exit_conditions = set() + has_errors = False + + for msg in llm_messages: + if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions: + matched_exit_conditions.add(msg.tool_call.tool_name) + + # Check if any error is specifically from the tool matching the exit condition + tool_errors = [ + tool_msg.tool_call_result.error + for tool_msg in tool_messages + if tool_msg.tool_call_result is not None + and tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name + ] + if any(tool_errors): + has_errors = True + # No need to check further if we found an error + break + + # Only return True if at least one exit condition was matched AND none had errors + return bool(matched_exit_conditions) and not has_errors diff --git a/releasenotes/notes/agent-run-async-28d3a6dd7ea4d888.yaml b/releasenotes/notes/agent-run-async-28d3a6dd7ea4d888.yaml new file mode 100644 index 000000000..44801fd43 --- /dev/null +++ b/releasenotes/notes/agent-run-async-28d3a6dd7ea4d888.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add a run_async to the Agent, which calls the run_async of the underlying ChatGenerator if available. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index b7e181420..bc1e00a99 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -4,9 +4,9 @@ import os from datetime import datetime -from typing import Iterator, Dict, Any, List +from typing import Iterator, Dict, Any, List, Optional, Union -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, AsyncMock import pytest from openai import Stream @@ -16,10 +16,13 @@ from haystack.components.agents import Agent from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.types import ChatGenerator +from haystack import component from haystack.core.component.types import OutputSocket from haystack.dataclasses import ChatMessage, ToolCall +from haystack.dataclasses.chat_message import ChatRole, TextContent from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.tools import Tool, ComponentTool +from haystack.tools.toolset import Toolset from haystack.utils import serialize_callable, Secret from haystack.dataclasses.state_utils import merge_lists @@ -104,6 +107,22 @@ class MockChatGeneratorWithoutTools(ChatGenerator): return {"replies": [ChatMessage.from_assistant("Hello")]} +class MockChatGeneratorWithoutRunAsync(ChatGenerator): + """A mock chat generator that implements ChatGenerator protocol but doesn't have run_async method.""" + + def to_dict(self) -> Dict[str, Any]: + return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync": + return cls() + + def run( + self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs + ) -> Dict[str, Any]: + return {"replies": [ChatMessage.from_assistant("Hello")]} + + class TestAgent: def test_output_types(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") @@ -500,3 +519,70 @@ class TestAgent: assert response["messages"][1].tool_calls[0].arguments is not None assert response["messages"][2].tool_call_results[0].result is not None assert response["messages"][2].tool_call_results[0].origin is not None + + @pytest.mark.asyncio + async def test_run_async_falls_back_to_run_when_chat_generator_has_no_run_async(self, weather_tool): + chat_generator = MockChatGeneratorWithoutRunAsync() + agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) + agent.warm_up() + + chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Hello")]}) + + result = await agent.run_async([ChatMessage.from_user("Hello")]) + + expected_messages = [ + ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={}) + ] + chat_generator.run.assert_called_once_with(messages=expected_messages, tools=[weather_tool]) + + assert isinstance(result, dict) + assert "messages" in result + assert isinstance(result["messages"], list) + assert len(result["messages"]) == 2 + assert [isinstance(reply, ChatMessage) for reply in result["messages"]] + assert "Hello" in result["messages"][1].text + + @pytest.mark.asyncio + async def test_run_async_uses_chat_generator_run_async_when_available(self, weather_tool): + # Create a mock chat generator with run_async + # We need to use @component so that has_async_run is set + @component + class MockChatGeneratorWithRunAsync: + def to_dict(self) -> Dict[str, Any]: + return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync": + return cls() + + def run( + self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs + ) -> Dict[str, Any]: + return {"replies": [ChatMessage.from_assistant("Hello")]} + + async def run_async( + self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs + ) -> Dict[str, Any]: + return {"replies": [ChatMessage.from_assistant("Hello from run_async")]} + + chat_generator = MockChatGeneratorWithRunAsync() + agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) + agent.warm_up() + + chat_generator.run_async = AsyncMock( + return_value={"replies": [ChatMessage.from_assistant("Hello from run_async")]} + ) + + result = await agent.run_async([ChatMessage.from_user("Hello")]) + + expected_messages = [ + ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={}) + ] + chat_generator.run_async.assert_called_once_with(messages=expected_messages, tools=[weather_tool]) + + assert isinstance(result, dict) + assert "messages" in result + assert isinstance(result["messages"], list) + assert len(result["messages"]) == 2 + assert [isinstance(reply, ChatMessage) for reply in result["messages"]] + assert "Hello from run_async" in result["messages"][1].text