feat: Move storing of messages into State in Agent (#9150)

* Update messages to be stored in State so users can control how they are stored through a handler in the schema

* Fix test

* Add test

* Add reno

* Fix docstring
This commit is contained in:
Sebastian Husch Lee 2025-04-01 11:29:44 +02:00 committed by GitHub
parent f687d49fec
commit bde2d77df0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 74 additions and 23 deletions

View File

@ -58,7 +58,7 @@ class Agent:
system_prompt: Optional[str] = None,
exit_conditions: Optional[List[str]] = None,
state_schema: Optional[Dict[str, Any]] = None,
max_runs_per_component: int = 100,
max_agent_steps: int = 100,
raise_on_tool_invocation_failure: bool = False,
streaming_callback: Optional[SyncStreamingCallbackT] = None,
):
@ -72,8 +72,8 @@ class Agent:
Can include "text" if the agent should return when it generates a message without tool calls,
or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
:param state_schema: The schema for the runtime state used by the tools.
:param max_runs_per_component: Maximum number of runs per component. Agent will raise an exception if a
component exceeds the maximum number of runs per component.
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
If the agent exceeds this number of steps, it will stop and return the current state.
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
@ -91,7 +91,11 @@ class Agent:
if exit_conditions is None:
exit_conditions = ["text"]
if not all(condition in valid_exits for condition in exit_conditions):
raise ValueError(f"Exit conditions must be a subset of {valid_exits}")
raise ValueError(
f"Invalid exit conditions provided: {exit_conditions}. "
f"Valid exit conditions must be a subset of {valid_exits}. "
"Ensure that each exit condition corresponds to either 'text' or a valid tool name."
)
if state_schema is not None:
_validate_schema(state_schema)
@ -101,18 +105,17 @@ class Agent:
self.tools = tools or []
self.system_prompt = system_prompt
self.exit_conditions = exit_conditions
self.max_runs_per_component = max_runs_per_component
self.max_agent_steps = max_agent_steps
self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
self.streaming_callback = streaming_callback
output_types = {"messages": List[ChatMessage]}
output_types = {}
for param, config in self.state_schema.items():
component.set_input_type(self, name=param, type=config["type"], default=None)
output_types[param] = config["type"]
component.set_output_types(self, **output_types)
self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
self._is_warmed_up = False
def warm_up(self) -> None:
@ -142,7 +145,7 @@ class Agent:
system_prompt=self.system_prompt,
exit_conditions=self.exit_conditions,
state_schema=_schema_to_dict(self.state_schema),
max_runs_per_component=self.max_runs_per_component,
max_agent_steps=self.max_agent_steps,
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
streaming_callback=streaming_callback,
)
@ -170,7 +173,10 @@ class Agent:
return default_from_dict(cls, data)
def run(
self, messages: List[ChatMessage], streaming_callback: Optional[SyncStreamingCallbackT] = None, **kwargs
self,
messages: List[ChatMessage],
streaming_callback: Optional[SyncStreamingCallbackT] = None,
**kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""
Process messages and execute tools until the exit condition is met.
@ -188,6 +194,7 @@ class Agent:
if self.system_prompt is not None:
messages = [ChatMessage.from_system(self.system_prompt)] + messages
state.set("messages", messages)
generator_inputs: Dict[str, Any] = {"tools": self.tools}
@ -197,21 +204,23 @@ class Agent:
# Repeat until the exit condition is met
counter = 0
while counter < self.max_runs_per_component:
while counter < self.max_agent_steps:
# 1. Call the ChatGenerator
llm_messages = self.chat_generator.run(messages=messages, **generator_inputs)["replies"]
state.set("messages", llm_messages)
# TODO Possible for LLM to return multiple messages (e.g. multiple tool calls)
# Would a better check be to see if any of the messages contain a tool call?
# 2. Check if the LLM response contains a tool call
if llm_messages[0].tool_call is None:
return {"messages": messages + llm_messages, **state.data}
return {**state.data}
# 3. Call the ToolInvoker
# We only send the messages from the LLM to the tool invoker
tool_invoker_result = self._tool_invoker.run(messages=llm_messages, state=state)
tool_messages = tool_invoker_result["messages"]
tool_messages = tool_invoker_result["tool_messages"]
state = tool_invoker_result["state"]
state.set("messages", tool_messages)
# 4. Check the LLM and Tool response for exit conditions, if exit_conditions contains a tool name
# TODO Possible for LLM to return multiple messages (e.g. multiple tool calls)
@ -220,13 +229,14 @@ class Agent:
llm_messages[0].tool_call.tool_name in self.exit_conditions
and not tool_messages[0].tool_call_result.error
):
return {"messages": messages + llm_messages + tool_messages, **state.data}
return {**state.data}
# 5. Combine messages, llm_messages and tool_messages and send to the ChatGenerator
messages = messages + llm_messages + tool_messages
messages = state.get("messages")
counter += 1
logger.warning(
"Agent exceeded maximum runs per component ({max_loops}), stopping.", max_loops=self.max_runs_per_component
"Agent exceeded maximum runs per component ({max_agent_steps}), stopping.",
max_agent_steps=self.max_agent_steps,
)
return {"messages": messages, **state.data}
return {**state.data}

View File

@ -2,8 +2,9 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.type_serialization import deserialize_type, serialize_type
@ -65,6 +66,8 @@ def _validate_schema(schema: Dict[str, Any]) -> None:
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
if definition.get("handler") is not None and not callable(definition["handler"]):
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
if param == "messages" and definition["type"] is not List[ChatMessage]:
raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}")
class State:
@ -91,6 +94,8 @@ class State:
"""
_validate_schema(schema)
self.schema = schema
if self.schema.get("messages") is None:
self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
self._data = data or {}
# Set default handlers if not provided in schema

View File

@ -48,7 +48,7 @@ def _is_list_type(type_hint: Any) -> bool:
return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list)
def merge_lists(current: Union[List[T], T], new: Union[List[T], T]) -> List[T]:
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
"""
Merges two values into a single list.

View File

@ -0,0 +1,6 @@
---
enhancements:
- |
Now when using Haystack's Agent the messages are stored and accumulated in State. This means:
* State is required to have a "messages" type and handler defined in its schema. If not provided a default type and handler is provided.
* Users can now customize how to accumulate messages by providing a custom handler for messages through the State schema.

View File

@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
from datetime import datetime
from typing import Iterator, Dict, Any, List
@ -34,15 +35,12 @@ def weather_function(location):
return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
@pytest.fixture
def weather_tool():
return Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]},
function=weather_function,
)
@ -163,7 +161,7 @@ class TestAgent:
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
# Test invalid exit condition
with pytest.raises(ValueError, match="Exit conditions must be a subset of"):
with pytest.raises(ValueError, match="Invalid exit conditions provided:"):
Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"])
# Test default exit condition
@ -254,3 +252,32 @@ class TestAgent:
with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"):
Agent(chat_generator=chat_generator, tools=[weather_tool])
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_run(self, weather_tool):
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini")
agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3)
agent.warm_up()
response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")])
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 4
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
# Loose check of message texts
assert response["messages"][0].text == "What is the weather in Berlin?"
assert response["messages"][1].text is None
assert response["messages"][2].text is None
assert response["messages"][3].text is not None
# Loose check of message metadata
assert response["messages"][0].meta == {}
assert response["messages"][1].meta.get("model") is not None
assert response["messages"][2].meta == {}
assert response["messages"][3].meta.get("model") is not None
# Loose check of tool calls and results
assert response["messages"][1].tool_calls[0].tool_name == "weather_tool"
assert response["messages"][1].tool_calls[0].arguments is not None
assert response["messages"][2].tool_call_results[0].result is not None
assert response["messages"][2].tool_call_results[0].origin is not None

View File

@ -1,7 +1,9 @@
import pytest
from typing import List, Dict
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.state import State, _validate_schema
from haystack.dataclasses.state_utils import merge_lists
@pytest.fixture
@ -106,6 +108,7 @@ def test_state_has(basic_schema):
def test_state_empty_schema():
state = State({})
assert state.data == {}
assert state.schema == {"messages": {"type": List[ChatMessage], "handler": merge_lists}}
with pytest.raises(ValueError, match="Key 'any_key' not found in schema"):
state.set("any_key", "value")