haystack/test/components/agents/test_agent_breakpoints.py
Sebastian Husch Lee 18b6482e2e
feat: Add ability to pass breakpoint and snapshot to Agent at runtime (#9867)
* Add ability to pass breakpoint and snapshot to agent at runtime

* Update releasenotes/notes/pass-agent-breakpoint-and-snapshot-5ac32800899d0bab.yaml

Co-authored-by: David S. Batista <dsbatista@gmail.com>

---------

Co-authored-by: David S. Batista <dsbatista@gmail.com>
2025-10-13 11:07:12 +00:00

698 lines
31 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import os
from dataclasses import replace
from pathlib import Path
from typing import Any, Optional, Union
import pytest
from haystack import component
from haystack.components.agents import Agent
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.core.errors import BreakpointException
from haystack.core.pipeline.breakpoint import load_pipeline_snapshot
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint
from haystack.tools import Tool, Toolset
def weather_function(location):
weather_info = {
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
}
return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
@pytest.fixture
def weather_tool():
return Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]},
function=weather_function,
)
@component
class MockChatGenerator:
def __init__(self, responses: Optional[list[ChatMessage]] = None):
self._counter = 0
self.responses = responses or [
ChatMessage.from_assistant(
"I'll help you check the weather.",
tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})],
),
ChatMessage.from_assistant("The weather in Berlin is sunny."),
]
def to_dict(self) -> dict[str, Any]:
return {"type": "MockChatGenerator", "data": {}}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MockChatGenerator":
return cls()
@component.output_types(replies=list[ChatMessage])
def run(
self, messages: list[ChatMessage], tools: Optional[Union[list[Tool], Toolset]] = None, **kwargs
) -> dict[str, Any]:
if self._counter >= len(self.responses):
return {"replies": [self.responses[-1]]}
else:
result = self.responses[self._counter]
self._counter += 1
return {"replies": [result]}
@component.output_types(replies=list[ChatMessage])
async def run_async(
self, messages: list[ChatMessage], tools: Optional[Union[list[Tool], Toolset]] = None, **kwargs
) -> dict[str, Any]:
if self._counter >= len(self.responses):
return {"replies": [self.responses[-1]]}
else:
result = self.responses[self._counter]
self._counter += 1
return {"replies": [result]}
@pytest.fixture
def agent(weather_tool):
return Agent(chat_generator=MockChatGenerator(), tools=[weather_tool])
@pytest.fixture
def chat_generator_serialization_schema():
return {
"type": "object",
"properties": {
"messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}},
"tools": {"type": "array", "items": {"type": "haystack.tools.tool.Tool"}},
},
}
class TestAgentBreakpoints:
def test_run_with_chat_generator_breakpoint(self, agent, chat_generator_serialization_schema):
agent_breakpoint = AgentBreakpoint(
break_point=Breakpoint(component_name="chat_generator"), agent_name="test_agent"
)
with pytest.raises(BreakpointException) as exc_info:
agent.run(messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint)
assert isinstance(exc_info.value, BreakpointException)
assert exc_info.value.component == "chat_generator"
assert exc_info.value.inputs == {
"chat_generator": {
"serialization_schema": chat_generator_serialization_schema,
"serialized_data": {
"messages": [
{
"role": "user",
"meta": {},
"name": None,
"content": [{"text": "What's the weather in Berlin?"}],
}
],
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "test_agent_breakpoints.weather_function",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
],
},
},
"tool_invoker": {
"serialization_schema": {
"type": "object",
"properties": {
"messages": {"type": "array", "items": {}},
"state": {"type": "haystack.components.agents.state.state.State"},
"tools": {"type": "array", "items": {"type": "haystack.tools.tool.Tool"}},
},
},
"serialized_data": {
"messages": [],
"state": {
"schema": {
"messages": {
"type": "list[haystack.dataclasses.chat_message.ChatMessage]",
"handler": "haystack.components.agents.state.state_utils.merge_lists",
}
},
"data": {
"serialization_schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"items": {"type": "haystack.dataclasses.chat_message.ChatMessage"},
}
},
},
"serialized_data": {
"messages": [
{
"role": "user",
"meta": {},
"name": None,
"content": [{"text": "What's the weather in Berlin?"}],
}
]
},
},
},
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "test_agent_breakpoints.weather_function",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
],
},
},
}
assert exc_info.value.results == {
"schema": {
"messages": {
"type": "list[haystack.dataclasses.chat_message.ChatMessage]",
"handler": "haystack.components.agents.state.state_utils.merge_lists",
}
},
"data": {
"serialization_schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"items": {"type": "haystack.dataclasses.chat_message.ChatMessage"},
}
},
},
"serialized_data": {
"messages": [
{
"role": "user",
"meta": {},
"name": None,
"content": [{"text": "What's the weather in Berlin?"}],
}
]
},
},
}
def test_run_with_tool_invoker_breakpoint(self, agent, chat_generator_serialization_schema):
agent_breakpoint = AgentBreakpoint(
break_point=ToolBreakpoint(component_name="tool_invoker", tool_name="weather_tool"), agent_name="test_agent"
)
with pytest.raises(BreakpointException) as exc_info:
agent.run(messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint)
assert isinstance(exc_info.value, BreakpointException)
assert exc_info.value.component == "tool_invoker"
assert exc_info.value.inputs == {
"chat_generator": {
"serialization_schema": chat_generator_serialization_schema,
"serialized_data": {
"messages": [
{
"role": "user",
"meta": {},
"name": None,
"content": [{"text": "What's the weather in Berlin?"}],
}
],
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "test_agent_breakpoints.weather_function",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
],
},
},
"tool_invoker": {
"serialization_schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"items": {"type": "haystack.dataclasses.chat_message.ChatMessage"},
},
"state": {"type": "haystack.components.agents.state.state.State"},
"tools": {"type": "array", "items": {"type": "haystack.tools.tool.Tool"}},
},
},
"serialized_data": {
"messages": [
{
"role": "assistant",
"meta": {},
"name": None,
"content": [
{"text": "I'll help you check the weather."},
{
"tool_call": {
"tool_name": "weather_tool",
"arguments": {"location": "Berlin"},
"id": None,
}
},
],
}
],
"state": {
"schema": {
"messages": {
"type": "list[haystack.dataclasses.chat_message.ChatMessage]",
"handler": "haystack.components.agents.state.state_utils.merge_lists",
}
},
"data": {
"serialization_schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"items": {"type": "haystack.dataclasses.chat_message.ChatMessage"},
}
},
},
"serialized_data": {
"messages": [
{
"role": "user",
"meta": {},
"name": None,
"content": [{"text": "What's the weather in Berlin?"}],
},
{
"role": "assistant",
"meta": {},
"name": None,
"content": [
{"text": "I'll help you check the weather."},
{
"tool_call": {
"tool_name": "weather_tool",
"arguments": {"location": "Berlin"},
"id": None,
}
},
],
},
]
},
},
},
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "test_agent_breakpoints.weather_function",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
],
},
},
}
assert exc_info.value.results == {
"schema": {
"messages": {
"type": "list[haystack.dataclasses.chat_message.ChatMessage]",
"handler": "haystack.components.agents.state.state_utils.merge_lists",
}
},
"data": {
"serialization_schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"items": {"type": "haystack.dataclasses.chat_message.ChatMessage"},
}
},
},
"serialized_data": {
"messages": [
{
"role": "user",
"meta": {},
"name": None,
"content": [{"text": "What's the weather in Berlin?"}],
},
{
"role": "assistant",
"meta": {},
"name": None,
"content": [
{"text": "I'll help you check the weather."},
{
"tool_call": {
"tool_name": "weather_tool",
"arguments": {"location": "Berlin"},
"id": None,
}
},
],
},
]
},
},
}
def test_resume_from_chat_generator(self, agent, tmp_path):
debug_path = str(tmp_path / "debug_snapshots")
agent_breakpoint = AgentBreakpoint(
break_point=Breakpoint(component_name="chat_generator", snapshot_file_path=debug_path),
agent_name="test_agent",
)
try:
agent.run(messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint)
except BreakpointException:
pass
snapshot_files = list(Path(debug_path).glob("test_agent_chat_generator_*.json"))
assert len(snapshot_files) > 0
latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
result = agent.run(
messages=[ChatMessage.from_user("This is actually ignored when resuming from snapshot.")],
snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot,
)
assert "messages" in result
assert "last_message" in result
# There should be 4 messages: user + assistant + tool call result + final assistant message
assert len(result["messages"]) == 4
def test_resume_from_tool_invoker(self, agent, tmp_path):
messages = [ChatMessage.from_user("What's the weather in Berlin?")]
debug_path = str(tmp_path / "debug_snapshots")
tool_bp = ToolBreakpoint(component_name="tool_invoker", snapshot_file_path=debug_path)
agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent")
try:
agent.run(messages=messages, break_point=agent_breakpoint)
except BreakpointException:
pass
snapshot_files = list(Path(debug_path).glob("test_agent_tool_invoker_*.json"))
assert len(snapshot_files) > 0
latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
result = agent.run(
messages=[ChatMessage.from_user("This is actually ignored when resuming from snapshot.")],
snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot,
)
assert "messages" in result
assert "last_message" in result
assert len(result["messages"]) > 0
def test_resume_from_tool_invoker_and_new_breakpoint(self, weather_tool, tmp_path):
agent = Agent(
chat_generator=MockChatGenerator(
[
ChatMessage.from_assistant(tool_calls=[ToolCall("weather_tool", {"location": "Berlin"})]),
ChatMessage.from_assistant(tool_calls=[ToolCall("weather_tool", {"location": "Paris"})]),
ChatMessage.from_assistant(text="The weather in Berlin and Paris is sunny."),
]
),
tools=[weather_tool],
)
debug_path = str(tmp_path / "debug_snapshots")
tool_bp = ToolBreakpoint(
component_name="tool_invoker", tool_name="weather_tool", visit_count=0, snapshot_file_path=debug_path
)
agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent")
# First run to create the snapshot at the tool invoker
try:
agent.run(messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint)
except BreakpointException:
pass
snapshot_files = list(Path(debug_path).glob("test_agent_tool_invoker_*.json"))
first_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
# Now resume from snapshot and trigger new breakpoint at the next visit of the same tool
new_breakpoint = AgentBreakpoint(break_point=replace(tool_bp, visit_count=1), agent_name="test_agent")
agent_snapshot = load_pipeline_snapshot(first_snapshot_file).agent_snapshot
try:
# messages not used when resuming from snapshot
_ = agent.run(messages=[], break_point=new_breakpoint, snapshot=agent_snapshot)
except BreakpointException:
pass
snapshot_files = list(Path(debug_path).glob("test_agent_tool_invoker_*.json"))
latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
# Resume again, this time the agent should complete
result = agent.run(
messages=[],
# Shouldn't trigger, but we pass here to show that we can pass a breakpoint even if not used
break_point=AgentBreakpoint(break_point=replace(tool_bp, visit_count=2), agent_name="test_agent"),
snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot,
)
# 1 user + 2 assistant + 2 tool call results + 1 final assistant message
assert len(result["messages"]) == 6
assert result["last_message"].text == "The weather in Berlin and Paris is sunny."
def test_breakpoint_with_invalid_component_name(self):
invalid_bp = Breakpoint(component_name="invalid_breakpoint")
with pytest.raises(ValueError):
AgentBreakpoint(break_point=invalid_bp, agent_name="test_agent")
def test_breakpoint_with_invalid_tool_name(self, agent):
with pytest.raises(ValueError, match="Tool 'invalid_tool' is not available in the agent's tools"):
agent_breakpoint = AgentBreakpoint(
break_point=ToolBreakpoint(component_name="tool_invoker", tool_name="invalid_tool"),
agent_name="test_agent",
)
agent.run(messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint)
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_live_resume_from_tool_invoker(self, tmp_path, weather_tool):
agent = Agent(chat_generator=OpenAIChatGenerator(model="gpt-4o"), tools=[weather_tool])
debug_path = str(tmp_path / "debug_snapshots")
agent_breakpoint = AgentBreakpoint(
break_point=ToolBreakpoint(component_name="tool_invoker", snapshot_file_path=debug_path),
agent_name="test_agent",
)
try:
agent.run(messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint)
except BreakpointException:
pass
snapshot_files = list(Path(debug_path).glob("test_agent_tool_invoker_*.json"))
assert len(snapshot_files) > 0
latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
result = agent.run(
messages=[ChatMessage.from_user("This is actually ignored when resuming from snapshot.")],
snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot,
)
assert "messages" in result
assert "last_message" in result
assert len(result["messages"]) == 4
assert "berlin" in result["last_message"].text.lower()
class TestAsyncAgentBreakpoints:
@pytest.mark.asyncio
async def test_run_async_with_chat_generator_breakpoint(self, agent):
agent_breakpoint = AgentBreakpoint(
break_point=Breakpoint(component_name="chat_generator"), agent_name="test_agent"
)
with pytest.raises(BreakpointException) as exc_info:
await agent.run_async(
messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint
)
assert exc_info.value.component == "chat_generator"
assert "messages" in exc_info.value.inputs["chat_generator"]["serialized_data"]
@pytest.mark.asyncio
async def test_run_async_with_tool_invoker_breakpoint(self, agent):
agent_breakpoint = AgentBreakpoint(
break_point=ToolBreakpoint(component_name="tool_invoker", tool_name="weather_tool"), agent_name="test"
)
with pytest.raises(BreakpointException) as exc_info:
await agent.run_async(
messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint
)
assert exc_info.value.component == "tool_invoker"
assert "messages" in exc_info.value.inputs["tool_invoker"]["serialized_data"]
@pytest.mark.asyncio
async def test_resume_from_chat_generator_async(self, agent, tmp_path):
debug_path = str(tmp_path / "debug_snapshots")
chat_generator_bp = Breakpoint(component_name="chat_generator", snapshot_file_path=debug_path)
agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name="test_agent")
try:
await agent.run_async(
messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint
)
except BreakpointException:
pass
snapshot_files = list(Path(debug_path).glob("test_agent_chat_generator_*.json"))
assert len(snapshot_files) > 0
latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
result = await agent.run_async(
messages=[ChatMessage.from_user("This is actually ignored when resuming from snapshot.")],
snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot,
)
assert "messages" in result
assert "last_message" in result
assert len(result["messages"]) == 4
@pytest.mark.asyncio
async def test_resume_from_tool_invoker_async(self, agent, tmp_path):
debug_path = str(tmp_path / "debug_snapshots")
messages = [ChatMessage.from_user("What's the weather in Berlin?")]
tool_bp = ToolBreakpoint(component_name="tool_invoker", tool_name="weather_tool", snapshot_file_path=debug_path)
agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent")
try:
await agent.run_async(messages=messages, break_point=agent_breakpoint)
except BreakpointException:
pass
snapshot_files = list(Path(debug_path).glob("test_agent_tool_invoker_*.json"))
assert len(snapshot_files) > 0
latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
result = await agent.run_async(
messages=[ChatMessage.from_user("This is actually ignored when resuming from snapshot.")],
snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot,
)
assert "messages" in result
assert "last_message" in result
assert len(result["messages"]) > 0
@pytest.mark.asyncio
async def test_resume_from_tool_invoker_and_new_breakpoint_async(self, weather_tool, tmp_path):
agent = Agent(
chat_generator=MockChatGenerator(
[
ChatMessage.from_assistant(tool_calls=[ToolCall("weather_tool", {"location": "Berlin"})]),
ChatMessage.from_assistant(tool_calls=[ToolCall("weather_tool", {"location": "Paris"})]),
ChatMessage.from_assistant(text="The weather in Berlin and Paris is sunny."),
]
),
tools=[weather_tool],
)
debug_path = str(tmp_path / "debug_snapshots")
tool_bp = ToolBreakpoint(
component_name="tool_invoker", tool_name="weather_tool", visit_count=0, snapshot_file_path=debug_path
)
agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent")
# First run to create the snapshot at the tool invoker
try:
await agent.run_async(
messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint
)
except BreakpointException:
pass
snapshot_files = list(Path(debug_path).glob("test_agent_tool_invoker_*.json"))
assert len(snapshot_files) > 0
first_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
# Now resume from snapshot and trigger new breakpoint at the next visit of the same tool
new_breakpoint = AgentBreakpoint(break_point=replace(tool_bp, visit_count=1), agent_name="test_agent")
agent_snapshot = load_pipeline_snapshot(first_snapshot_file).agent_snapshot
try:
# messages not used when resuming from snapshot
_ = await agent.run_async(messages=[], break_point=new_breakpoint, snapshot=agent_snapshot)
except BreakpointException:
pass
snapshot_files = list(Path(debug_path).glob("test_agent_tool_invoker_*.json"))
latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
# Resume again
result = await agent.run_async(
messages=[],
# Shouldn't trigger, but we pass here to show that we can pass a breakpoint even if not used
break_point=AgentBreakpoint(break_point=replace(tool_bp, visit_count=2), agent_name="test_agent"),
snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot,
)
# 1 user + 2 assistant + 2 tool call results + 1 final assistant message
assert len(result["messages"]) == 6
assert result["last_message"].text == "The weather in Berlin and Paris is sunny."
@pytest.mark.asyncio
async def test_breakpoint_with_invalid_tool_name_async(self, agent):
agent_breakpoint = AgentBreakpoint(
break_point=ToolBreakpoint(component_name="tool_invoker", tool_name="invalid_tool"), agent_name="test"
)
with pytest.raises(ValueError, match="Tool 'invalid_tool' is not available in the agent's tools"):
await agent.run_async(
messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint
)