From b12af1e6a987cde382abd03cfcd6b34951ca722f Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 31 Mar 2025 18:00:59 +0200 Subject: [PATCH] chore: use deserialize_chatgenerator_inplace utility function in Agent (#9149) --- haystack/components/agents/agent.py | 7 ++----- test/components/agents/test_agent.py | 8 +++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index e5aa4d856..b5f187a67 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -8,12 +8,12 @@ from typing import Any, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.chat.types import ChatGenerator from haystack.components.tools import ToolInvoker -from haystack.core.serialization import import_class_by_name from haystack.dataclasses import ChatMessage from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT from haystack.tools import Tool, deserialize_tools_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from haystack.utils.deserialization import deserialize_chatgenerator_inplace logger = logging.getLogger(__name__) @@ -157,10 +157,7 @@ class Agent: """ init_params = data.get("init_parameters", {}) - chat_generator_class = import_class_by_name(init_params["chat_generator"]["type"]) - assert hasattr(chat_generator_class, "from_dict") # we know but mypy doesn't - chat_generator_instance = chat_generator_class.from_dict(init_params["chat_generator"]) - data["init_parameters"]["chat_generator"] = chat_generator_instance + deserialize_chatgenerator_inplace(init_params, key="chat_generator") if "state_schema" in init_params: init_params["state_schema"] = _schema_from_dict(init_params["state_schema"]) diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 38dfc2e1c..71f76b26b 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -20,8 +20,6 @@ from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.tools import Tool, ComponentTool from haystack.utils import serialize_callable, Secret -import os - def streaming_callback_for_serde(chunk: StreamingChunk): pass @@ -111,7 +109,10 @@ class TestAgent: monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) agent = Agent( - chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"] + chat_generator=generator, + tools=[weather_tool, component_tool], + exit_conditions=["text", "weather_tool"], + state_schema={"foo": {"type": str}}, ) serialized_agent = agent.to_dict() @@ -138,6 +139,7 @@ class TestAgent: assert deserialized_agent.tools[0].function is weather_function assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder) assert deserialized_agent.exit_conditions == ["text", "weather_tool"] + assert deserialized_agent.state_schema == {"foo": {"type": str}} def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")