mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 23:48:53 +00:00
feat: Agent checks that its chat_generator supports tools (#9144)
* check that chat_generator run accepts tools * reno
This commit is contained in:
parent
c8918e43ba
commit
fc33382b48
@ -2,6 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
@ -76,7 +77,16 @@ class Agent:
|
||||
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
|
||||
If set to False, the exception will be turned into a chat message and passed to the LLM.
|
||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
||||
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
|
||||
"""
|
||||
# Check if chat_generator supports tools parameter
|
||||
chat_generator_run_method = inspect.signature(chat_generator.run)
|
||||
if "tools" not in chat_generator_run_method.parameters:
|
||||
raise TypeError(
|
||||
f"{type(chat_generator).__name__} does not accept tools parameter in its run method. "
|
||||
"The Agent component requires a chat generator that supports tools."
|
||||
)
|
||||
|
||||
valid_exits = ["text"] + [tool.name for tool in tools or []]
|
||||
if exit_conditions is None:
|
||||
exit_conditions = ["text"]
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
The Agent component checks that the ChatGenerator it is initialized with supports tools. If it doesn't, the Agent raises a TypeError.
|
||||
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Dict, Any, List
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
@ -14,6 +14,7 @@ from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
|
||||
from haystack.components.agents import Agent
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.components.generators.chat.types import ChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||
from haystack.tools import Tool, ComponentTool
|
||||
@ -91,6 +92,20 @@ def openai_mock_chat_completion_chunk():
|
||||
yield mock_chat_completion_create
|
||||
|
||||
|
||||
class MockChatGeneratorWithoutTools(ChatGenerator):
|
||||
"""A mock chat generator that implements ChatGenerator protocol but doesn't support tools."""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"type": "MockChatGeneratorWithoutTools", "data": {}}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MockChatGeneratorWithoutTools":
|
||||
return cls()
|
||||
|
||||
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
|
||||
return {"replies": [ChatMessage.from_assistant("Hello")]}
|
||||
|
||||
|
||||
class TestAgent:
|
||||
def test_serde(self, weather_tool, component_tool, monkeypatch):
|
||||
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
||||
@ -231,3 +246,9 @@ class TestAgent:
|
||||
assert len(response["messages"]) == 2
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
|
||||
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
|
||||
|
||||
def test_chat_generator_must_support_tools(self, weather_tool):
|
||||
chat_generator = MockChatGeneratorWithoutTools()
|
||||
|
||||
with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"):
|
||||
Agent(chat_generator=chat_generator, tools=[weather_tool])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user