2025-03-28 14:23:39 +01:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
#
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2025-04-01 11:29:44 +02:00
|
|
|
import os
|
2025-03-28 14:23:39 +01:00
|
|
|
from datetime import datetime
|
2025-03-31 14:47:38 +02:00
|
|
|
from typing import Iterator, Dict, Any, List
|
2025-03-28 14:23:39 +01:00
|
|
|
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from openai import Stream
|
|
|
|
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
|
|
|
|
|
|
|
|
from haystack.components.agents import Agent
|
|
|
|
from haystack.components.builders.prompt_builder import PromptBuilder
|
|
|
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
2025-03-31 14:47:38 +02:00
|
|
|
from haystack.components.generators.chat.types import ChatGenerator
|
2025-04-10 08:39:17 +02:00
|
|
|
from haystack.core.component.types import OutputSocket
|
2025-04-01 14:51:53 +02:00
|
|
|
from haystack.dataclasses import ChatMessage, ToolCall
|
2025-03-28 14:23:39 +01:00
|
|
|
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
|
|
|
from haystack.tools import Tool, ComponentTool
|
|
|
|
from haystack.utils import serialize_callable, Secret
|
2025-04-10 08:39:17 +02:00
|
|
|
from haystack.dataclasses.state_utils import merge_lists
|
2025-03-28 14:23:39 +01:00
|
|
|
|
|
|
|
|
|
|
|
def streaming_callback_for_serde(chunk: StreamingChunk):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def weather_function(location):
|
|
|
|
weather_info = {
|
|
|
|
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
|
|
|
|
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
|
|
|
|
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
|
|
|
|
}
|
|
|
|
return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def weather_tool():
|
|
|
|
return Tool(
|
|
|
|
name="weather_tool",
|
|
|
|
description="Provides weather information for a given location.",
|
2025-04-01 11:29:44 +02:00
|
|
|
parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]},
|
2025-03-28 14:23:39 +01:00
|
|
|
function=weather_function,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def component_tool():
|
|
|
|
return ComponentTool(name="parrot", description="This is a parrot.", component=PromptBuilder(template="{{parrot}}"))
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIMockStream(Stream[ChatCompletionChunk]):
|
|
|
|
def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs):
|
|
|
|
client = client or MagicMock()
|
|
|
|
super().__init__(client=client, *args, **kwargs)
|
|
|
|
self.mock_chunk = mock_chunk
|
|
|
|
|
|
|
|
def __stream__(self) -> Iterator[ChatCompletionChunk]:
|
|
|
|
yield self.mock_chunk
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def openai_mock_chat_completion_chunk():
|
|
|
|
"""
|
|
|
|
Mock the OpenAI API completion chunk response and reuse it for tests
|
|
|
|
"""
|
|
|
|
|
|
|
|
with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create:
|
|
|
|
completion = ChatCompletionChunk(
|
|
|
|
id="foo",
|
|
|
|
model="gpt-4",
|
|
|
|
object="chat.completion.chunk",
|
|
|
|
choices=[
|
|
|
|
chat_completion_chunk.Choice(
|
|
|
|
finish_reason="stop",
|
|
|
|
logprobs=None,
|
|
|
|
index=0,
|
|
|
|
delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant"),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=int(datetime.now().timestamp()),
|
|
|
|
usage=None,
|
|
|
|
)
|
|
|
|
mock_chat_completion_create.return_value = OpenAIMockStream(
|
|
|
|
completion, cast_to=None, response=None, client=None
|
|
|
|
)
|
|
|
|
yield mock_chat_completion_create
|
|
|
|
|
|
|
|
|
2025-03-31 14:47:38 +02:00
|
|
|
class MockChatGeneratorWithoutTools(ChatGenerator):
|
|
|
|
"""A mock chat generator that implements ChatGenerator protocol but doesn't support tools."""
|
|
|
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
|
|
return {"type": "MockChatGeneratorWithoutTools", "data": {}}
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutTools":
|
|
|
|
return cls()
|
|
|
|
|
|
|
|
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
|
|
|
|
return {"replies": [ChatMessage.from_assistant("Hello")]}
|
|
|
|
|
|
|
|
|
2025-03-28 14:23:39 +01:00
|
|
|
class TestAgent:
|
2025-04-10 08:39:17 +02:00
|
|
|
def test_output_types(self, weather_tool, component_tool, monkeypatch):
|
|
|
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
|
|
chat_generator = OpenAIChatGenerator()
|
|
|
|
agent = Agent(chat_generator=chat_generator, tools=[weather_tool, component_tool])
|
|
|
|
assert agent.__haystack_output__._sockets_dict == {
|
|
|
|
"messages": OutputSocket(name="messages", type=List[ChatMessage], receivers=[])
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_to_dict(self, weather_tool, component_tool, monkeypatch):
|
|
|
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
|
|
generator = OpenAIChatGenerator()
|
|
|
|
agent = Agent(
|
|
|
|
chat_generator=generator,
|
|
|
|
tools=[weather_tool, component_tool],
|
|
|
|
exit_conditions=["text", "weather_tool"],
|
|
|
|
state_schema={"foo": {"type": str}},
|
|
|
|
)
|
|
|
|
serialized_agent = agent.to_dict()
|
|
|
|
assert serialized_agent == {
|
|
|
|
"type": "haystack.components.agents.agent.Agent",
|
|
|
|
"init_parameters": {
|
|
|
|
"chat_generator": {
|
|
|
|
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
|
|
|
"init_parameters": {
|
|
|
|
"model": "gpt-4o-mini",
|
|
|
|
"streaming_callback": None,
|
|
|
|
"api_base_url": None,
|
|
|
|
"organization": None,
|
|
|
|
"generation_kwargs": {},
|
|
|
|
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
|
|
|
|
"timeout": None,
|
|
|
|
"max_retries": None,
|
|
|
|
"tools": None,
|
|
|
|
"tools_strict": False,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"tools": [
|
|
|
|
{
|
|
|
|
"type": "haystack.tools.tool.Tool",
|
|
|
|
"data": {
|
|
|
|
"name": "weather_tool",
|
|
|
|
"description": "Provides weather information for a given location.",
|
|
|
|
"parameters": {
|
|
|
|
"type": "object",
|
|
|
|
"properties": {"location": {"type": "string"}},
|
|
|
|
"required": ["location"],
|
|
|
|
},
|
|
|
|
"function": "test_agent.weather_function",
|
|
|
|
"outputs_to_string": None,
|
|
|
|
"inputs_from_state": None,
|
|
|
|
"outputs_to_state": None,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"type": "haystack.tools.component_tool.ComponentTool",
|
|
|
|
"data": {
|
|
|
|
"component": {
|
|
|
|
"type": "haystack.components.builders.prompt_builder.PromptBuilder",
|
|
|
|
"init_parameters": {
|
|
|
|
"template": "{{parrot}}",
|
|
|
|
"variables": None,
|
|
|
|
"required_variables": None,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"name": "parrot",
|
|
|
|
"description": "This is a parrot.",
|
|
|
|
"parameters": None,
|
|
|
|
"outputs_to_string": None,
|
|
|
|
"inputs_from_state": None,
|
|
|
|
"outputs_to_state": None,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
],
|
|
|
|
"system_prompt": None,
|
|
|
|
"exit_conditions": ["text", "weather_tool"],
|
|
|
|
"state_schema": {"foo": {"type": "str"}},
|
|
|
|
"max_agent_steps": 100,
|
|
|
|
"raise_on_tool_invocation_failure": False,
|
|
|
|
"streaming_callback": None,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_from_dict(self, weather_tool, component_tool, monkeypatch):
|
|
|
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
|
|
data = {
|
|
|
|
"type": "haystack.components.agents.agent.Agent",
|
|
|
|
"init_parameters": {
|
|
|
|
"chat_generator": {
|
|
|
|
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
|
|
|
"init_parameters": {
|
|
|
|
"model": "gpt-4o-mini",
|
|
|
|
"streaming_callback": None,
|
|
|
|
"api_base_url": None,
|
|
|
|
"organization": None,
|
|
|
|
"generation_kwargs": {},
|
|
|
|
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
|
|
|
|
"timeout": None,
|
|
|
|
"max_retries": None,
|
|
|
|
"tools": None,
|
|
|
|
"tools_strict": False,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"tools": [
|
|
|
|
{
|
|
|
|
"type": "haystack.tools.tool.Tool",
|
|
|
|
"data": {
|
|
|
|
"name": "weather_tool",
|
|
|
|
"description": "Provides weather information for a given location.",
|
|
|
|
"parameters": {
|
|
|
|
"type": "object",
|
|
|
|
"properties": {"location": {"type": "string"}},
|
|
|
|
"required": ["location"],
|
|
|
|
},
|
|
|
|
"function": "test_agent.weather_function",
|
|
|
|
"outputs_to_string": None,
|
|
|
|
"inputs_from_state": None,
|
|
|
|
"outputs_to_state": None,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"type": "haystack.tools.component_tool.ComponentTool",
|
|
|
|
"data": {
|
|
|
|
"component": {
|
|
|
|
"type": "haystack.components.builders.prompt_builder.PromptBuilder",
|
|
|
|
"init_parameters": {
|
|
|
|
"template": "{{parrot}}",
|
|
|
|
"variables": None,
|
|
|
|
"required_variables": None,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"name": "parrot",
|
|
|
|
"description": "This is a parrot.",
|
|
|
|
"parameters": None,
|
|
|
|
"outputs_to_string": None,
|
|
|
|
"inputs_from_state": None,
|
|
|
|
"outputs_to_state": None,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
],
|
|
|
|
"system_prompt": None,
|
|
|
|
"exit_conditions": ["text", "weather_tool"],
|
|
|
|
"state_schema": {"foo": {"type": "str"}},
|
|
|
|
"max_agent_steps": 100,
|
|
|
|
"raise_on_tool_invocation_failure": False,
|
|
|
|
"streaming_callback": None,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
agent = Agent.from_dict(data)
|
|
|
|
assert isinstance(agent, Agent)
|
|
|
|
assert isinstance(agent.chat_generator, OpenAIChatGenerator)
|
|
|
|
assert agent.chat_generator.model == "gpt-4o-mini"
|
|
|
|
assert agent.chat_generator.api_key == Secret.from_env_var("OPENAI_API_KEY")
|
|
|
|
assert agent.tools[0].function is weather_function
|
|
|
|
assert isinstance(agent.tools[1]._component, PromptBuilder)
|
|
|
|
assert agent.exit_conditions == ["text", "weather_tool"]
|
|
|
|
assert agent.state_schema == {
|
|
|
|
"foo": {"type": str},
|
|
|
|
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
|
|
|
|
}
|
|
|
|
|
2025-03-28 14:23:39 +01:00
|
|
|
def test_serde(self, weather_tool, component_tool, monkeypatch):
|
|
|
|
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
|
|
|
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
|
2025-03-31 11:02:25 +02:00
|
|
|
agent = Agent(
|
2025-03-31 18:00:59 +02:00
|
|
|
chat_generator=generator,
|
|
|
|
tools=[weather_tool, component_tool],
|
|
|
|
exit_conditions=["text", "weather_tool"],
|
|
|
|
state_schema={"foo": {"type": str}},
|
2025-03-31 11:02:25 +02:00
|
|
|
)
|
2025-03-28 14:23:39 +01:00
|
|
|
|
|
|
|
serialized_agent = agent.to_dict()
|
|
|
|
|
|
|
|
init_parameters = serialized_agent["init_parameters"]
|
|
|
|
|
|
|
|
assert serialized_agent["type"] == "haystack.components.agents.agent.Agent"
|
|
|
|
assert (
|
|
|
|
init_parameters["chat_generator"]["type"]
|
|
|
|
== "haystack.components.generators.chat.openai.OpenAIChatGenerator"
|
|
|
|
)
|
2025-04-10 08:39:17 +02:00
|
|
|
assert init_parameters["streaming_callback"] is None
|
2025-03-28 14:23:39 +01:00
|
|
|
assert init_parameters["tools"][0]["data"]["function"] == serialize_callable(weather_function)
|
|
|
|
assert (
|
|
|
|
init_parameters["tools"][1]["data"]["component"]["type"]
|
|
|
|
== "haystack.components.builders.prompt_builder.PromptBuilder"
|
|
|
|
)
|
2025-03-31 11:02:25 +02:00
|
|
|
assert init_parameters["exit_conditions"] == ["text", "weather_tool"]
|
2025-03-28 14:23:39 +01:00
|
|
|
|
|
|
|
deserialized_agent = Agent.from_dict(serialized_agent)
|
|
|
|
|
|
|
|
assert isinstance(deserialized_agent, Agent)
|
|
|
|
assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator)
|
|
|
|
assert deserialized_agent.tools[0].function is weather_function
|
|
|
|
assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder)
|
2025-03-31 11:02:25 +02:00
|
|
|
assert deserialized_agent.exit_conditions == ["text", "weather_tool"]
|
2025-04-10 08:39:17 +02:00
|
|
|
assert deserialized_agent.state_schema == {
|
|
|
|
"foo": {"type": str},
|
|
|
|
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
|
|
|
|
}
|
2025-03-28 14:23:39 +01:00
|
|
|
|
|
|
|
def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch):
|
|
|
|
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
|
|
|
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
|
|
|
|
agent = Agent(
|
|
|
|
chat_generator=generator,
|
|
|
|
tools=[weather_tool, component_tool],
|
|
|
|
streaming_callback=streaming_callback_for_serde,
|
|
|
|
)
|
|
|
|
|
|
|
|
serialized_agent = agent.to_dict()
|
|
|
|
|
|
|
|
init_parameters = serialized_agent["init_parameters"]
|
|
|
|
assert init_parameters["streaming_callback"] == "test_agent.streaming_callback_for_serde"
|
|
|
|
|
|
|
|
deserialized_agent = Agent.from_dict(serialized_agent)
|
|
|
|
assert deserialized_agent.streaming_callback is streaming_callback_for_serde
|
|
|
|
|
2025-03-31 11:02:25 +02:00
|
|
|
def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch):
|
|
|
|
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
|
|
|
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
|
|
|
|
|
|
|
|
# Test invalid exit condition
|
2025-04-01 11:29:44 +02:00
|
|
|
with pytest.raises(ValueError, match="Invalid exit conditions provided:"):
|
2025-03-31 11:02:25 +02:00
|
|
|
Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"])
|
|
|
|
|
|
|
|
# Test default exit condition
|
|
|
|
agent = Agent(chat_generator=generator, tools=[weather_tool, component_tool])
|
|
|
|
assert agent.exit_conditions == ["text"]
|
|
|
|
|
|
|
|
# Test multiple valid exit conditions
|
|
|
|
agent = Agent(
|
|
|
|
chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"]
|
|
|
|
)
|
|
|
|
assert agent.exit_conditions == ["text", "weather_tool"]
|
|
|
|
|
2025-03-28 14:23:39 +01:00
|
|
|
def test_run_with_params_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
|
|
|
|
chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
|
|
|
streaming_callback_called = False
|
|
|
|
|
|
|
|
def streaming_callback(chunk: StreamingChunk) -> None:
|
|
|
|
nonlocal streaming_callback_called
|
|
|
|
streaming_callback_called = True
|
|
|
|
|
|
|
|
agent = Agent(chat_generator=chat_generator, streaming_callback=streaming_callback, tools=[weather_tool])
|
|
|
|
agent.warm_up()
|
|
|
|
response = agent.run([ChatMessage.from_user("Hello")])
|
|
|
|
|
|
|
|
# check we called the streaming callback
|
|
|
|
assert streaming_callback_called is True
|
|
|
|
|
|
|
|
# check that the component still returns the correct response
|
|
|
|
assert isinstance(response, dict)
|
|
|
|
assert "messages" in response
|
|
|
|
assert isinstance(response["messages"], list)
|
|
|
|
assert len(response["messages"]) == 2
|
|
|
|
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
|
|
|
|
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
|
|
|
|
|
|
|
|
def test_run_with_run_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
|
|
|
|
chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
|
|
|
|
|
|
|
streaming_callback_called = False
|
|
|
|
|
|
|
|
def streaming_callback(chunk: StreamingChunk) -> None:
|
|
|
|
nonlocal streaming_callback_called
|
|
|
|
streaming_callback_called = True
|
|
|
|
|
|
|
|
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
|
|
|
|
agent.warm_up()
|
|
|
|
response = agent.run([ChatMessage.from_user("Hello")], streaming_callback=streaming_callback)
|
|
|
|
|
|
|
|
# check we called the streaming callback
|
|
|
|
assert streaming_callback_called is True
|
|
|
|
|
|
|
|
# check that the component still returns the correct response
|
|
|
|
assert isinstance(response, dict)
|
|
|
|
assert "messages" in response
|
|
|
|
assert isinstance(response["messages"], list)
|
|
|
|
assert len(response["messages"]) == 2
|
|
|
|
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
|
|
|
|
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
|
|
|
|
|
|
|
|
def test_keep_generator_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
|
|
|
|
streaming_callback_called = False
|
|
|
|
|
|
|
|
def streaming_callback(chunk: StreamingChunk) -> None:
|
|
|
|
nonlocal streaming_callback_called
|
|
|
|
streaming_callback_called = True
|
|
|
|
|
|
|
|
chat_generator = OpenAIChatGenerator(
|
|
|
|
api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback
|
|
|
|
)
|
|
|
|
|
|
|
|
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
|
|
|
|
agent.warm_up()
|
|
|
|
response = agent.run([ChatMessage.from_user("Hello")])
|
|
|
|
|
|
|
|
# check we called the streaming callback
|
|
|
|
assert streaming_callback_called is True
|
|
|
|
|
|
|
|
# check that the component still returns the correct response
|
|
|
|
assert isinstance(response, dict)
|
|
|
|
assert "messages" in response
|
|
|
|
assert isinstance(response["messages"], list)
|
|
|
|
assert len(response["messages"]) == 2
|
|
|
|
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
|
|
|
|
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
|
2025-03-31 14:47:38 +02:00
|
|
|
|
|
|
|
def test_chat_generator_must_support_tools(self, weather_tool):
|
|
|
|
chat_generator = MockChatGeneratorWithoutTools()
|
|
|
|
|
|
|
|
with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"):
|
|
|
|
Agent(chat_generator=chat_generator, tools=[weather_tool])
|
2025-04-01 11:29:44 +02:00
|
|
|
|
2025-04-01 14:51:53 +02:00
|
|
|
def test_multiple_llm_responses_with_tool_call(self, monkeypatch, weather_tool):
|
|
|
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
|
|
generator = OpenAIChatGenerator()
|
|
|
|
|
|
|
|
mock_messages = [
|
|
|
|
ChatMessage.from_assistant("First response"),
|
|
|
|
ChatMessage.from_assistant(
|
|
|
|
tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1)
|
|
|
|
agent.warm_up()
|
|
|
|
|
|
|
|
# Patch agent.chat_generator.run to return mock_messages
|
|
|
|
agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages})
|
|
|
|
|
|
|
|
result = agent.run([ChatMessage.from_user("Hello")])
|
|
|
|
|
|
|
|
assert "messages" in result
|
|
|
|
assert len(result["messages"]) == 4
|
|
|
|
assert (
|
|
|
|
result["messages"][-1].tool_call_result.result
|
|
|
|
== "{'weather': 'mostly sunny', 'temperature': 7, 'unit': 'celsius'}"
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_exit_conditions_checked_across_all_llm_messages(self, monkeypatch, weather_tool):
|
|
|
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
|
|
generator = OpenAIChatGenerator()
|
|
|
|
|
|
|
|
# Mock messages where the exit condition appears in the second message
|
|
|
|
mock_messages = [
|
|
|
|
ChatMessage.from_assistant("First response"),
|
|
|
|
ChatMessage.from_assistant(
|
|
|
|
tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
agent = Agent(chat_generator=generator, tools=[weather_tool], exit_conditions=["weather_tool"])
|
|
|
|
agent.warm_up()
|
|
|
|
|
|
|
|
# Patch agent.chat_generator.run to return mock_messages
|
|
|
|
agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages})
|
|
|
|
|
|
|
|
result = agent.run([ChatMessage.from_user("Hello")])
|
|
|
|
|
|
|
|
assert "messages" in result
|
|
|
|
assert len(result["messages"]) == 4
|
|
|
|
assert result["messages"][-2].tool_call.tool_name == "weather_tool"
|
|
|
|
assert (
|
|
|
|
result["messages"][-1].tool_call_result.result
|
|
|
|
== "{'weather': 'mostly sunny', 'temperature': 7, 'unit': 'celsius'}"
|
|
|
|
)
|
|
|
|
|
2025-04-01 11:29:44 +02:00
|
|
|
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
|
|
|
|
@pytest.mark.integration
|
|
|
|
def test_run(self, weather_tool):
|
|
|
|
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini")
|
|
|
|
agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3)
|
|
|
|
agent.warm_up()
|
|
|
|
response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")])
|
|
|
|
|
|
|
|
assert isinstance(response, dict)
|
|
|
|
assert "messages" in response
|
|
|
|
assert isinstance(response["messages"], list)
|
|
|
|
assert len(response["messages"]) == 4
|
|
|
|
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
|
|
|
|
# Loose check of message texts
|
|
|
|
assert response["messages"][0].text == "What is the weather in Berlin?"
|
|
|
|
assert response["messages"][1].text is None
|
|
|
|
assert response["messages"][2].text is None
|
|
|
|
assert response["messages"][3].text is not None
|
|
|
|
# Loose check of message metadata
|
|
|
|
assert response["messages"][0].meta == {}
|
|
|
|
assert response["messages"][1].meta.get("model") is not None
|
|
|
|
assert response["messages"][2].meta == {}
|
|
|
|
assert response["messages"][3].meta.get("model") is not None
|
|
|
|
# Loose check of tool calls and results
|
|
|
|
assert response["messages"][1].tool_calls[0].tool_name == "weather_tool"
|
|
|
|
assert response["messages"][1].tool_calls[0].arguments is not None
|
|
|
|
assert response["messages"][2].tool_call_results[0].result is not None
|
|
|
|
assert response["messages"][2].tool_call_results[0].origin is not None
|