# SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 import os from datetime import datetime from typing import Iterator, Dict, Any, List from unittest.mock import MagicMock, patch import pytest from openai import Stream from openai.types.chat import ChatCompletionChunk, chat_completion_chunk from haystack.components.agents import Agent from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.types import ChatGenerator from haystack.dataclasses import ChatMessage from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.tools import Tool, ComponentTool from haystack.utils import serialize_callable, Secret def streaming_callback_for_serde(chunk: StreamingChunk): pass def weather_function(location): weather_info = { "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, } return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) @pytest.fixture def weather_tool(): return Tool( name="weather_tool", description="Provides weather information for a given location.", parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, function=weather_function, ) @pytest.fixture def component_tool(): return ComponentTool(name="parrot", description="This is a parrot.", component=PromptBuilder(template="{{parrot}}")) class OpenAIMockStream(Stream[ChatCompletionChunk]): def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs): client = client or MagicMock() super().__init__(client=client, *args, **kwargs) self.mock_chunk = mock_chunk def __stream__(self) -> Iterator[ChatCompletionChunk]: yield self.mock_chunk @pytest.fixture def openai_mock_chat_completion_chunk(): """ Mock the OpenAI API completion chunk response and reuse it for tests """ with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletionChunk( id="foo", model="gpt-4", object="chat.completion.chunk", choices=[ chat_completion_chunk.Choice( finish_reason="stop", logprobs=None, index=0, delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant"), ) ], created=int(datetime.now().timestamp()), usage=None, ) mock_chat_completion_create.return_value = OpenAIMockStream( completion, cast_to=None, response=None, client=None ) yield mock_chat_completion_create class MockChatGeneratorWithoutTools(ChatGenerator): """A mock chat generator that implements ChatGenerator protocol but doesn't support tools.""" def to_dict(self) -> Dict[str, Any]: return {"type": "MockChatGeneratorWithoutTools", "data": {}} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutTools": return cls() def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: return {"replies": [ChatMessage.from_assistant("Hello")]} class TestAgent: def test_serde(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) agent = Agent( chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"], state_schema={"foo": {"type": str}}, ) serialized_agent = agent.to_dict() init_parameters = serialized_agent["init_parameters"] assert serialized_agent["type"] == "haystack.components.agents.agent.Agent" assert ( init_parameters["chat_generator"]["type"] == "haystack.components.generators.chat.openai.OpenAIChatGenerator" ) assert init_parameters["streaming_callback"] == None assert init_parameters["tools"][0]["data"]["function"] == serialize_callable(weather_function) assert ( init_parameters["tools"][1]["data"]["component"]["type"] == "haystack.components.builders.prompt_builder.PromptBuilder" ) assert init_parameters["exit_conditions"] == ["text", "weather_tool"] deserialized_agent = Agent.from_dict(serialized_agent) assert isinstance(deserialized_agent, Agent) assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator) assert deserialized_agent.tools[0].function is weather_function assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder) assert deserialized_agent.exit_conditions == ["text", "weather_tool"] assert deserialized_agent.state_schema == {"foo": {"type": str}} def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) agent = Agent( chat_generator=generator, tools=[weather_tool, component_tool], streaming_callback=streaming_callback_for_serde, ) serialized_agent = agent.to_dict() init_parameters = serialized_agent["init_parameters"] assert init_parameters["streaming_callback"] == "test_agent.streaming_callback_for_serde" deserialized_agent = Agent.from_dict(serialized_agent) assert deserialized_agent.streaming_callback is streaming_callback_for_serde def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) # Test invalid exit condition with pytest.raises(ValueError, match="Invalid exit conditions provided:"): Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"]) # Test default exit condition agent = Agent(chat_generator=generator, tools=[weather_tool, component_tool]) assert agent.exit_conditions == ["text"] # Test multiple valid exit conditions agent = Agent( chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"] ) assert agent.exit_conditions == ["text", "weather_tool"] def test_run_with_params_streaming(self, openai_mock_chat_completion_chunk, weather_tool): chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True agent = Agent(chat_generator=chat_generator, streaming_callback=streaming_callback, tools=[weather_tool]) agent.warm_up() response = agent.run([ChatMessage.from_user("Hello")]) # check we called the streaming callback assert streaming_callback_called is True # check that the component still returns the correct response assert isinstance(response, dict) assert "messages" in response assert isinstance(response["messages"], list) assert len(response["messages"]) == 2 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk def test_run_with_run_streaming(self, openai_mock_chat_completion_chunk, weather_tool): chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) agent.warm_up() response = agent.run([ChatMessage.from_user("Hello")], streaming_callback=streaming_callback) # check we called the streaming callback assert streaming_callback_called is True # check that the component still returns the correct response assert isinstance(response, dict) assert "messages" in response assert isinstance(response["messages"], list) assert len(response["messages"]) == 2 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk def test_keep_generator_streaming(self, openai_mock_chat_completion_chunk, weather_tool): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True chat_generator = OpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback ) agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) agent.warm_up() response = agent.run([ChatMessage.from_user("Hello")]) # check we called the streaming callback assert streaming_callback_called is True # check that the component still returns the correct response assert isinstance(response, dict) assert "messages" in response assert isinstance(response["messages"], list) assert len(response["messages"]) == 2 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk def test_chat_generator_must_support_tools(self, weather_tool): chat_generator = MockChatGeneratorWithoutTools() with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"): Agent(chat_generator=chat_generator, tools=[weather_tool]) @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_run(self, weather_tool): chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3) agent.warm_up() response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) assert isinstance(response, dict) assert "messages" in response assert isinstance(response["messages"], list) assert len(response["messages"]) == 4 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] # Loose check of message texts assert response["messages"][0].text == "What is the weather in Berlin?" assert response["messages"][1].text is None assert response["messages"][2].text is None assert response["messages"][3].text is not None # Loose check of message metadata assert response["messages"][0].meta == {} assert response["messages"][1].meta.get("model") is not None assert response["messages"][2].meta == {} assert response["messages"][3].meta.get("model") is not None # Loose check of tool calls and results assert response["messages"][1].tool_calls[0].tool_name == "weather_tool" assert response["messages"][1].tool_calls[0].arguments is not None assert response["messages"][2].tool_call_results[0].result is not None assert response["messages"][2].tool_call_results[0].origin is not None