haystack/test/components/agents/test_agent.py

257 lines
10 KiB
Python
Raw Normal View History

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from typing import Iterator, Dict, Any, List
from unittest.mock import MagicMock, patch
import pytest
from openai import Stream
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
from haystack.components.agents import Agent
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.generators.chat.types import ChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool, ComponentTool
from haystack.utils import serialize_callable, Secret
def streaming_callback_for_serde(chunk: StreamingChunk):
pass
def weather_function(location):
weather_info = {
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
}
return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
@pytest.fixture
def weather_tool():
return Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
function=weather_function,
)
@pytest.fixture
def component_tool():
return ComponentTool(name="parrot", description="This is a parrot.", component=PromptBuilder(template="{{parrot}}"))
class OpenAIMockStream(Stream[ChatCompletionChunk]):
def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs):
client = client or MagicMock()
super().__init__(client=client, *args, **kwargs)
self.mock_chunk = mock_chunk
def __stream__(self) -> Iterator[ChatCompletionChunk]:
yield self.mock_chunk
@pytest.fixture
def openai_mock_chat_completion_chunk():
"""
Mock the OpenAI API completion chunk response and reuse it for tests
"""
with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create:
completion = ChatCompletionChunk(
id="foo",
model="gpt-4",
object="chat.completion.chunk",
choices=[
chat_completion_chunk.Choice(
finish_reason="stop",
logprobs=None,
index=0,
delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant"),
)
],
created=int(datetime.now().timestamp()),
usage=None,
)
mock_chat_completion_create.return_value = OpenAIMockStream(
completion, cast_to=None, response=None, client=None
)
yield mock_chat_completion_create
class MockChatGeneratorWithoutTools(ChatGenerator):
"""A mock chat generator that implements ChatGenerator protocol but doesn't support tools."""
def to_dict(self) -> Dict[str, Any]:
return {"type": "MockChatGeneratorWithoutTools", "data": {}}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutTools":
return cls()
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello")]}
class TestAgent:
def test_serde(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
agent = Agent(
chat_generator=generator,
tools=[weather_tool, component_tool],
exit_conditions=["text", "weather_tool"],
state_schema={"foo": {"type": str}},
)
serialized_agent = agent.to_dict()
init_parameters = serialized_agent["init_parameters"]
assert serialized_agent["type"] == "haystack.components.agents.agent.Agent"
assert (
init_parameters["chat_generator"]["type"]
== "haystack.components.generators.chat.openai.OpenAIChatGenerator"
)
assert init_parameters["streaming_callback"] == None
assert init_parameters["tools"][0]["data"]["function"] == serialize_callable(weather_function)
assert (
init_parameters["tools"][1]["data"]["component"]["type"]
== "haystack.components.builders.prompt_builder.PromptBuilder"
)
assert init_parameters["exit_conditions"] == ["text", "weather_tool"]
deserialized_agent = Agent.from_dict(serialized_agent)
assert isinstance(deserialized_agent, Agent)
assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator)
assert deserialized_agent.tools[0].function is weather_function
assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder)
assert deserialized_agent.exit_conditions == ["text", "weather_tool"]
assert deserialized_agent.state_schema == {"foo": {"type": str}}
def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
agent = Agent(
chat_generator=generator,
tools=[weather_tool, component_tool],
streaming_callback=streaming_callback_for_serde,
)
serialized_agent = agent.to_dict()
init_parameters = serialized_agent["init_parameters"]
assert init_parameters["streaming_callback"] == "test_agent.streaming_callback_for_serde"
deserialized_agent = Agent.from_dict(serialized_agent)
assert deserialized_agent.streaming_callback is streaming_callback_for_serde
def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
# Test invalid exit condition
with pytest.raises(ValueError, match="Exit conditions must be a subset of"):
Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"])
# Test default exit condition
agent = Agent(chat_generator=generator, tools=[weather_tool, component_tool])
assert agent.exit_conditions == ["text"]
# Test multiple valid exit conditions
agent = Agent(
chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"]
)
assert agent.exit_conditions == ["text", "weather_tool"]
def test_run_with_params_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
agent = Agent(chat_generator=chat_generator, streaming_callback=streaming_callback, tools=[weather_tool])
agent.warm_up()
response = agent.run([ChatMessage.from_user("Hello")])
# check we called the streaming callback
assert streaming_callback_called is True
# check that the component still returns the correct response
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_run_with_run_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
response = agent.run([ChatMessage.from_user("Hello")], streaming_callback=streaming_callback)
# check we called the streaming callback
assert streaming_callback_called is True
# check that the component still returns the correct response
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_keep_generator_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
chat_generator = OpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback
)
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
agent.warm_up()
response = agent.run([ChatMessage.from_user("Hello")])
# check we called the streaming callback
assert streaming_callback_called is True
# check that the component still returns the correct response
assert isinstance(response, dict)
assert "messages" in response
assert isinstance(response["messages"], list)
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_chat_generator_must_support_tools(self, weather_tool):
chat_generator = MockChatGeneratorWithoutTools()
with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"):
Agent(chat_generator=chat_generator, tools=[weather_tool])