haystack/test/components/agents/test_agent.py

589 lines
26 KiB
Python
Raw Normal View History

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import os
from datetime import datetime
from typing import Iterator, Dict, Any, List, Optional, Union
from unittest.mock import MagicMock, patch, AsyncMock
import pytest
from openai import Stream
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
from haystack.components.agents import Agent
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.generators.chat.types import ChatGenerator
from haystack import component
from haystack.core.component.types import OutputSocket
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.chat_message import ChatRole, TextContent
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool, ComponentTool
from haystack.tools.toolset import Toolset
from haystack.utils import serialize_callable, Secret
from haystack.dataclasses.state_utils import merge_lists
def streaming_callback_for_serde(chunk: StreamingChunk):
pass
def weather_function(location):
weather_info = {
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
}
return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
@pytest.fixture
def weather_tool():
return Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]},
function=weather_function,
)
@pytest.fixture
def component_tool():
return ComponentTool(name="parrot", description="This is a parrot.", component=PromptBuilder(template="{{parrot}}"))
class OpenAIMockStream(Stream[ChatCompletionChunk]):
def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs):
client = client or MagicMock()
super().__init__(client=client, *args, **kwargs)
self.mock_chunk = mock_chunk
def __stream__(self) -> Iterator[ChatCompletionChunk]:
yield self.mock_chunk
@pytest.fixture
def openai_mock_chat_completion_chunk():
"""
Mock the OpenAI API completion chunk response and reuse it for tests
"""
with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create:
completion = ChatCompletionChunk(
id="foo",
model="gpt-4",
object="chat.completion.chunk",
choices=[
chat_completion_chunk.Choice(
finish_reason="stop",
logprobs=None,
index=0,
delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant"),
)
],
created=int(datetime.now().timestamp()),
usage=None,
)
mock_chat_completion_create.return_value = OpenAIMockStream(
completion, cast_to=None, response=None, client=None
)
yield mock_chat_completion_create
class MockChatGeneratorWithoutTools(ChatGenerator):
"""A mock chat generator that implements ChatGenerator protocol but doesn't support tools."""
def to_dict(self) -> Dict[str, Any]:
return {"type": "MockChatGeneratorWithoutTools", "data": {}}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutTools":
return cls()
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello")]}
class MockChatGeneratorWithoutRunAsync(ChatGenerator):
"""A mock chat generator that implements ChatGenerator protocol but doesn't have run_async method."""
def to_dict(self) -> Dict[str, Any]:
return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync":
return cls()
def run(
self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs
) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello")]}
class TestAgent:
def test_output_types(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
chat_generator = OpenAIChatGenerator()
agent = Agent(chat_generator=chat_generator, tools=[weather_tool, component_tool])
assert agent.__haystack_output__._sockets_dict == {
"messages": OutputSocket(name="messages", type=List[ChatMessage], receivers=[])
}
def test_to_dict(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
generator = OpenAIChatGenerator()
agent = Agent(
chat_generator=generator,
tools=[weather_tool, component_tool],
exit_conditions=["text", "weather_tool"],
state_schema={"foo": {"type": str}},
)
serialized_agent = agent.to_dict()
assert serialized_agent == {
"type": "haystack.components.agents.agent.Agent",
"init_parameters": {
"chat_generator": {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model": "gpt-4o-mini",
"streaming_callback": None,
"api_base_url": None,
"organization": None,
"generation_kwargs": {},
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"timeout": None,
"max_retries": None,
"tools": None,
"tools_strict": False,
},
},
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "test_agent.weather_function",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
},
{
"type": "haystack.tools.component_tool.ComponentTool",
"data": {
"component": {
"type": "haystack.components.builders.prompt_builder.PromptBuilder",
"init_parameters": {
"template": "{{parrot}}",
"variables": None,
"required_variables": None,
},
},
"name": "parrot",
"description": "This is a parrot.",
"parameters": None,
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
},
],
"system_prompt": None,
"exit_conditions": ["text", "weather_tool"],
"state_schema": {"foo": {"type": "str"}},
"max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None,
},
}
def test_from_dict(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
data = {
"type": "haystack.components.agents.agent.Agent",
"init_parameters": {
"chat_generator": {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model": "gpt-4o-mini",
"streaming_callback": None,
"api_base_url": None,
"organization": None,
"generation_kwargs": {},
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"timeout": None,
"max_retries": None,
"tools": None,
"tools_strict": False,
},
},
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "test_agent.weather_function",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
},
{
"type": "haystack.tools.component_tool.ComponentTool",
"data": {
"component": {
"type": "haystack.components.builders.prompt_builder.PromptBuilder",
"init_parameters": {
"template": "{{parrot}}",
"variables": None,
"required_variables": None,
},
},
"name": "parrot",
"description": "This is a parrot.",
"parameters": None,
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
},
],
"system_prompt": None,
"exit_conditions": ["text", "weather_tool"],
"state_schema": {"foo": {"type": "str"}},
"max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None,
},
}
agent = Agent.from_dict(data)
assert isinstance(agent, Agent)
assert isinstance(agent.chat_generator, OpenAIChatGenerator)
assert agent.chat_generator.model == "gpt-4o-mini"
assert agent.chat_generator.api_key == Secret.from_env_var("OPENAI_API_KEY")
assert agent.tools[0].function is weather_function
assert isinstance(agent.tools[1]._component, PromptBuilder)
assert agent.exit_conditions == ["text", "weather_tool"]
assert agent.state_schema == {
"foo": {"type": str},
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
}
def test_serde(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
agent = Agent(
chat_generator=generator,
tools=[weather_tool, component_tool],
exit_conditions=["text", "weather_tool"],
state_schema={"foo": {"type": str}},
)
serialized_agent = agent.to_dict()
init_parameters = serialized_agent["init_parameters"]
assert serialized_agent["type"] == "haystack.components.agents.agent.Agent"
assert (
init_parameters["chat_generator"]["type"]
== "haystack.components.generators.chat.openai.OpenAIChatGenerator"
)
assert init_parameters["streaming_callback"] is None
assert init_parameters["tools"][0]["data"]["function"] == serialize_callable(weather_function)
assert (
init_parameters["tools"][1]["data"]["component"]["type"]
== "haystack.components.builders.prompt_builder.PromptBuilder"
)
assert init_parameters["exit_conditions"] == ["text", "weather_tool"]
deserialized_agent = Agent.from_dict(serialized_agent)
assert isinstance(deserialized_agent, Agent)
assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator)
assert deserialized_agent.tools[0].function is weather_function
assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder)
assert deserialized_agent.exit_conditions == ["text", "weather_tool"]
assert deserialized_agent.state_schema == {
"foo": {"type": str},
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
}
def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
agent = Agent(
chat_generator=generator,
tools=[weather_tool, component_tool],
streaming_callback=streaming_callback_for_serde,
)
serialized_agent = agent.to_dict()
init_parameters = serialized_agent["init_parameters"]
assert init_parameters["streaming_callback"] == "test_agent.streaming_callback_for_serde"
deserialized_agent = Agent.from_dict(serialized_agent)
assert deserialized_agent.streaming_callback is streaming_callback_for_serde
def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
# Test invalid exit condition
with pytest.raises(ValueError, match="Invalid exit conditions provided:"):
Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"])
# Test default exit condition
agent = Agent(chat_generator=generator, tools=[weather_tool, component_tool])
assert agent.exit_conditions == ["text"]
# Test multiple valid exit conditions
agent = Agent(
chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"]
)
assert agent.exit_conditions == ["text", "weather_tool"]
def test_run_with_params_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
agent = Agent(chat_generator=chat_generator, streaming_callback=streaming_callback, tools=[weather_tool])
agent.warm_up()
response = agent.run([ChatMessage.from_user("Hello")])
# check we called the streaming callback
assert streaming_callback_called is True
# check that the component still returns the correct response
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_run_with_run_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
response = agent.run([ChatMessage.from_user("Hello")], streaming_callback=streaming_callback)
# check we called the streaming callback
assert streaming_callback_called is True
# check that the component still returns the correct response
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_keep_generator_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
chat_generator = OpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback
)
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
response = agent.run([ChatMessage.from_user("Hello")])
# check we called the streaming callback
assert streaming_callback_called is True
# check that the component still returns the correct response
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_chat_generator_must_support_tools(self, weather_tool):
chat_generator = MockChatGeneratorWithoutTools()
with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"):
Agent(chat_generator=chat_generator, tools=[weather_tool])
def test_multiple_llm_responses_with_tool_call(self, monkeypatch, weather_tool):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
generator = OpenAIChatGenerator()
mock_messages = [
ChatMessage.from_assistant("First response"),
ChatMessage.from_assistant(
tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
),
]
agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1)
agent.warm_up()
# Patch agent.chat_generator.run to return mock_messages
agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages})
result = agent.run([ChatMessage.from_user("Hello")])
assert "messages" in result
assert len(result["messages"]) == 4
assert (
result["messages"][-1].tool_call_result.result
== "{'weather': 'mostly sunny', 'temperature': 7, 'unit': 'celsius'}"
)
def test_exit_conditions_checked_across_all_llm_messages(self, monkeypatch, weather_tool):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
generator = OpenAIChatGenerator()
# Mock messages where the exit condition appears in the second message
mock_messages = [
ChatMessage.from_assistant("First response"),
ChatMessage.from_assistant(
tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
),
]
agent = Agent(chat_generator=generator, tools=[weather_tool], exit_conditions=["weather_tool"])
agent.warm_up()
# Patch agent.chat_generator.run to return mock_messages
agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages})
result = agent.run([ChatMessage.from_user("Hello")])
assert "messages" in result
assert len(result["messages"]) == 4
assert result["messages"][-2].tool_call.tool_name == "weather_tool"
assert (
result["messages"][-1].tool_call_result.result
== "{'weather': 'mostly sunny', 'temperature': 7, 'unit': 'celsius'}"
)
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_run(self, weather_tool):
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini")
agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3)
agent.warm_up()
response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")])
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 4
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
# Loose check of message texts
assert response["messages"][0].text == "What is the weather in Berlin?"
assert response["messages"][1].text is None
assert response["messages"][2].text is None
assert response["messages"][3].text is not None
# Loose check of message metadata
assert response["messages"][0].meta == {}
assert response["messages"][1].meta.get("model") is not None
assert response["messages"][2].meta == {}
assert response["messages"][3].meta.get("model") is not None
# Loose check of tool calls and results
assert response["messages"][1].tool_calls[0].tool_name == "weather_tool"
assert response["messages"][1].tool_calls[0].arguments is not None
assert response["messages"][2].tool_call_results[0].result is not None
assert response["messages"][2].tool_call_results[0].origin is not None
@pytest.mark.asyncio
async def test_run_async_falls_back_to_run_when_chat_generator_has_no_run_async(self, weather_tool):
chat_generator = MockChatGeneratorWithoutRunAsync()
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Hello")]})
result = await agent.run_async([ChatMessage.from_user("Hello")])
expected_messages = [
ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={})
]
chat_generator.run.assert_called_once_with(messages=expected_messages, tools=[weather_tool])
assert isinstance(result, dict)
assert "messages" in result
assert isinstance(result["messages"], list)
assert len(result["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in result["messages"]]
assert "Hello" in result["messages"][1].text
@pytest.mark.asyncio
async def test_run_async_uses_chat_generator_run_async_when_available(self, weather_tool):
# Create a mock chat generator with run_async
# We need to use @component so that has_async_run is set
@component
class MockChatGeneratorWithRunAsync:
def to_dict(self) -> Dict[str, Any]:
return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync":
return cls()
def run(
self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs
) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello")]}
async def run_async(
self, messages: List[ChatMessage], tools: Optional[Union[List[Tool], Toolset]] = None, **kwargs
) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello from run_async")]}
chat_generator = MockChatGeneratorWithRunAsync()
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
chat_generator.run_async = AsyncMock(
return_value={"replies": [ChatMessage.from_assistant("Hello from run_async")]}
)
result = await agent.run_async([ChatMessage.from_user("Hello")])
expected_messages = [
ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={})
]
chat_generator.run_async.assert_called_once_with(messages=expected_messages, tools=[weather_tool])
assert isinstance(result, dict)
assert "messages" in result
assert isinstance(result["messages"], list)
assert len(result["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in result["messages"]]
assert "Hello from run_async" in result["messages"][1].text