haystack/test/components/agents/test_agent.py

503 lines
22 KiB
Python
Raw Normal View History

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import os
from datetime import datetime
from typing import Iterator, Dict, Any, List
from unittest.mock import MagicMock, patch
import pytest
from openai import Stream
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
from haystack.components.agents import Agent
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.generators.chat.types import ChatGenerator
from haystack.core.component.types import OutputSocket
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool, ComponentTool
from haystack.utils import serialize_callable, Secret
from haystack.dataclasses.state_utils import merge_lists
def streaming_callback_for_serde(chunk: StreamingChunk):
pass
def weather_function(location):
weather_info = {
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
}
return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
@pytest.fixture
def weather_tool():
return Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]},
function=weather_function,
)
@pytest.fixture
def component_tool():
return ComponentTool(name="parrot", description="This is a parrot.", component=PromptBuilder(template="{{parrot}}"))
class OpenAIMockStream(Stream[ChatCompletionChunk]):
def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs):
client = client or MagicMock()
super().__init__(client=client, *args, **kwargs)
self.mock_chunk = mock_chunk
def __stream__(self) -> Iterator[ChatCompletionChunk]:
yield self.mock_chunk
@pytest.fixture
def openai_mock_chat_completion_chunk():
"""
Mock the OpenAI API completion chunk response and reuse it for tests
"""
with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create:
completion = ChatCompletionChunk(
id="foo",
model="gpt-4",
object="chat.completion.chunk",
choices=[
chat_completion_chunk.Choice(
finish_reason="stop",
logprobs=None,
index=0,
delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant"),
)
],
created=int(datetime.now().timestamp()),
usage=None,
)
mock_chat_completion_create.return_value = OpenAIMockStream(
completion, cast_to=None, response=None, client=None
)
yield mock_chat_completion_create
class MockChatGeneratorWithoutTools(ChatGenerator):
"""A mock chat generator that implements ChatGenerator protocol but doesn't support tools."""
def to_dict(self) -> Dict[str, Any]:
return {"type": "MockChatGeneratorWithoutTools", "data": {}}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutTools":
return cls()
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello")]}
class TestAgent:
def test_output_types(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
chat_generator = OpenAIChatGenerator()
agent = Agent(chat_generator=chat_generator, tools=[weather_tool, component_tool])
assert agent.__haystack_output__._sockets_dict == {
"messages": OutputSocket(name="messages", type=List[ChatMessage], receivers=[])
}
def test_to_dict(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
generator = OpenAIChatGenerator()
agent = Agent(
chat_generator=generator,
tools=[weather_tool, component_tool],
exit_conditions=["text", "weather_tool"],
state_schema={"foo": {"type": str}},
)
serialized_agent = agent.to_dict()
assert serialized_agent == {
"type": "haystack.components.agents.agent.Agent",
"init_parameters": {
"chat_generator": {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model": "gpt-4o-mini",
"streaming_callback": None,
"api_base_url": None,
"organization": None,
"generation_kwargs": {},
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"timeout": None,
"max_retries": None,
"tools": None,
"tools_strict": False,
},
},
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "test_agent.weather_function",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
},
{
"type": "haystack.tools.component_tool.ComponentTool",
"data": {
"component": {
"type": "haystack.components.builders.prompt_builder.PromptBuilder",
"init_parameters": {
"template": "{{parrot}}",
"variables": None,
"required_variables": None,
},
},
"name": "parrot",
"description": "This is a parrot.",
"parameters": None,
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
},
],
"system_prompt": None,
"exit_conditions": ["text", "weather_tool"],
"state_schema": {"foo": {"type": "str"}},
"max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None,
},
}
def test_from_dict(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
data = {
"type": "haystack.components.agents.agent.Agent",
"init_parameters": {
"chat_generator": {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model": "gpt-4o-mini",
"streaming_callback": None,
"api_base_url": None,
"organization": None,
"generation_kwargs": {},
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"timeout": None,
"max_retries": None,
"tools": None,
"tools_strict": False,
},
},
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "test_agent.weather_function",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
},
{
"type": "haystack.tools.component_tool.ComponentTool",
"data": {
"component": {
"type": "haystack.components.builders.prompt_builder.PromptBuilder",
"init_parameters": {
"template": "{{parrot}}",
"variables": None,
"required_variables": None,
},
},
"name": "parrot",
"description": "This is a parrot.",
"parameters": None,
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
},
],
"system_prompt": None,
"exit_conditions": ["text", "weather_tool"],
"state_schema": {"foo": {"type": "str"}},
"max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None,
},
}
agent = Agent.from_dict(data)
assert isinstance(agent, Agent)
assert isinstance(agent.chat_generator, OpenAIChatGenerator)
assert agent.chat_generator.model == "gpt-4o-mini"
assert agent.chat_generator.api_key == Secret.from_env_var("OPENAI_API_KEY")
assert agent.tools[0].function is weather_function
assert isinstance(agent.tools[1]._component, PromptBuilder)
assert agent.exit_conditions == ["text", "weather_tool"]
assert agent.state_schema == {
"foo": {"type": str},
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
}
def test_serde(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
agent = Agent(
chat_generator=generator,
tools=[weather_tool, component_tool],
exit_conditions=["text", "weather_tool"],
state_schema={"foo": {"type": str}},
)
serialized_agent = agent.to_dict()
init_parameters = serialized_agent["init_parameters"]
assert serialized_agent["type"] == "haystack.components.agents.agent.Agent"
assert (
init_parameters["chat_generator"]["type"]
== "haystack.components.generators.chat.openai.OpenAIChatGenerator"
)
assert init_parameters["streaming_callback"] is None
assert init_parameters["tools"][0]["data"]["function"] == serialize_callable(weather_function)
assert (
init_parameters["tools"][1]["data"]["component"]["type"]
== "haystack.components.builders.prompt_builder.PromptBuilder"
)
assert init_parameters["exit_conditions"] == ["text", "weather_tool"]
deserialized_agent = Agent.from_dict(serialized_agent)
assert isinstance(deserialized_agent, Agent)
assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator)
assert deserialized_agent.tools[0].function is weather_function
assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder)
assert deserialized_agent.exit_conditions == ["text", "weather_tool"]
assert deserialized_agent.state_schema == {
"foo": {"type": str},
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
}
def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
agent = Agent(
chat_generator=generator,
tools=[weather_tool, component_tool],
streaming_callback=streaming_callback_for_serde,
)
serialized_agent = agent.to_dict()
init_parameters = serialized_agent["init_parameters"]
assert init_parameters["streaming_callback"] == "test_agent.streaming_callback_for_serde"
deserialized_agent = Agent.from_dict(serialized_agent)
assert deserialized_agent.streaming_callback is streaming_callback_for_serde
def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
# Test invalid exit condition
with pytest.raises(ValueError, match="Invalid exit conditions provided:"):
Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"])
# Test default exit condition
agent = Agent(chat_generator=generator, tools=[weather_tool, component_tool])
assert agent.exit_conditions == ["text"]
# Test multiple valid exit conditions
agent = Agent(
chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"]
)
assert agent.exit_conditions == ["text", "weather_tool"]
def test_run_with_params_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
agent = Agent(chat_generator=chat_generator, streaming_callback=streaming_callback, tools=[weather_tool])
agent.warm_up()
response = agent.run([ChatMessage.from_user("Hello")])
# check we called the streaming callback
assert streaming_callback_called is True
# check that the component still returns the correct response
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_run_with_run_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
response = agent.run([ChatMessage.from_user("Hello")], streaming_callback=streaming_callback)
# check we called the streaming callback
assert streaming_callback_called is True
# check that the component still returns the correct response
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_keep_generator_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
chat_generator = OpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback
)
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
response = agent.run([ChatMessage.from_user("Hello")])
# check we called the streaming callback
assert streaming_callback_called is True
# check that the component still returns the correct response
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_chat_generator_must_support_tools(self, weather_tool):
chat_generator = MockChatGeneratorWithoutTools()
with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"):
Agent(chat_generator=chat_generator, tools=[weather_tool])
def test_multiple_llm_responses_with_tool_call(self, monkeypatch, weather_tool):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
generator = OpenAIChatGenerator()
mock_messages = [
ChatMessage.from_assistant("First response"),
ChatMessage.from_assistant(
tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
),
]
agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1)
agent.warm_up()
# Patch agent.chat_generator.run to return mock_messages
agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages})
result = agent.run([ChatMessage.from_user("Hello")])
assert "messages" in result
assert len(result["messages"]) == 4
assert (
result["messages"][-1].tool_call_result.result
== "{'weather': 'mostly sunny', 'temperature': 7, 'unit': 'celsius'}"
)
def test_exit_conditions_checked_across_all_llm_messages(self, monkeypatch, weather_tool):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
generator = OpenAIChatGenerator()
# Mock messages where the exit condition appears in the second message
mock_messages = [
ChatMessage.from_assistant("First response"),
ChatMessage.from_assistant(
tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
),
]
agent = Agent(chat_generator=generator, tools=[weather_tool], exit_conditions=["weather_tool"])
agent.warm_up()
# Patch agent.chat_generator.run to return mock_messages
agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages})
result = agent.run([ChatMessage.from_user("Hello")])
assert "messages" in result
assert len(result["messages"]) == 4
assert result["messages"][-2].tool_call.tool_name == "weather_tool"
assert (
result["messages"][-1].tool_call_result.result
== "{'weather': 'mostly sunny', 'temperature': 7, 'unit': 'celsius'}"
)
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_run(self, weather_tool):
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini")
agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3)
agent.warm_up()
response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")])
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 4
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
# Loose check of message texts
assert response["messages"][0].text == "What is the weather in Berlin?"
assert response["messages"][1].text is None
assert response["messages"][2].text is None
assert response["messages"][3].text is not None
# Loose check of message metadata
assert response["messages"][0].meta == {}
assert response["messages"][1].meta.get("model") is not None
assert response["messages"][2].meta == {}
assert response["messages"][3].meta.get("model") is not None
# Loose check of tool calls and results
assert response["messages"][1].tool_calls[0].tool_name == "weather_tool"
assert response["messages"][1].tool_calls[0].arguments is not None
assert response["messages"][2].tool_call_results[0].result is not None
assert response["messages"][2].tool_call_results[0].origin is not None