wip: adding breakpoints to agent

This commit is contained in:
David S. Batista 2025-06-06 21:45:49 +02:00
parent 8e21c501df
commit c798024d5b

View File

@ -3,7 +3,11 @@
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Any, Dict, List, Optional, Union
import json
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from haystack import component, default_from_dict, default_to_dict, logging, tracing
from haystack.components.generators.chat.types import ChatGenerator
@ -24,6 +28,72 @@ from .state.state_utils import merge_lists
logger = logging.getLogger(__name__)
class AgentBreakpointException(Exception):
"""
Exception raised when an agent breakpoint is triggered.
"""
def __init__(
self,
message: str,
component: Optional[str] = None,
state: Optional[Dict[str, Any]] = None,
results: Optional[Dict[str, Any]] = None,
):
super().__init__(message)
self.component = component
self.state = state
self.results = results
class AgentInvalidResumeStateError(Exception):
"""
Exception raised when an agent is resumed from an invalid state.
"""
def __init__(self, message: str):
super().__init__(message)
def _serialize_agent_data(value: Any) -> Any:
"""
Serializes agent data so it can be saved to a file.
"""
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
serialized_value = value.to_dict()
serialized_value["_type"] = value.__class__.__name__
return serialized_value
elif hasattr(value, "__dict__"):
return {"_type": value.__class__.__name__, "attributes": value.__dict__}
elif isinstance(value, dict):
return {k: _serialize_agent_data(v) for k, v in value.items()}
elif isinstance(value, list):
return [_serialize_agent_data(item) for item in value]
return value
def _deserialize_agent_data(value: Any) -> Any:
"""
Deserializes agent data loaded from a file.
"""
if not value or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, list):
if all(isinstance(i, (str, int, float, bool)) for i in value):
return value
return [_deserialize_agent_data(i) for i in value]
if isinstance(value, dict):
if "_type" in value:
type_name = value.pop("_type")
if type_name == "State":
return State.from_dict(value)
return {k: _deserialize_agent_data(v) for k, v in value.items()}
return value
@component
class Agent:
"""
@ -220,8 +290,234 @@ class Agent:
},
)
def _validate_breakpoints(self, breakpoints: Set[Tuple[str, Optional[int]]]) -> Set[Tuple[str, int]]:
"""
Validates the breakpoints passed to the agent.
Valid breakpoint components are:
- "chat_generator": Breaks before calling the chat generator
- "tool_invoker": Breaks before calling any tool
- "tool_invoker:tool_name": Breaks before calling a specific tool
:param breakpoints: Set of tuples of component names and visit counts at which the agent should stop.
:returns: Set of valid breakpoints.
"""
processed_breakpoints: Set[Tuple[str, int]] = set()
valid_components = {"chat_generator", "tool_invoker"}
# Add tool-specific breakpoints to valid components
for tool in self.tools:
valid_components.add(f"tool_invoker:{tool.name}")
for break_point in breakpoints:
component_name = break_point[0]
if component_name not in valid_components:
raise ValueError(
f"Breakpoint '{component_name}' is not a valid component. "
f"Valid components are: {sorted(valid_components)}"
)
valid_breakpoint: Tuple[str, int] = (component_name, 0 if break_point[1] is None else break_point[1])
processed_breakpoints.add(valid_breakpoint)
return processed_breakpoints
@staticmethod
def save_state(
state: State,
component_visits: Dict[str, int],
generator_inputs: Dict[str, Any],
component_name: str,
original_messages: List[ChatMessage],
original_kwargs: Dict[str, Any],
debug_path: Optional[Union[str, Path]] = None,
callback_fun: Optional[Callable[..., Any]] = None,
) -> Dict[str, Any]:
"""
Save the current agent state to a JSON file or return as a dictionary.
:param state: Current agent State object
:param component_visits: Dictionary tracking component visit counts
:param generator_inputs: Current generator inputs
:param component_name: Name of the component where breakpoint was triggered
:param original_messages: Original input messages
:param original_kwargs: Original input kwargs
:param debug_path: Optional path to save the state file
:param callback_fun: Optional callback function to call with the state
:returns: The saved state dictionary
"""
dt = datetime.now()
print(component_visits)
# For tool-specific breakpoints, extract the base component name for visit count lookup
base_component_name = component_name.split(":")[0] if ":" in component_name else component_name
agent_state = {
"input_data": _serialize_agent_data({"messages": original_messages, "kwargs": original_kwargs}),
"timestamp": dt.isoformat(),
"breakpoint": {"component": component_name, "visits": component_visits[base_component_name]},
"agent_state": {
"state": _serialize_agent_data(state),
"component_visits": component_visits,
"generator_inputs": _serialize_agent_data(generator_inputs),
},
}
if not debug_path:
return agent_state
if isinstance(debug_path, str):
debug_path = Path(debug_path)
if not isinstance(debug_path, Path):
raise ValueError("Debug path must be a string or a Path object.")
debug_path.mkdir(exist_ok=True)
file_name = Path(f"agent_{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json")
try:
with open(debug_path / file_name, "w") as f_out:
json.dump(agent_state, f_out, indent=2)
logger.info(f"Agent state saved at: {file_name}")
if callback_fun is not None:
callback_fun(agent_state)
return agent_state
except Exception as e:
logger.error(f"Failed to save agent state: {str(e)}")
raise
@staticmethod
def load_state(file_path: Union[str, Path]) -> Dict[str, Any]:
"""
Load a saved agent state.
:param file_path: Path to the state file
:returns: Dict containing the loaded state
"""
file_path = Path(file_path)
try:
with open(file_path, "r", encoding="utf-8") as f:
state = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")
except json.JSONDecodeError as e:
raise json.JSONDecodeError(f"Invalid JSON file {file_path}: {str(e)}", e.doc, e.pos)
except IOError as e:
raise IOError(f"Error reading {file_path}: {str(e)}")
try:
Agent._validate_resume_state(state)
except ValueError as e:
raise ValueError(f"Invalid agent state from {file_path}: {str(e)}")
logger.info(f"Successfully loaded agent state from: {file_path}")
return state
@staticmethod
def _validate_resume_state(state: Dict[str, Any]) -> None:
"""
Validates the loaded agent state.
:param state: The state to validate
"""
required_top_keys = {"input_data", "breakpoint", "agent_state"}
missing_top = required_top_keys - state.keys()
if missing_top:
raise ValueError(f"Invalid state file: missing required keys {missing_top}")
agent_state = state["agent_state"]
required_agent_keys = {"state", "component_visits", "generator_inputs"}
missing_agent = required_agent_keys - agent_state.keys()
if missing_agent:
raise ValueError(f"Invalid agent_state: missing required keys {missing_agent}")
logger.info("Agent resume state validated successfully.")
def _check_breakpoints(
self,
breakpoints: Set[Tuple[str, int]],
component_name: str,
component_visits: Dict[str, int],
state: State,
generator_inputs: Dict[str, Any],
original_messages: List[ChatMessage],
original_kwargs: Dict[str, Any],
tool_name: Optional[str] = None,
debug_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Check if a breakpoint should be triggered for the given component.
:param breakpoints: Set of breakpoints to check against
:param component_name: Name of the component to check
:param component_visits: Dictionary tracking component visit counts
:param state: Current agent state
:param generator_inputs: Current generator inputs
:param original_messages: Original input messages
:param original_kwargs: Original input kwargs
:param tool_name: Optional tool name for tool-specific breakpoints
:param debug_path: Optional path to save debug state
"""
# Check general component breakpoints
matching_breakpoints = [bp for bp in breakpoints if bp[0] == component_name]
# Check tool-specific breakpoints if tool_name is provided
if tool_name and component_name == "tool_invoker":
tool_specific_name = f"tool_invoker:{tool_name}"
tool_breakpoints = [bp for bp in breakpoints if bp[0] == tool_specific_name]
matching_breakpoints.extend(tool_breakpoints)
for bp in matching_breakpoints:
visit_count = bp[1]
if visit_count == component_visits[component_name]:
msg = f"Breaking at component {bp[0]} visit count {component_visits[component_name]}"
logger.info(msg)
saved_state = self.save_state(
state=state,
component_visits=component_visits,
generator_inputs=generator_inputs,
component_name=bp[0],
original_messages=original_messages,
original_kwargs=original_kwargs,
debug_path=debug_path,
)
raise AgentBreakpointException(msg, component=bp[0], state=saved_state)
def inject_resume_state(
self, resume_state: Dict[str, Any]
) -> Tuple[State, Dict[str, int], Dict[str, Any], List[ChatMessage]]:
"""
Inject resume state into the agent for continuing execution.
:param resume_state: The saved state to resume from
:returns: Tuple of (state, component_visits, generator_inputs, messages)
"""
if not resume_state:
raise AgentInvalidResumeStateError("Cannot inject resume state: resume_state is None")
self._validate_resume_state(resume_state)
agent_state_data = resume_state["agent_state"]
state = _deserialize_agent_data(agent_state_data["state"])
component_visits = agent_state_data["component_visits"]
generator_inputs = _deserialize_agent_data(agent_state_data["generator_inputs"])
messages = state.get("messages")
logger.info(
f"Resuming agent from {resume_state['breakpoint']['component']} "
f"visit count {resume_state['breakpoint']['visits']}"
)
return state, component_visits, generator_inputs, messages
def run(
self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any
self,
messages: List[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
breakpoints: Optional[Set[Tuple[str, Optional[int]]]] = None,
resume_state: Optional[Dict[str, Any]] = None,
debug_path: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Process messages and execute tools until an exit condition is met.
@ -230,6 +526,11 @@ class Agent:
If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called.
:param breakpoints: Set of tuples of component names and visit counts at which the agent should break execution.
Valid components are "chat_generator", "tool_invoker", or "tool_invoker:tool_name" for specific tools.
If the visit count is not given, it defaults to 0 (break on first visit).
:param resume_state: A dictionary containing the state of a previously saved agent execution.
:param debug_path: Path to the directory where the agent state should be saved when breakpoints are hit.
:param kwargs: Additional data to pass to the State schema used by the Agent.
The keys must match the schema defined in the Agent's `state_schema`.
:returns:
@ -237,65 +538,127 @@ class Agent:
- "messages": List of all messages exchanged during the agent's run.
- "last_message": The last message exchanged during the agent's run.
- Any additional keys defined in the `state_schema`.
:raises AgentBreakpointException: When a breakpoint is triggered.
"""
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")
if self.system_prompt is not None:
messages = [ChatMessage.from_system(self.system_prompt)] + messages
if breakpoints and resume_state:
logger.warning("Breakpoints cannot be provided when resuming an agent. All breakpoints will be ignored.")
state = State(schema=self.state_schema, data=kwargs)
state.set("messages", messages)
component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
# Validate breakpoints if provided
validated_breakpoints = self._validate_breakpoints(breakpoints) if breakpoints else None
# Store original inputs for potential resume
original_messages = deepcopy(messages)
original_kwargs = deepcopy(kwargs)
# Handle resume state or initialize fresh state
if resume_state:
# Restore state from resume data
agent_state_data = resume_state["agent_state"]
state = _deserialize_agent_data(agent_state_data["state"])
component_visits = agent_state_data["component_visits"]
generator_inputs = _deserialize_agent_data(agent_state_data["generator_inputs"])
messages = state.get("messages")
logger.info(
f"Resuming agent from {resume_state['breakpoint']['component']} "
f"visit count {resume_state['breakpoint']['visits']}"
)
else:
if self.system_prompt is not None:
messages = [ChatMessage.from_system(self.system_prompt)] + messages
state = State(schema=self.state_schema, data=kwargs)
state.set("messages", messages)
component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
if not resume_state:
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
with self._create_agent_span() as span:
span.set_content_tag(
"haystack.agent.input",
_deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}),
)
counter = 0
while counter < self.max_agent_steps:
# 1. Call the ChatGenerator
result = Pipeline._run_component(
component_name="chat_generator",
component={"instance": self.chat_generator},
inputs={"messages": messages, **generator_inputs},
component_visits=component_visits,
parent_span=span,
)
llm_messages = result["replies"]
state.set("messages", llm_messages)
try:
while counter < self.max_agent_steps:
# Check breakpoint before calling ChatGenerator
if validated_breakpoints and not resume_state:
self._check_breakpoints(
breakpoints=validated_breakpoints,
component_name="chat_generator",
component_visits=component_visits,
state=state,
generator_inputs=generator_inputs,
original_messages=original_messages,
original_kwargs=original_kwargs,
debug_path=debug_path,
)
# 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools
if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
# 1. Call the ChatGenerator
result = Pipeline._run_component(
component_name="chat_generator",
component={"instance": self.chat_generator},
inputs={"messages": messages, **generator_inputs},
component_visits=component_visits,
parent_span=span,
)
llm_messages = result["replies"]
state.set("messages", llm_messages)
# 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools
if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
counter += 1
break
# Check breakpoint before calling ToolInvoker
if validated_breakpoints and not resume_state:
# Check for tool-specific breakpoints
tool_names = [msg.tool_call.tool_name for msg in llm_messages if msg.tool_call]
for tool_name in tool_names:
self._check_breakpoints(
breakpoints=validated_breakpoints,
component_name="tool_invoker",
component_visits=component_visits,
state=state,
generator_inputs=generator_inputs,
original_messages=original_messages,
original_kwargs=original_kwargs,
tool_name=tool_name,
debug_path=debug_path,
)
# 3. Call the ToolInvoker
# We only send the messages from the LLM to the tool invoker
tool_invoker_result = Pipeline._run_component(
component_name="tool_invoker",
component={"instance": self._tool_invoker},
inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback},
component_visits=component_visits,
parent_span=span,
)
tool_messages = tool_invoker_result["tool_messages"]
state = tool_invoker_result["state"]
state.set("messages", tool_messages)
# 4. Check if any LLM message's tool call name matches an exit condition
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
counter += 1
break
# 5. Fetch the combined messages and send them back to the LLM
messages = state.get("messages")
counter += 1
break
# 3. Call the ToolInvoker
# We only send the messages from the LLM to the tool invoker
tool_invoker_result = Pipeline._run_component(
component_name="tool_invoker",
component={"instance": self._tool_invoker},
inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback},
component_visits=component_visits,
parent_span=span,
)
tool_messages = tool_invoker_result["tool_messages"]
state = tool_invoker_result["state"]
state.set("messages", tool_messages)
# 4. Check if any LLM message's tool call name matches an exit condition
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
counter += 1
break
# 5. Fetch the combined messages and send them back to the LLM
messages = state.get("messages")
counter += 1
except AgentBreakpointException as e:
# Add current agent state to the exception and re-raise
e.results = {**state.data}
raise
if counter >= self.max_agent_steps:
logger.warning(
@ -312,7 +675,13 @@ class Agent:
return result
async def run_async(
self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any
self,
messages: List[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
breakpoints: Optional[Set[Tuple[str, Optional[int]]]] = None,
resume_state: Optional[Dict[str, Any]] = None,
debug_path: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Asynchronously process messages and execute tools until the exit condition is met.
@ -325,6 +694,10 @@ class Agent:
:param streaming_callback: An asynchronous callback that will be invoked when a response
is streamed from the LLM. The same callback can be configured to emit tool results
when a tool is called.
:param breakpoints: Set of tuples of component names and visit counts at which the agent should break execution.
Valid components are "chat_generator", "tool_invoker", or "tool_invoker:tool_name" for specific tools.
:param resume_state: A dictionary containing the state of a previously saved agent execution.
:param debug_path: Path to the directory where the agent state should be saved when breakpoints are hit.
:param kwargs: Additional data to pass to the State schema used by the Agent.
The keys must match the schema defined in the Agent's `state_schema`.
:returns:
@ -332,6 +705,7 @@ class Agent:
- "messages": List of all messages exchanged during the agent's run.
- "last_message": The last message exchanged during the agent's run.
- Any additional keys defined in the `state_schema`.
:raises AgentBreakpointException: When a breakpoint is triggered.
"""
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.")