mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-14 16:47:06 +00:00
* Start refactor * Update run_async to use the new code * Slight updates * Refactoring of tests * Remove messages from execution context * Cleanup * More cleanup * Formatting * Fix some typing * ignore typing issues * Add reno * Adding docstrings * Small changes * docstrings * Updates * Update haystack/components/agents/agent.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * PR comments * PR comments --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
357 lines
16 KiB
Python
357 lines
16 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import os
|
|
import re
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import pytest
|
|
|
|
from haystack import component
|
|
from haystack.components.agents import Agent
|
|
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
|
|
from haystack.components.generators.chat import OpenAIChatGenerator
|
|
from haystack.core.errors import BreakpointException
|
|
from haystack.core.pipeline import Pipeline
|
|
from haystack.core.pipeline.breakpoint import load_pipeline_snapshot
|
|
from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall
|
|
from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
|
from haystack.tools import create_tool_from_function
|
|
|
|
document_store = InMemoryDocumentStore()
|
|
|
|
|
|
@component
|
|
class MockLinkContentFetcher:
|
|
@component.output_types(streams=list[ByteStream])
|
|
def run(self, urls: list[str]) -> dict[str, list[ByteStream]]:
|
|
mock_html_content = """
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Deepset - About Our Team</title>
|
|
</head>
|
|
<body>
|
|
<h1>About Deepset</h1>
|
|
<p>Deepset is a company focused on natural language processing and AI.</p>
|
|
<h2>Our Leadership Team</h2>
|
|
<div class="team-member">
|
|
<h3>Malte Pietsch</h3>
|
|
<p>Malte Pietsch is the CEO and co-founder of Deepset. He has extensive experience in machine learning
|
|
and natural language processing.</p>
|
|
<p>Job Title: Chief Executive Officer</p>
|
|
</div>
|
|
<div class="team-member">
|
|
<h3>Milos Rusic</h3>
|
|
<p>Milos Rusic is the CTO and co-founder of Deepset. He specializes in building scalable AI systems
|
|
and has worked on various NLP projects.</p>
|
|
<p>Job Title: Chief Technology Officer</p>
|
|
</div>
|
|
<h2>Our Mission</h2>
|
|
<p>Deepset aims to make natural language processing accessible to developers and businesses worldwide
|
|
through open-source tools and enterprise solutions.</p>
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
bytestream = ByteStream(
|
|
data=mock_html_content.encode("utf-8"),
|
|
mime_type="text/html",
|
|
meta={"url": urls[0] if urls else "https://en.wikipedia.org/wiki/Deepset"},
|
|
)
|
|
|
|
return {"streams": [bytestream]}
|
|
|
|
|
|
@component
|
|
class MockHTMLToDocument:
|
|
@component.output_types(documents=list[Document])
|
|
def run(self, sources: list[ByteStream]) -> dict[str, list[Document]]:
|
|
"""Mock HTML to Document converter that extracts text content from HTML ByteStreams."""
|
|
|
|
documents = []
|
|
for source in sources:
|
|
html_content = source.data.decode("utf-8")
|
|
|
|
# Simple text extraction - remove HTML tags and extract meaningful content
|
|
# Remove HTML tags
|
|
text_content = re.sub(r"<[^>]+>", " ", html_content)
|
|
# Remove extra whitespace
|
|
text_content = re.sub(r"\s+", " ", text_content).strip()
|
|
|
|
document = Document(
|
|
content=text_content,
|
|
meta={"url": source.meta.get("url", "unknown"), "mime_type": source.mime_type, "source_type": "html"},
|
|
)
|
|
documents.append(document)
|
|
|
|
return {"documents": documents}
|
|
|
|
|
|
def add_database_tool_function(name: str, surname: str, job_title: Optional[str], other: Optional[str]):
|
|
document_store.write_documents(
|
|
[Document(content=name + " " + surname + " " + (job_title or ""), meta={"other": other})]
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def pipeline_with_agent(monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test_key")
|
|
generator = OpenAIChatGenerator()
|
|
call_count = 0
|
|
|
|
def mock_run(messages, tools=None, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
|
|
if call_count == 1:
|
|
return {
|
|
"replies": [
|
|
ChatMessage.from_assistant(
|
|
"I'll extract the information about the people mentioned in the context.",
|
|
tool_calls=[
|
|
ToolCall(
|
|
tool_name="add_database_tool",
|
|
arguments={
|
|
"name": "Malte",
|
|
"surname": "Pietsch",
|
|
"job_title": "Chief Executive Officer",
|
|
"other": "CEO and co-founder of Deepset with extensive experience in machine "
|
|
"learning and natural language processing",
|
|
},
|
|
),
|
|
ToolCall(
|
|
tool_name="add_database_tool",
|
|
arguments={
|
|
"name": "Milos",
|
|
"surname": "Rusic",
|
|
"job_title": "Chief Technology Officer",
|
|
"other": "CTO and co-founder of Deepset specializing in building scalable "
|
|
"AI systems and NLP projects",
|
|
},
|
|
),
|
|
],
|
|
)
|
|
]
|
|
}
|
|
else:
|
|
return {
|
|
"replies": [
|
|
ChatMessage.from_assistant(
|
|
"I have successfully extracted and stored information about the following people:\n\n"
|
|
"1. **Malte Pietsch** - Chief Executive Officer\n"
|
|
" - CEO and co-founder of Deepset\n"
|
|
" - Extensive experience in machine learning and natural language processing\n\n"
|
|
"2. **Milos Rusic** - Chief Technology Officer\n"
|
|
" - CTO and co-founder of Deepset\n"
|
|
" - Specializes in building scalable AI systems and NLP projects\n\n"
|
|
"Both individuals have been added to the knowledge base with their respective information."
|
|
)
|
|
]
|
|
}
|
|
|
|
generator.run = mock_run
|
|
|
|
# We use this since the @tool decorator has issues with deserialization
|
|
add_database_tool = create_tool_from_function(add_database_tool_function, name="add_database_tool")
|
|
|
|
database_assistant = Agent(
|
|
chat_generator=generator,
|
|
tools=[add_database_tool],
|
|
system_prompt="""
|
|
You are a database assistant.
|
|
Your task is to extract the names of people mentioned in the given context and add them to a knowledge base,
|
|
along with additional relevant information about them that can be extracted from the context.
|
|
Do not use you own knowledge, stay grounded to the given context.
|
|
Do not ask the user for confirmation. Instead, automatically update the knowledge base and return a brief
|
|
summary of the people added, including the information stored for each.
|
|
""",
|
|
exit_conditions=["text"],
|
|
max_agent_steps=100,
|
|
raise_on_tool_invocation_failure=False,
|
|
)
|
|
|
|
extraction_agent = Pipeline()
|
|
extraction_agent.add_component("fetcher", MockLinkContentFetcher())
|
|
extraction_agent.add_component("converter", MockHTMLToDocument())
|
|
extraction_agent.add_component(
|
|
"builder",
|
|
ChatPromptBuilder(
|
|
template=[
|
|
ChatMessage.from_user("""
|
|
{% for doc in docs %}
|
|
{{ doc.content|default|truncate(25000) }}
|
|
{% endfor %}
|
|
""")
|
|
],
|
|
required_variables=["docs"],
|
|
),
|
|
)
|
|
extraction_agent.add_component("database_agent", database_assistant)
|
|
|
|
extraction_agent.connect("fetcher.streams", "converter.sources")
|
|
extraction_agent.connect("converter.documents", "builder.docs")
|
|
extraction_agent.connect("builder", "database_agent")
|
|
|
|
return extraction_agent
|
|
|
|
|
|
def run_pipeline_without_any_breakpoints(pipeline_with_agent):
|
|
agent_output = pipeline_with_agent.run(data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}})
|
|
|
|
# pipeline completed
|
|
assert "database_agent" in agent_output
|
|
assert "messages" in agent_output["database_agent"]
|
|
assert len(agent_output["database_agent"]["messages"]) > 0
|
|
|
|
# final message contains the expected summary
|
|
final_message = agent_output["database_agent"]["messages"][-1].text
|
|
assert "Malte Pietsch" in final_message
|
|
assert "Milos Rusic" in final_message
|
|
assert "Chief Executive Officer" in final_message
|
|
assert "Chief Technology Officer" in final_message
|
|
|
|
|
|
def test_chat_generator_breakpoint_in_pipeline_agent(pipeline_with_agent):
|
|
with tempfile.TemporaryDirectory() as debug_path:
|
|
agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path)
|
|
agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent")
|
|
try:
|
|
pipeline_with_agent.run(
|
|
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
|
|
)
|
|
assert False, "Expected exception was not raised"
|
|
|
|
except BreakpointException as e: # this is the exception from the Agent
|
|
assert e.component == "chat_generator"
|
|
assert e.inputs is not None
|
|
assert "messages" in e.inputs["chat_generator"]["serialized_data"]
|
|
assert e.results is not None
|
|
|
|
# verify that snapshot file was created
|
|
chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json"))
|
|
assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}"
|
|
|
|
|
|
def test_tool_breakpoint_in_pipeline_agent(pipeline_with_agent):
|
|
with tempfile.TemporaryDirectory() as debug_path:
|
|
agent_tool_breakpoint = ToolBreakpoint(
|
|
"tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path
|
|
)
|
|
agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent")
|
|
try:
|
|
pipeline_with_agent.run(
|
|
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
|
|
)
|
|
assert False, "Expected exception was not raised"
|
|
except BreakpointException as e: # this is the exception from the Agent
|
|
assert e.component == "tool_invoker"
|
|
assert e.inputs is not None
|
|
assert "messages" in e.inputs["tool_invoker"]["serialized_data"]
|
|
assert e.results is not None
|
|
|
|
# verify that snapshot file was created
|
|
tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json"))
|
|
assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}"
|
|
|
|
|
|
def test_agent_breakpoint_chat_generator_and_resume_pipeline(pipeline_with_agent):
|
|
with tempfile.TemporaryDirectory() as debug_path:
|
|
agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path)
|
|
agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent")
|
|
try:
|
|
pipeline_with_agent.run(
|
|
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
|
|
)
|
|
assert False, "Expected PipelineBreakpointException was not raised"
|
|
|
|
except BreakpointException as e:
|
|
assert e.component == "chat_generator"
|
|
assert e.inputs is not None
|
|
assert "messages" in e.inputs["chat_generator"]["serialized_data"]
|
|
assert e.results is not None
|
|
|
|
# verify that the snapshot file was created
|
|
chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json"))
|
|
assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}"
|
|
|
|
# resume the pipeline from the saved snapshot
|
|
latest_snapshot_file = max(chat_generator_snapshot_files, key=os.path.getctime)
|
|
result = pipeline_with_agent.run(data={}, pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file))
|
|
|
|
# pipeline completed successfully after resuming
|
|
assert "database_agent" in result
|
|
assert "messages" in result["database_agent"]
|
|
assert len(result["database_agent"]["messages"]) > 0
|
|
|
|
# final message contains the expected summary
|
|
final_message = result["database_agent"]["messages"][-1].text
|
|
assert "Malte Pietsch" in final_message
|
|
assert "Milos Rusic" in final_message
|
|
assert "Chief Executive Officer" in final_message
|
|
assert "Chief Technology Officer" in final_message
|
|
|
|
# tool should have been called during the resumed execution
|
|
documents = document_store.filter_documents()
|
|
assert len(documents) >= 2, "Expected at least 2 documents to be added to the database"
|
|
|
|
# both people were added
|
|
person_names = [doc.content for doc in documents]
|
|
assert any("Malte Pietsch" in name for name in person_names)
|
|
assert any("Milos Rusic" in name for name in person_names)
|
|
|
|
|
|
def test_agent_breakpoint_tool_and_resume_pipeline(pipeline_with_agent):
|
|
with tempfile.TemporaryDirectory() as debug_path:
|
|
agent_tool_breakpoint = ToolBreakpoint(
|
|
"tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path
|
|
)
|
|
agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent")
|
|
try:
|
|
pipeline_with_agent.run(
|
|
data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint
|
|
)
|
|
assert False, "Expected PipelineBreakpointException was not raised"
|
|
|
|
except BreakpointException as e:
|
|
assert e.component == "tool_invoker"
|
|
assert e.inputs is not None
|
|
assert "serialization_schema" in e.inputs["tool_invoker"]
|
|
assert "serialized_data" in e.inputs["tool_invoker"]
|
|
assert "messages" in e.inputs["tool_invoker"]["serialized_data"]
|
|
assert e.results is not None
|
|
|
|
# verify that the snapshot file was created
|
|
tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json"))
|
|
assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}"
|
|
|
|
# resume the pipeline from the saved snapshot
|
|
latest_snapshot_file = max(tool_invoker_snapshot_files, key=os.path.getctime)
|
|
pipeline_snapshot = load_pipeline_snapshot(latest_snapshot_file)
|
|
result = pipeline_with_agent.run(data={}, pipeline_snapshot=pipeline_snapshot)
|
|
|
|
# pipeline completed successfully after resuming
|
|
assert "database_agent" in result
|
|
assert "messages" in result["database_agent"]
|
|
assert len(result["database_agent"]["messages"]) > 0
|
|
|
|
# final message contains the expected summary
|
|
final_message = result["database_agent"]["messages"][-1].text
|
|
assert "Malte Pietsch" in final_message
|
|
assert "Milos Rusic" in final_message
|
|
assert "Chief Executive Officer" in final_message
|
|
assert "Chief Technology Officer" in final_message
|
|
|
|
# tool should have been called during the resumed execution
|
|
documents = document_store.filter_documents()
|
|
assert len(documents) >= 2, "Expected at least 2 documents to be added to the database"
|
|
|
|
# both people were added
|
|
person_names = [doc.content for doc in documents]
|
|
assert any("Malte Pietsch" in name for name in person_names)
|
|
assert any("Milos Rusic" in name for name in person_names)
|