# SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 import os from datetime import datetime from typing import Iterator, Dict, Any, List from unittest.mock import MagicMock, patch import pytest from openai import Stream from openai.types.chat import ChatCompletionChunk, chat_completion_chunk from haystack.components.agents import Agent from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.types import ChatGenerator from haystack.core.component.types import OutputSocket from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.tools import Tool, ComponentTool from haystack.utils import serialize_callable, Secret from haystack.dataclasses.state_utils import merge_lists def streaming_callback_for_serde(chunk: StreamingChunk): pass def weather_function(location): weather_info = { "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, } return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) @pytest.fixture def weather_tool(): return Tool( name="weather_tool", description="Provides weather information for a given location.", parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, function=weather_function, ) @pytest.fixture def component_tool(): return ComponentTool(name="parrot", description="This is a parrot.", component=PromptBuilder(template="{{parrot}}")) class OpenAIMockStream(Stream[ChatCompletionChunk]): def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs): client = client or MagicMock() super().__init__(client=client, *args, **kwargs) self.mock_chunk = mock_chunk def __stream__(self) -> Iterator[ChatCompletionChunk]: yield self.mock_chunk @pytest.fixture def openai_mock_chat_completion_chunk(): """ Mock the OpenAI API completion chunk response and reuse it for tests """ with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletionChunk( id="foo", model="gpt-4", object="chat.completion.chunk", choices=[ chat_completion_chunk.Choice( finish_reason="stop", logprobs=None, index=0, delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant"), ) ], created=int(datetime.now().timestamp()), usage=None, ) mock_chat_completion_create.return_value = OpenAIMockStream( completion, cast_to=None, response=None, client=None ) yield mock_chat_completion_create class MockChatGeneratorWithoutTools(ChatGenerator): """A mock chat generator that implements ChatGenerator protocol but doesn't support tools.""" def to_dict(self) -> Dict[str, Any]: return {"type": "MockChatGeneratorWithoutTools", "data": {}} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutTools": return cls() def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: return {"replies": [ChatMessage.from_assistant("Hello")]} class TestAgent: def test_output_types(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") chat_generator = OpenAIChatGenerator() agent = Agent(chat_generator=chat_generator, tools=[weather_tool, component_tool]) assert agent.__haystack_output__._sockets_dict == { "messages": OutputSocket(name="messages", type=List[ChatMessage], receivers=[]) } def test_to_dict(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") generator = OpenAIChatGenerator() agent = Agent( chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"], state_schema={"foo": {"type": str}}, ) serialized_agent = agent.to_dict() assert serialized_agent == { "type": "haystack.components.agents.agent.Agent", "init_parameters": { "chat_generator": { "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "model": "gpt-4o-mini", "streaming_callback": None, "api_base_url": None, "organization": None, "generation_kwargs": {}, "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, "timeout": None, "max_retries": None, "tools": None, "tools_strict": False, }, }, "tools": [ { "type": "haystack.tools.tool.Tool", "data": { "name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": { "type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"], }, "function": "test_agent.weather_function", "outputs_to_string": None, "inputs_from_state": None, "outputs_to_state": None, }, }, { "type": "haystack.tools.component_tool.ComponentTool", "data": { "component": { "type": "haystack.components.builders.prompt_builder.PromptBuilder", "init_parameters": { "template": "{{parrot}}", "variables": None, "required_variables": None, }, }, "name": "parrot", "description": "This is a parrot.", "parameters": None, "outputs_to_string": None, "inputs_from_state": None, "outputs_to_state": None, }, }, ], "system_prompt": None, "exit_conditions": ["text", "weather_tool"], "state_schema": {"foo": {"type": "str"}}, "max_agent_steps": 100, "raise_on_tool_invocation_failure": False, "streaming_callback": None, }, } def test_from_dict(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") data = { "type": "haystack.components.agents.agent.Agent", "init_parameters": { "chat_generator": { "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "model": "gpt-4o-mini", "streaming_callback": None, "api_base_url": None, "organization": None, "generation_kwargs": {}, "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, "timeout": None, "max_retries": None, "tools": None, "tools_strict": False, }, }, "tools": [ { "type": "haystack.tools.tool.Tool", "data": { "name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": { "type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"], }, "function": "test_agent.weather_function", "outputs_to_string": None, "inputs_from_state": None, "outputs_to_state": None, }, }, { "type": "haystack.tools.component_tool.ComponentTool", "data": { "component": { "type": "haystack.components.builders.prompt_builder.PromptBuilder", "init_parameters": { "template": "{{parrot}}", "variables": None, "required_variables": None, }, }, "name": "parrot", "description": "This is a parrot.", "parameters": None, "outputs_to_string": None, "inputs_from_state": None, "outputs_to_state": None, }, }, ], "system_prompt": None, "exit_conditions": ["text", "weather_tool"], "state_schema": {"foo": {"type": "str"}}, "max_agent_steps": 100, "raise_on_tool_invocation_failure": False, "streaming_callback": None, }, } agent = Agent.from_dict(data) assert isinstance(agent, Agent) assert isinstance(agent.chat_generator, OpenAIChatGenerator) assert agent.chat_generator.model == "gpt-4o-mini" assert agent.chat_generator.api_key == Secret.from_env_var("OPENAI_API_KEY") assert agent.tools[0].function is weather_function assert isinstance(agent.tools[1]._component, PromptBuilder) assert agent.exit_conditions == ["text", "weather_tool"] assert agent.state_schema == { "foo": {"type": str}, "messages": {"handler": merge_lists, "type": List[ChatMessage]}, } def test_serde(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) agent = Agent( chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"], state_schema={"foo": {"type": str}}, ) serialized_agent = agent.to_dict() init_parameters = serialized_agent["init_parameters"] assert serialized_agent["type"] == "haystack.components.agents.agent.Agent" assert ( init_parameters["chat_generator"]["type"] == "haystack.components.generators.chat.openai.OpenAIChatGenerator" ) assert init_parameters["streaming_callback"] is None assert init_parameters["tools"][0]["data"]["function"] == serialize_callable(weather_function) assert ( init_parameters["tools"][1]["data"]["component"]["type"] == "haystack.components.builders.prompt_builder.PromptBuilder" ) assert init_parameters["exit_conditions"] == ["text", "weather_tool"] deserialized_agent = Agent.from_dict(serialized_agent) assert isinstance(deserialized_agent, Agent) assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator) assert deserialized_agent.tools[0].function is weather_function assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder) assert deserialized_agent.exit_conditions == ["text", "weather_tool"] assert deserialized_agent.state_schema == { "foo": {"type": str}, "messages": {"handler": merge_lists, "type": List[ChatMessage]}, } def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) agent = Agent( chat_generator=generator, tools=[weather_tool, component_tool], streaming_callback=streaming_callback_for_serde, ) serialized_agent = agent.to_dict() init_parameters = serialized_agent["init_parameters"] assert init_parameters["streaming_callback"] == "test_agent.streaming_callback_for_serde" deserialized_agent = Agent.from_dict(serialized_agent) assert deserialized_agent.streaming_callback is streaming_callback_for_serde def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) # Test invalid exit condition with pytest.raises(ValueError, match="Invalid exit conditions provided:"): Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"]) # Test default exit condition agent = Agent(chat_generator=generator, tools=[weather_tool, component_tool]) assert agent.exit_conditions == ["text"] # Test multiple valid exit conditions agent = Agent( chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"] ) assert agent.exit_conditions == ["text", "weather_tool"] def test_run_with_params_streaming(self, openai_mock_chat_completion_chunk, weather_tool): chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True agent = Agent(chat_generator=chat_generator, streaming_callback=streaming_callback, tools=[weather_tool]) agent.warm_up() response = agent.run([ChatMessage.from_user("Hello")]) # check we called the streaming callback assert streaming_callback_called is True # check that the component still returns the correct response assert isinstance(response, dict) assert "messages" in response assert isinstance(response["messages"], list) assert len(response["messages"]) == 2 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk def test_run_with_run_streaming(self, openai_mock_chat_completion_chunk, weather_tool): chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) agent.warm_up() response = agent.run([ChatMessage.from_user("Hello")], streaming_callback=streaming_callback) # check we called the streaming callback assert streaming_callback_called is True # check that the component still returns the correct response assert isinstance(response, dict) assert "messages" in response assert isinstance(response["messages"], list) assert len(response["messages"]) == 2 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk def test_keep_generator_streaming(self, openai_mock_chat_completion_chunk, weather_tool): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True chat_generator = OpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback ) agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) agent.warm_up() response = agent.run([ChatMessage.from_user("Hello")]) # check we called the streaming callback assert streaming_callback_called is True # check that the component still returns the correct response assert isinstance(response, dict) assert "messages" in response assert isinstance(response["messages"], list) assert len(response["messages"]) == 2 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk def test_chat_generator_must_support_tools(self, weather_tool): chat_generator = MockChatGeneratorWithoutTools() with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"): Agent(chat_generator=chat_generator, tools=[weather_tool]) def test_multiple_llm_responses_with_tool_call(self, monkeypatch, weather_tool): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") generator = OpenAIChatGenerator() mock_messages = [ ChatMessage.from_assistant("First response"), ChatMessage.from_assistant( tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] ), ] agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1) agent.warm_up() # Patch agent.chat_generator.run to return mock_messages agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages}) result = agent.run([ChatMessage.from_user("Hello")]) assert "messages" in result assert len(result["messages"]) == 4 assert ( result["messages"][-1].tool_call_result.result == "{'weather': 'mostly sunny', 'temperature': 7, 'unit': 'celsius'}" ) def test_exit_conditions_checked_across_all_llm_messages(self, monkeypatch, weather_tool): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") generator = OpenAIChatGenerator() # Mock messages where the exit condition appears in the second message mock_messages = [ ChatMessage.from_assistant("First response"), ChatMessage.from_assistant( tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] ), ] agent = Agent(chat_generator=generator, tools=[weather_tool], exit_conditions=["weather_tool"]) agent.warm_up() # Patch agent.chat_generator.run to return mock_messages agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages}) result = agent.run([ChatMessage.from_user("Hello")]) assert "messages" in result assert len(result["messages"]) == 4 assert result["messages"][-2].tool_call.tool_name == "weather_tool" assert ( result["messages"][-1].tool_call_result.result == "{'weather': 'mostly sunny', 'temperature': 7, 'unit': 'celsius'}" ) @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_run(self, weather_tool): chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3) agent.warm_up() response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) assert isinstance(response, dict) assert "messages" in response assert isinstance(response["messages"], list) assert len(response["messages"]) == 4 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] # Loose check of message texts assert response["messages"][0].text == "What is the weather in Berlin?" assert response["messages"][1].text is None assert response["messages"][2].text is None assert response["messages"][3].text is not None # Loose check of message metadata assert response["messages"][0].meta == {} assert response["messages"][1].meta.get("model") is not None assert response["messages"][2].meta == {} assert response["messages"][3].meta.get("model") is not None # Loose check of tool calls and results assert response["messages"][1].tool_calls[0].tool_name == "weather_tool" assert response["messages"][1].tool_calls[0].arguments is not None assert response["messages"][2].tool_call_results[0].result is not None assert response["messages"][2].tool_call_results[0].origin is not None