diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index feb96ead1..e5aa4d856 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import inspect from typing import Any, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict, logging @@ -76,7 +77,16 @@ class Agent: :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? If set to False, the exception will be turned into a chat message and passed to the LLM. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + :raises TypeError: If the chat_generator does not support tools parameter in its run method. """ + # Check if chat_generator supports tools parameter + chat_generator_run_method = inspect.signature(chat_generator.run) + if "tools" not in chat_generator_run_method.parameters: + raise TypeError( + f"{type(chat_generator).__name__} does not accept tools parameter in its run method. " + "The Agent component requires a chat generator that supports tools." + ) + valid_exits = ["text"] + [tool.name for tool in tools or []] if exit_conditions is None: exit_conditions = ["text"] diff --git a/releasenotes/notes/check-chatgenerator-tools-support-f684c1ad0d9ad108.yaml b/releasenotes/notes/check-chatgenerator-tools-support-f684c1ad0d9ad108.yaml new file mode 100644 index 000000000..a3b2fb7ef --- /dev/null +++ b/releasenotes/notes/check-chatgenerator-tools-support-f684c1ad0d9ad108.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + The Agent component checks that the ChatGenerator it is initialized with supports tools. If it doesn't, the Agent raises a TypeError. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 0452a9181..38dfc2e1c 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from datetime import datetime -from typing import Iterator +from typing import Iterator, Dict, Any, List from unittest.mock import MagicMock, patch import pytest @@ -14,6 +14,7 @@ from openai.types.chat import ChatCompletionChunk, chat_completion_chunk from haystack.components.agents import Agent from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.generators.chat.openai import OpenAIChatGenerator +from haystack.components.generators.chat.types import ChatGenerator from haystack.dataclasses import ChatMessage from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.tools import Tool, ComponentTool @@ -91,6 +92,20 @@ def openai_mock_chat_completion_chunk(): yield mock_chat_completion_create +class MockChatGeneratorWithoutTools(ChatGenerator): + """A mock chat generator that implements ChatGenerator protocol but doesn't support tools.""" + + def to_dict(self) -> Dict[str, Any]: + return {"type": "MockChatGeneratorWithoutTools", "data": {}} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutTools": + return cls() + + def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: + return {"replies": [ChatMessage.from_assistant("Hello")]} + + class TestAgent: def test_serde(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") @@ -231,3 +246,9 @@ class TestAgent: assert len(response["messages"]) == 2 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk + + def test_chat_generator_must_support_tools(self, weather_tool): + chat_generator = MockChatGeneratorWithoutTools() + + with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"): + Agent(chat_generator=chat_generator, tools=[weather_tool])