mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-18 02:28:36 +00:00
* wip: fixing tests * wip: fixing tests * wip: fixing tests * wip: fixing tests * fixing circular imports * decoupling resume and initial run() for agent * adding release notes * re-raising BreakPointException from pipeline.run() * fixing imports * refactor: Refactor suggestions for Pipeline breakpoints (#9614) * Refactoring * Start adding debug_path into Breakpoint class * Fully move debug_path into Breakpoint dataclass * Simplifications in pipeline run logic * More simplification * lint * More simplification * Updates * Rename resume_state to pipeline_snapshot * PR comments * Missed renaming of state in a few more places * feat: Add dataclasses to represent a `PipelineSnapshot` and refactored to use it (#9619) * Refactor to use dataclasses for PipelineSnapshot and AgentSnapshot * Fix integration tests * Mypy * Fix mypy * Fix lint * Refactor AgentSnapshot to only contain needed info * Fix mypy * More refactoring * removing unused import --------- Co-authored-by: David S. Batista <dsbatista@gmail.com> * feat: saving include_outputs_from intermediate results to `PipelineState` object (#9629) * saving intermediate components results in include_outputs_from into the PipelineSnaptshot * cleaning up * fixing tests * fixing tests * extending tests * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * linting * moving intermediate results to pipeline state and adding pipeline outputs to state * moving ordered_component_names and include_outputs_from to PipelineSnapshot * moving original_input_data to PipelineSnapshot * simplifying saving the intermediate results * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * linting * cleaning up * avoiding creating PipelineSnapshot for every component run * removing unecessary code * Update checks in Agent to not unecessarily create AgentSnapshot when not needed. * Update haystack/components/agents/agent.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update haystack/components/agents/agent.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * cleaning up tests * linting --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
361 lines
16 KiB
Python
361 lines
16 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import os
|
|
import re
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional
|
|
|
|
import pytest
|
|
|
|
from haystack import component
|
|
from haystack.components.agents import Agent
|
|
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
|
|
from haystack.components.generators.chat import OpenAIChatGenerator
|
|
from haystack.core.errors import BreakpointException
|
|
from haystack.core.pipeline import Pipeline
|
|
from haystack.core.pipeline.breakpoint import load_pipeline_snapshot
|
|
from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall
|
|
from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
|
from haystack.tools import create_tool_from_function
|
|
|
|
document_store = InMemoryDocumentStore()
|
|
|
|
|
|
@component
|
|
class MockLinkContentFetcher:
|
|
@component.output_types(streams=List[ByteStream])
|
|
def run(self, urls: List[str]) -> Dict[str, List[ByteStream]]:
|
|
mock_html_content = """
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Deepset - About Our Team</title>
|
|
</head>
|
|
<body>
|
|
<h1>About Deepset</h1>
|
|
<p>Deepset is a company focused on natural language processing and AI.</p>
|
|
<h2>Our Leadership Team</h2>
|
|
<div class="team-member">
|
|
<h3>Malte Pietsch</h3>
|
|
<p>Malte Pietsch is the CEO and co-founder of Deepset. He has extensive experience in machine learning
|
|
and natural language processing.</p>
|
|
<p>Job Title: Chief Executive Officer</p>
|
|
</div>
|
|
<div class="team-member">
|
|
<h3>Milos Rusic</h3>
|
|
<p>Milos Rusic is the CTO and co-founder of Deepset. He specializes in building scalable AI systems
|
|
and has worked on various NLP projects.</p>
|
|
<p>Job Title: Chief Technology Officer</p>
|
|
</div>
|
|
<h2>Our Mission</h2>
|
|
<p>Deepset aims to make natural language processing accessible to developers and businesses worldwide
|
|
through open-source tools and enterprise solutions.</p>
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
bytestream = ByteStream(
|
|
data=mock_html_content.encode("utf-8"),
|
|
mime_type="text/html",
|
|
meta={"url": urls[0] if urls else "https://en.wikipedia.org/wiki/Deepset"},
|
|
)
|
|
|
|
return {"streams": [bytestream]}
|
|
|
|
|
|
@component
|
|
class MockHTMLToDocument:
|
|
@component.output_types(documents=List[Document])
|
|
def run(self, sources: List[ByteStream]) -> Dict[str, List[Document]]:
|
|
"""Mock HTML to Document converter that extracts text content from HTML ByteStreams."""
|
|
|
|
documents = []
|
|
for source in sources:
|
|
# Extract the HTML content from the ByteStream
|
|
html_content = source.data.decode("utf-8")
|
|
|
|
# Simple text extraction - remove HTML tags and extract meaningful content
|
|
# This is a simplified version that extracts the main content
|
|
# Remove HTML tags
|
|
text_content = re.sub(r"<[^>]+>", " ", html_content)
|
|
# Remove extra whitespace
|
|
text_content = re.sub(r"\s+", " ", text_content).strip()
|
|
|
|
# Create a Document with the extracted text
|
|
document = Document(
|
|
content=text_content,
|
|
meta={"url": source.meta.get("url", "unknown"), "mime_type": source.mime_type, "source_type": "html"},
|
|
)
|
|
documents.append(document)
|
|
|
|
return {"documents": documents}
|
|
|
|
|
|
def add_database_tool_function(name: str, surname: str, job_title: Optional[str], other: Optional[str]):
|
|
document_store.write_documents(
|
|
[Document(content=name + " " + surname + " " + (job_title or ""), meta={"other": other})]
|
|
)
|
|
|
|
|
|
# We use this since the @tool decorator has issues with deserialization
|
|
add_database_tool = create_tool_from_function(add_database_tool_function, name="add_database_tool")
|
|
|
|
|
|
@pytest.fixture
|
|
def pipeline_with_agent(monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test_key")
|
|
generator = OpenAIChatGenerator()
|
|
call_count = 0
|
|
|
|
def mock_run(messages, tools=None, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
|
|
if call_count == 1:
|
|
return {
|
|
"replies": [
|
|
ChatMessage.from_assistant(
|
|
"I'll extract the information about the people mentioned in the context.",
|
|
tool_calls=[
|
|
ToolCall(
|
|
tool_name="add_database_tool",
|
|
arguments={
|
|
"name": "Malte",
|
|
"surname": "Pietsch",
|
|
"job_title": "Chief Executive Officer",
|
|
"other": "CEO and co-founder of Deepset with extensive experience in machine "
|
|
"learning and natural language processing",
|
|
},
|
|
),
|
|
ToolCall(
|
|
tool_name="add_database_tool",
|
|
arguments={
|
|
"name": "Milos",
|
|
"surname": "Rusic",
|
|
"job_title": "Chief Technology Officer",
|
|
"other": "CTO and co-founder of Deepset specializing in building scalable "
|
|
"AI systems and NLP projects",
|
|
},
|
|
),
|
|
],
|
|
)
|
|
]
|
|
}
|
|
else:
|
|
return {
|
|
"replies": [
|
|
ChatMessage.from_assistant(
|
|
"I have successfully extracted and stored information about the following people:\n\n"
|
|
"1. **Malte Pietsch** - Chief Executive Officer\n"
|
|
" - CEO and co-founder of Deepset\n"
|
|
" - Extensive experience in machine learning and natural language processing\n\n"
|
|
"2. **Milos Rusic** - Chief Technology Officer\n"
|
|
" - CTO and co-founder of Deepset\n"
|
|
" - Specializes in building scalable AI systems and NLP projects\n\n"
|
|
"Both individuals have been added to the knowledge base with their respective information."
|
|
)
|
|
]
|
|
}
|
|
|
|
generator.run = mock_run
|
|
|
|
database_assistant = Agent(
|
|
chat_generator=generator,
|
|
tools=[add_database_tool],
|
|
system_prompt="""
|
|
You are a database assistant.
|
|
Your task is to extract the names of people mentioned in the given context and add them to a knowledge base,
|
|
along with additional relevant information about them that can be extracted from the context.
|
|
Do not use you own knowledge, stay grounded to the given context.
|
|
Do not ask the user for confirmation. Instead, automatically update the knowledge base and return a brief
|
|
summary of the people added, including the information stored for each.
|
|
""",
|
|
exit_conditions=["text"],
|
|
max_agent_steps=100,
|
|
raise_on_tool_invocation_failure=False,
|
|
)
|
|
|
|
extraction_agent = Pipeline()
|
|
extraction_agent.add_component("fetcher", MockLinkContentFetcher())
|
|
extraction_agent.add_component("converter", MockHTMLToDocument())
|
|
extraction_agent.add_component(
|
|
"builder",
|
|
ChatPromptBuilder(
|
|
template=[
|
|
ChatMessage.from_user("""
|
|
{% for doc in docs %}
|
|
{{ doc.content|default|truncate(25000) }}
|
|
{% endfor %}
|
|
""")
|
|
],
|
|
required_variables=["docs"],
|
|
),
|
|
)
|
|
extraction_agent.add_component("database_agent", database_assistant)
|
|
|
|
extraction_agent.connect("fetcher.streams", "converter.sources")
|
|
extraction_agent.connect("converter.documents", "builder.docs")
|
|
extraction_agent.connect("builder", "database_agent")
|
|
|
|
return extraction_agent
|
|
|
|
|
|
def run_pipeline_without_any_breakpoints(pipeline_with_agent):
|
|
agent_output = pipeline_with_agent.run(data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}})
|
|
|
|
# pipeline completed
|
|
assert "database_agent" in agent_output
|
|
assert "messages" in agent_output["database_agent"]
|
|
assert len(agent_output["database_agent"]["messages"]) > 0
|
|
|
|
# final message contains the expected summary
|
|
final_message = agent_output["database_agent"]["messages"][-1].text
|
|
assert "Malte Pietsch" in final_message
|
|
assert "Milos Rusic" in final_message
|
|
assert "Chief Executive Officer" in final_message
|
|
assert "Chief Technology Officer" in final_message
|
|
|
|
|
|
def test_chat_generator_breakpoint_in_pipeline_agent(pipeline_with_agent):
|
|
with tempfile.TemporaryDirectory() as debug_path:
|
|
agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path)
|
|
agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent")
|
|
try:
|
|
pipeline_with_agent.run(
|
|
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
|
|
)
|
|
assert False, "Expected exception was not raised"
|
|
|
|
except BreakpointException as e: # this is the exception from the Agent
|
|
assert e.component == "chat_generator"
|
|
assert e.inputs is not None
|
|
assert "messages" in e.inputs["chat_generator"]["serialized_data"]
|
|
assert e.results is not None
|
|
|
|
# verify that snapshot file was created
|
|
chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json"))
|
|
assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}"
|
|
|
|
|
|
def test_tool_breakpoint_in_pipeline_agent(pipeline_with_agent):
|
|
with tempfile.TemporaryDirectory() as debug_path:
|
|
agent_tool_breakpoint = ToolBreakpoint(
|
|
"tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path
|
|
)
|
|
agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent")
|
|
try:
|
|
pipeline_with_agent.run(
|
|
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
|
|
)
|
|
assert False, "Expected exception was not raised"
|
|
except BreakpointException as e: # this is the exception from the Agent
|
|
assert e.component == "tool_invoker"
|
|
assert e.inputs is not None
|
|
assert "messages" in e.inputs["tool_invoker"]["serialized_data"]
|
|
assert e.results is not None
|
|
|
|
# verify that snapshot file was created
|
|
tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json"))
|
|
assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}"
|
|
|
|
|
|
def test_agent_breakpoint_chat_generator_and_resume_pipeline(pipeline_with_agent):
|
|
with tempfile.TemporaryDirectory() as debug_path:
|
|
agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path)
|
|
agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent")
|
|
try:
|
|
pipeline_with_agent.run(
|
|
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
|
|
)
|
|
assert False, "Expected PipelineBreakpointException was not raised"
|
|
|
|
except BreakpointException as e:
|
|
assert e.component == "chat_generator"
|
|
assert e.inputs is not None
|
|
assert "messages" in e.inputs["chat_generator"]["serialized_data"]
|
|
assert e.results is not None
|
|
|
|
# verify that the snapshot file was created
|
|
chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json"))
|
|
assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}"
|
|
|
|
# resume the pipeline from the saved snapshot
|
|
latest_snapshot_file = max(chat_generator_snapshot_files, key=os.path.getctime)
|
|
result = pipeline_with_agent.run(data={}, pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file))
|
|
|
|
# pipeline completed successfully after resuming
|
|
assert "database_agent" in result
|
|
assert "messages" in result["database_agent"]
|
|
assert len(result["database_agent"]["messages"]) > 0
|
|
|
|
# final message contains the expected summary
|
|
final_message = result["database_agent"]["messages"][-1].text
|
|
assert "Malte Pietsch" in final_message
|
|
assert "Milos Rusic" in final_message
|
|
assert "Chief Executive Officer" in final_message
|
|
assert "Chief Technology Officer" in final_message
|
|
|
|
# tool should have been called during the resumed execution
|
|
documents = document_store.filter_documents()
|
|
assert len(documents) >= 2, "Expected at least 2 documents to be added to the database"
|
|
|
|
# both people were added
|
|
person_names = [doc.content for doc in documents]
|
|
assert any("Malte Pietsch" in name for name in person_names)
|
|
assert any("Milos Rusic" in name for name in person_names)
|
|
|
|
|
|
def test_agent_breakpoint_tool_and_resume_pipeline(pipeline_with_agent):
|
|
with tempfile.TemporaryDirectory() as debug_path:
|
|
agent_tool_breakpoint = ToolBreakpoint(
|
|
"tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path
|
|
)
|
|
agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent")
|
|
try:
|
|
pipeline_with_agent.run(
|
|
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
|
|
)
|
|
assert False, "Expected PipelineBreakpointException was not raised"
|
|
|
|
except BreakpointException as e:
|
|
assert e.component == "tool_invoker"
|
|
assert e.inputs is not None
|
|
assert "serialization_schema" in e.inputs["tool_invoker"]
|
|
assert "serialized_data" in e.inputs["tool_invoker"]
|
|
assert "messages" in e.inputs["tool_invoker"]["serialized_data"]
|
|
assert e.results is not None
|
|
|
|
# verify that the snapshot file was created
|
|
tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json"))
|
|
assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}"
|
|
|
|
# resume the pipeline from the saved snapshot
|
|
latest_snapshot_file = max(tool_invoker_snapshot_files, key=os.path.getctime)
|
|
pipeline_snapshot = load_pipeline_snapshot(latest_snapshot_file)
|
|
result = pipeline_with_agent.run(data={}, pipeline_snapshot=pipeline_snapshot)
|
|
|
|
# pipeline completed successfully after resuming
|
|
assert "database_agent" in result
|
|
assert "messages" in result["database_agent"]
|
|
assert len(result["database_agent"]["messages"]) > 0
|
|
|
|
# final message contains the expected summary
|
|
final_message = result["database_agent"]["messages"][-1].text
|
|
assert "Malte Pietsch" in final_message
|
|
assert "Milos Rusic" in final_message
|
|
assert "Chief Executive Officer" in final_message
|
|
assert "Chief Technology Officer" in final_message
|
|
|
|
# tool should have been called during the resumed execution
|
|
documents = document_store.filter_documents()
|
|
assert len(documents) >= 2, "Expected at least 2 documents to be added to the database"
|
|
|
|
# both people were added
|
|
person_names = [doc.content for doc in documents]
|
|
assert any("Malte Pietsch" in name for name in person_names)
|
|
assert any("Milos Rusic" in name for name in person_names)
|