feat: Agent checks that its chat_generator supports tools (#9144)

* check that chat_generator run accepts tools

* reno
This commit is contained in:
Julian Risch 2025-03-31 14:47:38 +02:00 committed by GitHub
parent c8918e43ba
commit fc33382b48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 1 deletions

View File

@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Any, Dict, List, Optional
from haystack import component, default_from_dict, default_to_dict, logging
@ -76,7 +77,16 @@ class Agent:
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
"""
# Check if chat_generator supports tools parameter
chat_generator_run_method = inspect.signature(chat_generator.run)
if "tools" not in chat_generator_run_method.parameters:
raise TypeError(
f"{type(chat_generator).__name__} does not accept tools parameter in its run method. "
"The Agent component requires a chat generator that supports tools."
)
valid_exits = ["text"] + [tool.name for tool in tools or []]
if exit_conditions is None:
exit_conditions = ["text"]

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
The Agent component checks that the ChatGenerator it is initialized with supports tools. If it doesn't, the Agent raises a TypeError.

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from typing import Iterator
from typing import Iterator, Dict, Any, List
from unittest.mock import MagicMock, patch
import pytest
@ -14,6 +14,7 @@ from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
from haystack.components.agents import Agent
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.generators.chat.types import ChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool, ComponentTool
@ -91,6 +92,20 @@ def openai_mock_chat_completion_chunk():
yield mock_chat_completion_create
class MockChatGeneratorWithoutTools(ChatGenerator):
"""A mock chat generator that implements ChatGenerator protocol but doesn't support tools."""
def to_dict(self) -> Dict[str, Any]:
return {"type": "MockChatGeneratorWithoutTools", "data": {}}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutTools":
return cls()
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
return {"replies": [ChatMessage.from_assistant("Hello")]}
class TestAgent:
def test_serde(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
@ -231,3 +246,9 @@ class TestAgent:
assert len(response["messages"]) == 2
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
def test_chat_generator_must_support_tools(self, weather_tool):
chat_generator = MockChatGeneratorWithoutTools()
with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"):
Agent(chat_generator=chat_generator, tools=[weather_tool])