mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-12 23:37:36 +00:00
chore: use deserialize_chatgenerator_inplace utility function in Agent (#9149)
This commit is contained in:
parent
adc3dfc5d2
commit
b12af1e6a9
@ -8,12 +8,12 @@ from typing import Any, Dict, List, Optional
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.components.generators.chat.types import ChatGenerator
|
||||
from haystack.components.tools import ToolInvoker
|
||||
from haystack.core.serialization import import_class_by_name
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
|
||||
from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT
|
||||
from haystack.tools import Tool, deserialize_tools_inplace
|
||||
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
||||
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -157,10 +157,7 @@ class Agent:
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
|
||||
chat_generator_class = import_class_by_name(init_params["chat_generator"]["type"])
|
||||
assert hasattr(chat_generator_class, "from_dict") # we know but mypy doesn't
|
||||
chat_generator_instance = chat_generator_class.from_dict(init_params["chat_generator"])
|
||||
data["init_parameters"]["chat_generator"] = chat_generator_instance
|
||||
deserialize_chatgenerator_inplace(init_params, key="chat_generator")
|
||||
|
||||
if "state_schema" in init_params:
|
||||
init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
|
||||
|
||||
@ -20,8 +20,6 @@ from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||
from haystack.tools import Tool, ComponentTool
|
||||
from haystack.utils import serialize_callable, Secret
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def streaming_callback_for_serde(chunk: StreamingChunk):
|
||||
pass
|
||||
@ -111,7 +109,10 @@ class TestAgent:
|
||||
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
||||
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
|
||||
agent = Agent(
|
||||
chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"]
|
||||
chat_generator=generator,
|
||||
tools=[weather_tool, component_tool],
|
||||
exit_conditions=["text", "weather_tool"],
|
||||
state_schema={"foo": {"type": str}},
|
||||
)
|
||||
|
||||
serialized_agent = agent.to_dict()
|
||||
@ -138,6 +139,7 @@ class TestAgent:
|
||||
assert deserialized_agent.tools[0].function is weather_function
|
||||
assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder)
|
||||
assert deserialized_agent.exit_conditions == ["text", "weather_tool"]
|
||||
assert deserialized_agent.state_schema == {"foo": {"type": str}}
|
||||
|
||||
def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch):
|
||||
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user