mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
feat: Move storing of messages into State in Agent (#9150)
* Update messages to be stored in State so users can control how they are stored through a handler in the schema * Fix test * Add test * Add reno * Fix docstring
This commit is contained in:
parent
f687d49fec
commit
bde2d77df0
@ -58,7 +58,7 @@ class Agent:
|
||||
system_prompt: Optional[str] = None,
|
||||
exit_conditions: Optional[List[str]] = None,
|
||||
state_schema: Optional[Dict[str, Any]] = None,
|
||||
max_runs_per_component: int = 100,
|
||||
max_agent_steps: int = 100,
|
||||
raise_on_tool_invocation_failure: bool = False,
|
||||
streaming_callback: Optional[SyncStreamingCallbackT] = None,
|
||||
):
|
||||
@ -72,8 +72,8 @@ class Agent:
|
||||
Can include "text" if the agent should return when it generates a message without tool calls,
|
||||
or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
|
||||
:param state_schema: The schema for the runtime state used by the tools.
|
||||
:param max_runs_per_component: Maximum number of runs per component. Agent will raise an exception if a
|
||||
component exceeds the maximum number of runs per component.
|
||||
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
|
||||
If the agent exceeds this number of steps, it will stop and return the current state.
|
||||
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
|
||||
If set to False, the exception will be turned into a chat message and passed to the LLM.
|
||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
||||
@ -91,7 +91,11 @@ class Agent:
|
||||
if exit_conditions is None:
|
||||
exit_conditions = ["text"]
|
||||
if not all(condition in valid_exits for condition in exit_conditions):
|
||||
raise ValueError(f"Exit conditions must be a subset of {valid_exits}")
|
||||
raise ValueError(
|
||||
f"Invalid exit conditions provided: {exit_conditions}. "
|
||||
f"Valid exit conditions must be a subset of {valid_exits}. "
|
||||
"Ensure that each exit condition corresponds to either 'text' or a valid tool name."
|
||||
)
|
||||
|
||||
if state_schema is not None:
|
||||
_validate_schema(state_schema)
|
||||
@ -101,18 +105,17 @@ class Agent:
|
||||
self.tools = tools or []
|
||||
self.system_prompt = system_prompt
|
||||
self.exit_conditions = exit_conditions
|
||||
self.max_runs_per_component = max_runs_per_component
|
||||
self.max_agent_steps = max_agent_steps
|
||||
self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
|
||||
self.streaming_callback = streaming_callback
|
||||
|
||||
output_types = {"messages": List[ChatMessage]}
|
||||
output_types = {}
|
||||
for param, config in self.state_schema.items():
|
||||
component.set_input_type(self, name=param, type=config["type"], default=None)
|
||||
output_types[param] = config["type"]
|
||||
component.set_output_types(self, **output_types)
|
||||
|
||||
self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
|
||||
|
||||
self._is_warmed_up = False
|
||||
|
||||
def warm_up(self) -> None:
|
||||
@ -142,7 +145,7 @@ class Agent:
|
||||
system_prompt=self.system_prompt,
|
||||
exit_conditions=self.exit_conditions,
|
||||
state_schema=_schema_to_dict(self.state_schema),
|
||||
max_runs_per_component=self.max_runs_per_component,
|
||||
max_agent_steps=self.max_agent_steps,
|
||||
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
|
||||
streaming_callback=streaming_callback,
|
||||
)
|
||||
@ -170,7 +173,10 @@ class Agent:
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def run(
|
||||
self, messages: List[ChatMessage], streaming_callback: Optional[SyncStreamingCallbackT] = None, **kwargs
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
streaming_callback: Optional[SyncStreamingCallbackT] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process messages and execute tools until the exit condition is met.
|
||||
@ -188,6 +194,7 @@ class Agent:
|
||||
|
||||
if self.system_prompt is not None:
|
||||
messages = [ChatMessage.from_system(self.system_prompt)] + messages
|
||||
state.set("messages", messages)
|
||||
|
||||
generator_inputs: Dict[str, Any] = {"tools": self.tools}
|
||||
|
||||
@ -197,21 +204,23 @@ class Agent:
|
||||
|
||||
# Repeat until the exit condition is met
|
||||
counter = 0
|
||||
while counter < self.max_runs_per_component:
|
||||
while counter < self.max_agent_steps:
|
||||
# 1. Call the ChatGenerator
|
||||
llm_messages = self.chat_generator.run(messages=messages, **generator_inputs)["replies"]
|
||||
state.set("messages", llm_messages)
|
||||
|
||||
# TODO Possible for LLM to return multiple messages (e.g. multiple tool calls)
|
||||
# Would a better check be to see if any of the messages contain a tool call?
|
||||
# 2. Check if the LLM response contains a tool call
|
||||
if llm_messages[0].tool_call is None:
|
||||
return {"messages": messages + llm_messages, **state.data}
|
||||
return {**state.data}
|
||||
|
||||
# 3. Call the ToolInvoker
|
||||
# We only send the messages from the LLM to the tool invoker
|
||||
tool_invoker_result = self._tool_invoker.run(messages=llm_messages, state=state)
|
||||
tool_messages = tool_invoker_result["messages"]
|
||||
tool_messages = tool_invoker_result["tool_messages"]
|
||||
state = tool_invoker_result["state"]
|
||||
state.set("messages", tool_messages)
|
||||
|
||||
# 4. Check the LLM and Tool response for exit conditions, if exit_conditions contains a tool name
|
||||
# TODO Possible for LLM to return multiple messages (e.g. multiple tool calls)
|
||||
@ -220,13 +229,14 @@ class Agent:
|
||||
llm_messages[0].tool_call.tool_name in self.exit_conditions
|
||||
and not tool_messages[0].tool_call_result.error
|
||||
):
|
||||
return {"messages": messages + llm_messages + tool_messages, **state.data}
|
||||
return {**state.data}
|
||||
|
||||
# 5. Combine messages, llm_messages and tool_messages and send to the ChatGenerator
|
||||
messages = messages + llm_messages + tool_messages
|
||||
messages = state.get("messages")
|
||||
counter += 1
|
||||
|
||||
logger.warning(
|
||||
"Agent exceeded maximum runs per component ({max_loops}), stopping.", max_loops=self.max_runs_per_component
|
||||
"Agent exceeded maximum runs per component ({max_agent_steps}), stopping.",
|
||||
max_agent_steps=self.max_agent_steps,
|
||||
)
|
||||
return {"messages": messages, **state.data}
|
||||
return {**state.data}
|
||||
|
||||
@ -2,8 +2,9 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.dataclasses.state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values
|
||||
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
||||
from haystack.utils.type_serialization import deserialize_type, serialize_type
|
||||
@ -65,6 +66,8 @@ def _validate_schema(schema: Dict[str, Any]) -> None:
|
||||
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
|
||||
if definition.get("handler") is not None and not callable(definition["handler"]):
|
||||
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
|
||||
if param == "messages" and definition["type"] is not List[ChatMessage]:
|
||||
raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}")
|
||||
|
||||
|
||||
class State:
|
||||
@ -91,6 +94,8 @@ class State:
|
||||
"""
|
||||
_validate_schema(schema)
|
||||
self.schema = schema
|
||||
if self.schema.get("messages") is None:
|
||||
self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
|
||||
self._data = data or {}
|
||||
|
||||
# Set default handlers if not provided in schema
|
||||
|
||||
@ -48,7 +48,7 @@ def _is_list_type(type_hint: Any) -> bool:
|
||||
return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list)
|
||||
|
||||
|
||||
def merge_lists(current: Union[List[T], T], new: Union[List[T], T]) -> List[T]:
|
||||
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
|
||||
"""
|
||||
Merges two values into a single list.
|
||||
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Now when using Haystack's Agent the messages are stored and accumulated in State. This means:
|
||||
* State is required to have a "messages" type and handler defined in its schema. If not provided a default type and handler is provided.
|
||||
* Users can now customize how to accumulate messages by providing a custom handler for messages through the State schema.
|
||||
@ -2,6 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Iterator, Dict, Any, List
|
||||
|
||||
@ -34,15 +35,12 @@ def weather_function(location):
|
||||
return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
|
||||
|
||||
|
||||
weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tool():
|
||||
return Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]},
|
||||
function=weather_function,
|
||||
)
|
||||
|
||||
@ -163,7 +161,7 @@ class TestAgent:
|
||||
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
|
||||
|
||||
# Test invalid exit condition
|
||||
with pytest.raises(ValueError, match="Exit conditions must be a subset of"):
|
||||
with pytest.raises(ValueError, match="Invalid exit conditions provided:"):
|
||||
Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"])
|
||||
|
||||
# Test default exit condition
|
||||
@ -254,3 +252,32 @@ class TestAgent:
|
||||
|
||||
with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"):
|
||||
Agent(chat_generator=chat_generator, tools=[weather_tool])
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
||||
@pytest.mark.integration
|
||||
def test_run(self, weather_tool):
|
||||
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini")
|
||||
agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3)
|
||||
agent.warm_up()
|
||||
response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")])
|
||||
|
||||
assert isinstance(response, dict)
|
||||
assert "messages" in response
|
||||
assert isinstance(response["messages"], list)
|
||||
assert len(response["messages"]) == 4
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
|
||||
# Loose check of message texts
|
||||
assert response["messages"][0].text == "What is the weather in Berlin?"
|
||||
assert response["messages"][1].text is None
|
||||
assert response["messages"][2].text is None
|
||||
assert response["messages"][3].text is not None
|
||||
# Loose check of message metadata
|
||||
assert response["messages"][0].meta == {}
|
||||
assert response["messages"][1].meta.get("model") is not None
|
||||
assert response["messages"][2].meta == {}
|
||||
assert response["messages"][3].meta.get("model") is not None
|
||||
# Loose check of tool calls and results
|
||||
assert response["messages"][1].tool_calls[0].tool_name == "weather_tool"
|
||||
assert response["messages"][1].tool_calls[0].arguments is not None
|
||||
assert response["messages"][2].tool_call_results[0].result is not None
|
||||
assert response["messages"][2].tool_call_results[0].origin is not None
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import pytest
|
||||
from typing import List, Dict
|
||||
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.dataclasses.state import State, _validate_schema
|
||||
from haystack.dataclasses.state_utils import merge_lists
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -106,6 +108,7 @@ def test_state_has(basic_schema):
|
||||
def test_state_empty_schema():
|
||||
state = State({})
|
||||
assert state.data == {}
|
||||
assert state.schema == {"messages": {"type": List[ChatMessage], "handler": merge_lists}}
|
||||
with pytest.raises(ValueError, match="Key 'any_key' not found in schema"):
|
||||
state.set("any_key", "value")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user