chore: use deserialize_chatgenerator_inplace utility function in Agent (#9149)

This commit is contained in:
Stefano Fiorucci 2025-03-31 18:00:59 +02:00 committed by GitHub
parent adc3dfc5d2
commit b12af1e6a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 8 deletions

View File

@ -8,12 +8,12 @@ from typing import Any, Dict, List, Optional
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.chat.types import ChatGenerator
from haystack.components.tools import ToolInvoker
from haystack.core.serialization import import_class_by_name
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT
from haystack.tools import Tool, deserialize_tools_inplace
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
logger = logging.getLogger(__name__)
@ -157,10 +157,7 @@ class Agent:
"""
init_params = data.get("init_parameters", {})
chat_generator_class = import_class_by_name(init_params["chat_generator"]["type"])
assert hasattr(chat_generator_class, "from_dict") # we know but mypy doesn't
chat_generator_instance = chat_generator_class.from_dict(init_params["chat_generator"])
data["init_parameters"]["chat_generator"] = chat_generator_instance
deserialize_chatgenerator_inplace(init_params, key="chat_generator")
if "state_schema" in init_params:
init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])

View File

@ -20,8 +20,6 @@ from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool, ComponentTool
from haystack.utils import serialize_callable, Secret
import os
def streaming_callback_for_serde(chunk: StreamingChunk):
pass
@ -111,7 +109,10 @@ class TestAgent:
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
agent = Agent(
chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"]
chat_generator=generator,
tools=[weather_tool, component_tool],
exit_conditions=["text", "weather_tool"],
state_schema={"foo": {"type": str}},
)
serialized_agent = agent.to_dict()
@ -138,6 +139,7 @@ class TestAgent:
assert deserialized_agent.tools[0].function is weather_function
assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder)
assert deserialized_agent.exit_conditions == ["text", "weather_tool"]
assert deserialized_agent.state_schema == {"foo": {"type": str}}
def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch):
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")