mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 13:38:53 +00:00
feat: Add run_async to Agent (#9239)
* add run_async * refactor with _check_exit_conditions * add run_async tests * reno * fix linting issues
This commit is contained in:
parent
c67d1bf0e9
commit
13780cfcc4
@ -2,6 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
@ -13,7 +14,7 @@ from haystack.core.serialization import component_to_dict
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
|
||||
from haystack.dataclasses.state_utils import merge_lists
|
||||
from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT
|
||||
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
|
||||
from haystack.tools import Tool, deserialize_tools_or_toolset_inplace
|
||||
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
||||
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
|
||||
@ -63,7 +64,7 @@ class Agent:
|
||||
state_schema: Optional[Dict[str, Any]] = None,
|
||||
max_agent_steps: int = 100,
|
||||
raise_on_tool_invocation_failure: bool = False,
|
||||
streaming_callback: Optional[SyncStreamingCallbackT] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent component.
|
||||
@ -189,7 +190,7 @@ class Agent:
|
||||
def run(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
streaming_callback: Optional[SyncStreamingCallbackT] = None,
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -235,28 +236,8 @@ class Agent:
|
||||
state.set("messages", tool_messages)
|
||||
|
||||
# 4. Check if any LLM message's tool call name matches an exit condition
|
||||
if self.exit_conditions != ["text"]:
|
||||
matched_exit_conditions = set()
|
||||
has_errors = False
|
||||
|
||||
for msg in llm_messages:
|
||||
if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions:
|
||||
matched_exit_conditions.add(msg.tool_call.tool_name)
|
||||
|
||||
# Check if any error is specifically from the tool matching the exit condition
|
||||
tool_errors = [
|
||||
tool_msg.tool_call_result.error
|
||||
for tool_msg in tool_messages
|
||||
if tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name
|
||||
]
|
||||
if any(tool_errors):
|
||||
has_errors = True
|
||||
# No need to check further if we found an error
|
||||
break
|
||||
|
||||
# Only return if at least one exit condition was matched AND none had errors
|
||||
if matched_exit_conditions and not has_errors:
|
||||
return {**state.data}
|
||||
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
|
||||
return {**state.data}
|
||||
|
||||
# 5. Fetch the combined messages and send them back to the LLM
|
||||
messages = state.get("messages")
|
||||
@ -266,3 +247,116 @@ class Agent:
|
||||
"Agent exceeded maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps
|
||||
)
|
||||
return {**state.data}
|
||||
|
||||
async def run_async(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously process messages and execute tools until the exit condition is met.
|
||||
|
||||
This is the asynchronous version of the `run` method. It follows the same logic but uses
|
||||
asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator
|
||||
if available.
|
||||
|
||||
:param messages: List of chat messages to process
|
||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
||||
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
||||
The keys must match the schema defined in the Agent's `state_schema`.
|
||||
:return: Dictionary containing messages and outputs matching the defined output types
|
||||
"""
|
||||
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
|
||||
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.")
|
||||
|
||||
state = State(schema=self.state_schema, data=kwargs)
|
||||
|
||||
if self.system_prompt is not None:
|
||||
messages = [ChatMessage.from_system(self.system_prompt)] + messages
|
||||
state.set("messages", messages)
|
||||
|
||||
generator_inputs: Dict[str, Any] = {"tools": self.tools}
|
||||
|
||||
selected_callback = streaming_callback or self.streaming_callback
|
||||
if selected_callback is not None:
|
||||
generator_inputs["streaming_callback"] = selected_callback
|
||||
|
||||
# Repeat until the exit condition is met
|
||||
counter = 0
|
||||
while counter < self.max_agent_steps:
|
||||
# 1. Call the ChatGenerator
|
||||
# Check if the chat generator supports async execution
|
||||
if getattr(self.chat_generator, "__haystack_supports_async__", False):
|
||||
result = await self.chat_generator.run_async(messages=messages, **generator_inputs) # type: ignore[attr-defined]
|
||||
llm_messages = result["replies"]
|
||||
else:
|
||||
# Fall back to synchronous run if async is not available
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, lambda: self.chat_generator.run(messages=messages, **generator_inputs)
|
||||
)
|
||||
llm_messages = result["replies"]
|
||||
|
||||
state.set("messages", llm_messages)
|
||||
|
||||
# 2. Check if any of the LLM responses contain a tool call
|
||||
if not any(msg.tool_call for msg in llm_messages):
|
||||
return {**state.data}
|
||||
|
||||
# 3. Call the ToolInvoker
|
||||
# We only send the messages from the LLM to the tool invoker
|
||||
# Check if the ToolInvoker supports async execution. Currently, it doesn't.
|
||||
if getattr(self._tool_invoker, "__haystack_supports_async__", False):
|
||||
tool_invoker_result = await self._tool_invoker.run_async(messages=llm_messages, state=state) # type: ignore[attr-defined]
|
||||
else:
|
||||
loop = asyncio.get_running_loop()
|
||||
tool_invoker_result = await loop.run_in_executor(
|
||||
None, lambda: self._tool_invoker.run(messages=llm_messages, state=state)
|
||||
)
|
||||
tool_messages = tool_invoker_result["tool_messages"]
|
||||
state = tool_invoker_result["state"]
|
||||
state.set("messages", tool_messages)
|
||||
|
||||
# 4. Check if any LLM message's tool call name matches an exit condition
|
||||
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
|
||||
return {**state.data}
|
||||
|
||||
# 5. Fetch the combined messages and send them back to the LLM
|
||||
messages = state.get("messages")
|
||||
counter += 1
|
||||
|
||||
logger.warning(
|
||||
"Agent exceeded maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps
|
||||
)
|
||||
return {**state.data}
|
||||
|
||||
def _check_exit_conditions(self, llm_messages: List[ChatMessage], tool_messages: List[ChatMessage]) -> bool:
|
||||
"""
|
||||
Check if any of the LLM messages' tool calls match an exit condition and if there are no errors.
|
||||
|
||||
:param llm_messages: List of messages from the LLM
|
||||
:param tool_messages: List of messages from the tool invoker
|
||||
:return: True if an exit condition is met and there are no errors, False otherwise
|
||||
"""
|
||||
matched_exit_conditions = set()
|
||||
has_errors = False
|
||||
|
||||
for msg in llm_messages:
|
||||
if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions:
|
||||
matched_exit_conditions.add(msg.tool_call.tool_name)
|
||||
|
||||
# Check if any error is specifically from the tool matching the exit condition
|
||||
tool_errors = [
|
||||
tool_msg.tool_call_result.error
|
||||
for tool_msg in tool_messages
|
||||
if tool_msg.tool_call_result is not None
|
||||
and tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name
|
||||
]
|
||||
if any(tool_errors):
|
||||
has_errors = True
|
||||
# No need to check further if we found an error
|
||||
break
|
||||
|
||||
# Only return True if at least one exit condition was matched AND none had errors
|
||||
return bool(matched_exit_conditions) and not has_errors
|
||||
|
||||
4
releasenotes/notes/agent-run-async-28d3a6dd7ea4d888.yaml
Normal file
4
releasenotes/notes/agent-run-async-28d3a6dd7ea4d888.yaml
Normal file
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add a run_async to the Agent, which calls the run_async of the underlying ChatGenerator if available.
|
||||
@ -4,9 +4,9 @@
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Iterator, Dict, Any, List
|
||||
from typing import Iterator, Dict, Any, List, Optional, Union
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
import pytest
|
||||
|
||||
from openai import Stream
|
||||
@ -16,10 +16,13 @@ from haystack.components.agents import Agent
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.components.generators.chat.types import ChatGenerator
|
||||
from haystack import component
|
||||
from haystack.core.component.types import OutputSocket
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.dataclasses.chat_message import ChatRole, TextContent
|
||||
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||
from haystack.tools import Tool, ComponentTool
|
||||
from haystack.tools.toolset import Toolset
|
||||
from haystack.utils import serialize_callable, Secret
|
||||
from haystack.dataclasses.state_utils import merge_lists
|
||||
|
||||
@ -104,6 +107,22 @@ class MockChatGeneratorWithoutTools(ChatGenerator):
|
||||
return {"replies": [ChatMessage.from_assistant("Hello")]}
|
||||
|
||||
|
||||
class MockChatGeneratorWithoutRunAsync(ChatGenerator):
|
||||
"""A mock chat generator that implements ChatGenerator protocol but doesn't have run_async method."""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync":
|
||||
return cls()
|
||||
|
||||
def run(
|
||||
self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
return {"replies": [ChatMessage.from_assistant("Hello")]}
|
||||
|
||||
|
||||
class TestAgent:
|
||||
def test_output_types(self, weather_tool, component_tool, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
||||
@ -500,3 +519,70 @@ class TestAgent:
|
||||
assert response["messages"][1].tool_calls[0].arguments is not None
|
||||
assert response["messages"][2].tool_call_results[0].result is not None
|
||||
assert response["messages"][2].tool_call_results[0].origin is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_falls_back_to_run_when_chat_generator_has_no_run_async(self, weather_tool):
|
||||
chat_generator = MockChatGeneratorWithoutRunAsync()
|
||||
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
|
||||
agent.warm_up()
|
||||
|
||||
chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Hello")]})
|
||||
|
||||
result = await agent.run_async([ChatMessage.from_user("Hello")])
|
||||
|
||||
expected_messages = [
|
||||
ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={})
|
||||
]
|
||||
chat_generator.run.assert_called_once_with(messages=expected_messages, tools=[weather_tool])
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "messages" in result
|
||||
assert isinstance(result["messages"], list)
|
||||
assert len(result["messages"]) == 2
|
||||
assert [isinstance(reply, ChatMessage) for reply in result["messages"]]
|
||||
assert "Hello" in result["messages"][1].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_uses_chat_generator_run_async_when_available(self, weather_tool):
|
||||
# Create a mock chat generator with run_async
|
||||
# We need to use @component so that has_async_run is set
|
||||
@component
|
||||
class MockChatGeneratorWithRunAsync:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync":
|
||||
return cls()
|
||||
|
||||
def run(
|
||||
self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
return {"replies": [ChatMessage.from_assistant("Hello")]}
|
||||
|
||||
async def run_async(
|
||||
self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
return {"replies": [ChatMessage.from_assistant("Hello from run_async")]}
|
||||
|
||||
chat_generator = MockChatGeneratorWithRunAsync()
|
||||
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
|
||||
agent.warm_up()
|
||||
|
||||
chat_generator.run_async = AsyncMock(
|
||||
return_value={"replies": [ChatMessage.from_assistant("Hello from run_async")]}
|
||||
)
|
||||
|
||||
result = await agent.run_async([ChatMessage.from_user("Hello")])
|
||||
|
||||
expected_messages = [
|
||||
ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={})
|
||||
]
|
||||
chat_generator.run_async.assert_called_once_with(messages=expected_messages, tools=[weather_tool])
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "messages" in result
|
||||
assert isinstance(result["messages"], list)
|
||||
assert len(result["messages"]) == 2
|
||||
assert [isinstance(reply, ChatMessage) for reply in result["messages"]]
|
||||
assert "Hello from run_async" in result["messages"][1].text
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user