diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index b5f187a67..3f218eea8 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -58,7 +58,7 @@ class Agent: system_prompt: Optional[str] = None, exit_conditions: Optional[List[str]] = None, state_schema: Optional[Dict[str, Any]] = None, - max_runs_per_component: int = 100, + max_agent_steps: int = 100, raise_on_tool_invocation_failure: bool = False, streaming_callback: Optional[SyncStreamingCallbackT] = None, ): @@ -72,8 +72,8 @@ class Agent: Can include "text" if the agent should return when it generates a message without tool calls, or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"]. :param state_schema: The schema for the runtime state used by the tools. - :param max_runs_per_component: Maximum number of runs per component. Agent will raise an exception if a - component exceeds the maximum number of runs per component. + :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. + If the agent exceeds this number of steps, it will stop and return the current state. :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? If set to False, the exception will be turned into a chat message and passed to the LLM. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. @@ -91,7 +91,11 @@ class Agent: if exit_conditions is None: exit_conditions = ["text"] if not all(condition in valid_exits for condition in exit_conditions): - raise ValueError(f"Exit conditions must be a subset of {valid_exits}") + raise ValueError( + f"Invalid exit conditions provided: {exit_conditions}. " + f"Valid exit conditions must be a subset of {valid_exits}. " + "Ensure that each exit condition corresponds to either 'text' or a valid tool name." + ) if state_schema is not None: _validate_schema(state_schema) @@ -101,18 +105,17 @@ class Agent: self.tools = tools or [] self.system_prompt = system_prompt self.exit_conditions = exit_conditions - self.max_runs_per_component = max_runs_per_component + self.max_agent_steps = max_agent_steps self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure self.streaming_callback = streaming_callback - output_types = {"messages": List[ChatMessage]} + output_types = {} for param, config in self.state_schema.items(): component.set_input_type(self, name=param, type=config["type"], default=None) output_types[param] = config["type"] component.set_output_types(self, **output_types) self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure) - self._is_warmed_up = False def warm_up(self) -> None: @@ -142,7 +145,7 @@ class Agent: system_prompt=self.system_prompt, exit_conditions=self.exit_conditions, state_schema=_schema_to_dict(self.state_schema), - max_runs_per_component=self.max_runs_per_component, + max_agent_steps=self.max_agent_steps, raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, streaming_callback=streaming_callback, ) @@ -170,7 +173,10 @@ class Agent: return default_from_dict(cls, data) def run( - self, messages: List[ChatMessage], streaming_callback: Optional[SyncStreamingCallbackT] = None, **kwargs + self, + messages: List[ChatMessage], + streaming_callback: Optional[SyncStreamingCallbackT] = None, + **kwargs: Dict[str, Any], ) -> Dict[str, Any]: """ Process messages and execute tools until the exit condition is met. @@ -188,6 +194,7 @@ class Agent: if self.system_prompt is not None: messages = [ChatMessage.from_system(self.system_prompt)] + messages + state.set("messages", messages) generator_inputs: Dict[str, Any] = {"tools": self.tools} @@ -197,21 +204,23 @@ class Agent: # Repeat until the exit condition is met counter = 0 - while counter < self.max_runs_per_component: + while counter < self.max_agent_steps: # 1. Call the ChatGenerator llm_messages = self.chat_generator.run(messages=messages, **generator_inputs)["replies"] + state.set("messages", llm_messages) # TODO Possible for LLM to return multiple messages (e.g. multiple tool calls) # Would a better check be to see if any of the messages contain a tool call? # 2. Check if the LLM response contains a tool call if llm_messages[0].tool_call is None: - return {"messages": messages + llm_messages, **state.data} + return {**state.data} # 3. Call the ToolInvoker # We only send the messages from the LLM to the tool invoker tool_invoker_result = self._tool_invoker.run(messages=llm_messages, state=state) - tool_messages = tool_invoker_result["messages"] + tool_messages = tool_invoker_result["tool_messages"] state = tool_invoker_result["state"] + state.set("messages", tool_messages) # 4. Check the LLM and Tool response for exit conditions, if exit_conditions contains a tool name # TODO Possible for LLM to return multiple messages (e.g. multiple tool calls) @@ -220,13 +229,14 @@ class Agent: llm_messages[0].tool_call.tool_name in self.exit_conditions and not tool_messages[0].tool_call_result.error ): - return {"messages": messages + llm_messages + tool_messages, **state.data} + return {**state.data} # 5. Combine messages, llm_messages and tool_messages and send to the ChatGenerator - messages = messages + llm_messages + tool_messages + messages = state.get("messages") counter += 1 logger.warning( - "Agent exceeded maximum runs per component ({max_loops}), stopping.", max_loops=self.max_runs_per_component + "Agent exceeded maximum runs per component ({max_agent_steps}), stopping.", + max_agent_steps=self.max_agent_steps, ) - return {"messages": messages, **state.data} + return {**state.data} diff --git a/haystack/dataclasses/state.py b/haystack/dataclasses/state.py index daf815e0b..02fc3c63e 100644 --- a/haystack/dataclasses/state.py +++ b/haystack/dataclasses/state.py @@ -2,8 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional +from haystack.dataclasses import ChatMessage from haystack.dataclasses.state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from haystack.utils.type_serialization import deserialize_type, serialize_type @@ -65,6 +66,8 @@ def _validate_schema(schema: Dict[str, Any]) -> None: raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}") if definition.get("handler") is not None and not callable(definition["handler"]): raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None") + if param == "messages" and definition["type"] is not List[ChatMessage]: + raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}") class State: @@ -91,6 +94,8 @@ class State: """ _validate_schema(schema) self.schema = schema + if self.schema.get("messages") is None: + self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists} self._data = data or {} # Set default handlers if not provided in schema diff --git a/haystack/dataclasses/state_utils.py b/haystack/dataclasses/state_utils.py index 19bcf1ded..2b392d812 100644 --- a/haystack/dataclasses/state_utils.py +++ b/haystack/dataclasses/state_utils.py @@ -48,7 +48,7 @@ def _is_list_type(type_hint: Any) -> bool: return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list) -def merge_lists(current: Union[List[T], T], new: Union[List[T], T]) -> List[T]: +def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]: """ Merges two values into a single list. diff --git a/releasenotes/notes/add-messages-to-state-1d10f1dc449be694.yaml b/releasenotes/notes/add-messages-to-state-1d10f1dc449be694.yaml new file mode 100644 index 000000000..020ba418d --- /dev/null +++ b/releasenotes/notes/add-messages-to-state-1d10f1dc449be694.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Now when using Haystack's Agent the messages are stored and accumulated in State. This means: + * State is required to have a "messages" type and handler defined in its schema. If not provided a default type and handler is provided. + * Users can now customize how to accumulate messages by providing a custom handler for messages through the State schema. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 71f76b26b..7865c52f1 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import os from datetime import datetime from typing import Iterator, Dict, Any, List @@ -34,15 +35,12 @@ def weather_function(location): return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) -weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]} - - @pytest.fixture def weather_tool(): return Tool( name="weather_tool", description="Provides weather information for a given location.", - parameters=weather_parameters, + parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, function=weather_function, ) @@ -163,7 +161,7 @@ class TestAgent: generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) # Test invalid exit condition - with pytest.raises(ValueError, match="Exit conditions must be a subset of"): + with pytest.raises(ValueError, match="Invalid exit conditions provided:"): Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"]) # Test default exit condition @@ -254,3 +252,32 @@ class TestAgent: with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"): Agent(chat_generator=chat_generator, tools=[weather_tool]) + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_run(self, weather_tool): + chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") + agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3) + agent.warm_up() + response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) + + assert isinstance(response, dict) + assert "messages" in response + assert isinstance(response["messages"], list) + assert len(response["messages"]) == 4 + assert [isinstance(reply, ChatMessage) for reply in response["messages"]] + # Loose check of message texts + assert response["messages"][0].text == "What is the weather in Berlin?" + assert response["messages"][1].text is None + assert response["messages"][2].text is None + assert response["messages"][3].text is not None + # Loose check of message metadata + assert response["messages"][0].meta == {} + assert response["messages"][1].meta.get("model") is not None + assert response["messages"][2].meta == {} + assert response["messages"][3].meta.get("model") is not None + # Loose check of tool calls and results + assert response["messages"][1].tool_calls[0].tool_name == "weather_tool" + assert response["messages"][1].tool_calls[0].arguments is not None + assert response["messages"][2].tool_call_results[0].result is not None + assert response["messages"][2].tool_call_results[0].origin is not None diff --git a/test/dataclasses/test_state.py b/test/dataclasses/test_state.py index 748bf2dc6..2a9d23aec 100644 --- a/test/dataclasses/test_state.py +++ b/test/dataclasses/test_state.py @@ -1,7 +1,9 @@ import pytest from typing import List, Dict +from haystack.dataclasses import ChatMessage from haystack.dataclasses.state import State, _validate_schema +from haystack.dataclasses.state_utils import merge_lists @pytest.fixture @@ -106,6 +108,7 @@ def test_state_has(basic_schema): def test_state_empty_schema(): state = State({}) assert state.data == {} + assert state.schema == {"messages": {"type": List[ChatMessage], "handler": merge_lists}} with pytest.raises(ValueError, match="Key 'any_key' not found in schema"): state.set("any_key", "value")