haystack/test/components/agents/test_agent_breakpoints_inside_pipeline.py
Abdelrahman Kaseb 5f3c37d287
chore: adopt PEP 585 type hints (#9678)
* chore(lint): enforce and apply PEP 585 type hinting

* Run fmt fixes

* Fix all typing imports using some regex

* Fix all typing written in string in tests

* undo changes in the e2e tests

* make e2e test use list instead of List

* type fixes

* remove type:ignore

* pylint

* Remove typing from Usage example comments

* Remove typing from most of comments

* try to fix e2e tests on comm PRs

* fix

* Add tests typing.List in to adjust test compatiplity
- test/components/agents/test_state_class.py
- test/components/converters/test_output_adapter.py
- test/components/joiners/test_list_joiner.py

* simplify pyproject

* improve relnote

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
2025-08-07 10:23:14 +02:00

361 lines
16 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import os
import re
import tempfile
from pathlib import Path
from typing import Optional
import pytest
from haystack import component
from haystack.components.agents import Agent
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.core.errors import BreakpointException
from haystack.core.pipeline import Pipeline
from haystack.core.pipeline.breakpoint import load_pipeline_snapshot
from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall
from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.tools import create_tool_from_function
document_store = InMemoryDocumentStore()
@component
class MockLinkContentFetcher:
@component.output_types(streams=list[ByteStream])
def run(self, urls: list[str]) -> dict[str, list[ByteStream]]:
mock_html_content = """
<!DOCTYPE html>
<html>
<head>
<title>Deepset - About Our Team</title>
</head>
<body>
<h1>About Deepset</h1>
<p>Deepset is a company focused on natural language processing and AI.</p>
<h2>Our Leadership Team</h2>
<div class="team-member">
<h3>Malte Pietsch</h3>
<p>Malte Pietsch is the CEO and co-founder of Deepset. He has extensive experience in machine learning
and natural language processing.</p>
<p>Job Title: Chief Executive Officer</p>
</div>
<div class="team-member">
<h3>Milos Rusic</h3>
<p>Milos Rusic is the CTO and co-founder of Deepset. He specializes in building scalable AI systems
and has worked on various NLP projects.</p>
<p>Job Title: Chief Technology Officer</p>
</div>
<h2>Our Mission</h2>
<p>Deepset aims to make natural language processing accessible to developers and businesses worldwide
through open-source tools and enterprise solutions.</p>
</body>
</html>
"""
bytestream = ByteStream(
data=mock_html_content.encode("utf-8"),
mime_type="text/html",
meta={"url": urls[0] if urls else "https://en.wikipedia.org/wiki/Deepset"},
)
return {"streams": [bytestream]}
@component
class MockHTMLToDocument:
@component.output_types(documents=list[Document])
def run(self, sources: list[ByteStream]) -> dict[str, list[Document]]:
"""Mock HTML to Document converter that extracts text content from HTML ByteStreams."""
documents = []
for source in sources:
# Extract the HTML content from the ByteStream
html_content = source.data.decode("utf-8")
# Simple text extraction - remove HTML tags and extract meaningful content
# This is a simplified version that extracts the main content
# Remove HTML tags
text_content = re.sub(r"<[^>]+>", " ", html_content)
# Remove extra whitespace
text_content = re.sub(r"\s+", " ", text_content).strip()
# Create a Document with the extracted text
document = Document(
content=text_content,
meta={"url": source.meta.get("url", "unknown"), "mime_type": source.mime_type, "source_type": "html"},
)
documents.append(document)
return {"documents": documents}
def add_database_tool_function(name: str, surname: str, job_title: Optional[str], other: Optional[str]):
document_store.write_documents(
[Document(content=name + " " + surname + " " + (job_title or ""), meta={"other": other})]
)
# We use this since the @tool decorator has issues with deserialization
add_database_tool = create_tool_from_function(add_database_tool_function, name="add_database_tool")
@pytest.fixture
def pipeline_with_agent(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_key")
generator = OpenAIChatGenerator()
call_count = 0
def mock_run(messages, tools=None, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return {
"replies": [
ChatMessage.from_assistant(
"I'll extract the information about the people mentioned in the context.",
tool_calls=[
ToolCall(
tool_name="add_database_tool",
arguments={
"name": "Malte",
"surname": "Pietsch",
"job_title": "Chief Executive Officer",
"other": "CEO and co-founder of Deepset with extensive experience in machine "
"learning and natural language processing",
},
),
ToolCall(
tool_name="add_database_tool",
arguments={
"name": "Milos",
"surname": "Rusic",
"job_title": "Chief Technology Officer",
"other": "CTO and co-founder of Deepset specializing in building scalable "
"AI systems and NLP projects",
},
),
],
)
]
}
else:
return {
"replies": [
ChatMessage.from_assistant(
"I have successfully extracted and stored information about the following people:\n\n"
"1. **Malte Pietsch** - Chief Executive Officer\n"
" - CEO and co-founder of Deepset\n"
" - Extensive experience in machine learning and natural language processing\n\n"
"2. **Milos Rusic** - Chief Technology Officer\n"
" - CTO and co-founder of Deepset\n"
" - Specializes in building scalable AI systems and NLP projects\n\n"
"Both individuals have been added to the knowledge base with their respective information."
)
]
}
generator.run = mock_run
database_assistant = Agent(
chat_generator=generator,
tools=[add_database_tool],
system_prompt="""
You are a database assistant.
Your task is to extract the names of people mentioned in the given context and add them to a knowledge base,
along with additional relevant information about them that can be extracted from the context.
Do not use you own knowledge, stay grounded to the given context.
Do not ask the user for confirmation. Instead, automatically update the knowledge base and return a brief
summary of the people added, including the information stored for each.
""",
exit_conditions=["text"],
max_agent_steps=100,
raise_on_tool_invocation_failure=False,
)
extraction_agent = Pipeline()
extraction_agent.add_component("fetcher", MockLinkContentFetcher())
extraction_agent.add_component("converter", MockHTMLToDocument())
extraction_agent.add_component(
"builder",
ChatPromptBuilder(
template=[
ChatMessage.from_user("""
{% for doc in docs %}
{{ doc.content|default|truncate(25000) }}
{% endfor %}
""")
],
required_variables=["docs"],
),
)
extraction_agent.add_component("database_agent", database_assistant)
extraction_agent.connect("fetcher.streams", "converter.sources")
extraction_agent.connect("converter.documents", "builder.docs")
extraction_agent.connect("builder", "database_agent")
return extraction_agent
def run_pipeline_without_any_breakpoints(pipeline_with_agent):
agent_output = pipeline_with_agent.run(data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}})
# pipeline completed
assert "database_agent" in agent_output
assert "messages" in agent_output["database_agent"]
assert len(agent_output["database_agent"]["messages"]) > 0
# final message contains the expected summary
final_message = agent_output["database_agent"]["messages"][-1].text
assert "Malte Pietsch" in final_message
assert "Milos Rusic" in final_message
assert "Chief Executive Officer" in final_message
assert "Chief Technology Officer" in final_message
def test_chat_generator_breakpoint_in_pipeline_agent(pipeline_with_agent):
with tempfile.TemporaryDirectory() as debug_path:
agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path)
agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent")
try:
pipeline_with_agent.run(
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
)
assert False, "Expected exception was not raised"
except BreakpointException as e: # this is the exception from the Agent
assert e.component == "chat_generator"
assert e.inputs is not None
assert "messages" in e.inputs["chat_generator"]["serialized_data"]
assert e.results is not None
# verify that snapshot file was created
chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json"))
assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}"
def test_tool_breakpoint_in_pipeline_agent(pipeline_with_agent):
with tempfile.TemporaryDirectory() as debug_path:
agent_tool_breakpoint = ToolBreakpoint(
"tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path
)
agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent")
try:
pipeline_with_agent.run(
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
)
assert False, "Expected exception was not raised"
except BreakpointException as e: # this is the exception from the Agent
assert e.component == "tool_invoker"
assert e.inputs is not None
assert "messages" in e.inputs["tool_invoker"]["serialized_data"]
assert e.results is not None
# verify that snapshot file was created
tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json"))
assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}"
def test_agent_breakpoint_chat_generator_and_resume_pipeline(pipeline_with_agent):
with tempfile.TemporaryDirectory() as debug_path:
agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path)
agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent")
try:
pipeline_with_agent.run(
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
)
assert False, "Expected PipelineBreakpointException was not raised"
except BreakpointException as e:
assert e.component == "chat_generator"
assert e.inputs is not None
assert "messages" in e.inputs["chat_generator"]["serialized_data"]
assert e.results is not None
# verify that the snapshot file was created
chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json"))
assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}"
# resume the pipeline from the saved snapshot
latest_snapshot_file = max(chat_generator_snapshot_files, key=os.path.getctime)
result = pipeline_with_agent.run(data={}, pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file))
# pipeline completed successfully after resuming
assert "database_agent" in result
assert "messages" in result["database_agent"]
assert len(result["database_agent"]["messages"]) > 0
# final message contains the expected summary
final_message = result["database_agent"]["messages"][-1].text
assert "Malte Pietsch" in final_message
assert "Milos Rusic" in final_message
assert "Chief Executive Officer" in final_message
assert "Chief Technology Officer" in final_message
# tool should have been called during the resumed execution
documents = document_store.filter_documents()
assert len(documents) >= 2, "Expected at least 2 documents to be added to the database"
# both people were added
person_names = [doc.content for doc in documents]
assert any("Malte Pietsch" in name for name in person_names)
assert any("Milos Rusic" in name for name in person_names)
def test_agent_breakpoint_tool_and_resume_pipeline(pipeline_with_agent):
with tempfile.TemporaryDirectory() as debug_path:
agent_tool_breakpoint = ToolBreakpoint(
"tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path
)
agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent")
try:
pipeline_with_agent.run(
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
)
assert False, "Expected PipelineBreakpointException was not raised"
except BreakpointException as e:
assert e.component == "tool_invoker"
assert e.inputs is not None
assert "serialization_schema" in e.inputs["tool_invoker"]
assert "serialized_data" in e.inputs["tool_invoker"]
assert "messages" in e.inputs["tool_invoker"]["serialized_data"]
assert e.results is not None
# verify that the snapshot file was created
tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json"))
assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}"
# resume the pipeline from the saved snapshot
latest_snapshot_file = max(tool_invoker_snapshot_files, key=os.path.getctime)
pipeline_snapshot = load_pipeline_snapshot(latest_snapshot_file)
result = pipeline_with_agent.run(data={}, pipeline_snapshot=pipeline_snapshot)
# pipeline completed successfully after resuming
assert "database_agent" in result
assert "messages" in result["database_agent"]
assert len(result["database_agent"]["messages"]) > 0
# final message contains the expected summary
final_message = result["database_agent"]["messages"][-1].text
assert "Malte Pietsch" in final_message
assert "Milos Rusic" in final_message
assert "Chief Executive Officer" in final_message
assert "Chief Technology Officer" in final_message
# tool should have been called during the resumed execution
documents = document_store.filter_documents()
assert len(documents) >= 2, "Expected at least 2 documents to be added to the database"
# both people were added
person_names = [doc.content for doc in documents]
assert any("Malte Pietsch" in name for name in person_names)
assert any("Milos Rusic" in name for name in person_names)