mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-16 17:48:19 +00:00
chore: use deserialize_chatgenerator_inplace utility function in Agent (#9149)
This commit is contained in:
parent
adc3dfc5d2
commit
b12af1e6a9
@ -8,12 +8,12 @@ from typing import Any, Dict, List, Optional
|
|||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.components.generators.chat.types import ChatGenerator
|
from haystack.components.generators.chat.types import ChatGenerator
|
||||||
from haystack.components.tools import ToolInvoker
|
from haystack.components.tools import ToolInvoker
|
||||||
from haystack.core.serialization import import_class_by_name
|
|
||||||
from haystack.dataclasses import ChatMessage
|
from haystack.dataclasses import ChatMessage
|
||||||
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
|
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
|
||||||
from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT
|
from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT
|
||||||
from haystack.tools import Tool, deserialize_tools_inplace
|
from haystack.tools import Tool, deserialize_tools_inplace
|
||||||
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
||||||
|
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -157,10 +157,7 @@ class Agent:
|
|||||||
"""
|
"""
|
||||||
init_params = data.get("init_parameters", {})
|
init_params = data.get("init_parameters", {})
|
||||||
|
|
||||||
chat_generator_class = import_class_by_name(init_params["chat_generator"]["type"])
|
deserialize_chatgenerator_inplace(init_params, key="chat_generator")
|
||||||
assert hasattr(chat_generator_class, "from_dict") # we know but mypy doesn't
|
|
||||||
chat_generator_instance = chat_generator_class.from_dict(init_params["chat_generator"])
|
|
||||||
data["init_parameters"]["chat_generator"] = chat_generator_instance
|
|
||||||
|
|
||||||
if "state_schema" in init_params:
|
if "state_schema" in init_params:
|
||||||
init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
|
init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
|
||||||
|
|||||||
@ -20,8 +20,6 @@ from haystack.dataclasses.streaming_chunk import StreamingChunk
|
|||||||
from haystack.tools import Tool, ComponentTool
|
from haystack.tools import Tool, ComponentTool
|
||||||
from haystack.utils import serialize_callable, Secret
|
from haystack.utils import serialize_callable, Secret
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def streaming_callback_for_serde(chunk: StreamingChunk):
|
def streaming_callback_for_serde(chunk: StreamingChunk):
|
||||||
pass
|
pass
|
||||||
@ -111,7 +109,10 @@ class TestAgent:
|
|||||||
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
||||||
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
|
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"]
|
chat_generator=generator,
|
||||||
|
tools=[weather_tool, component_tool],
|
||||||
|
exit_conditions=["text", "weather_tool"],
|
||||||
|
state_schema={"foo": {"type": str}},
|
||||||
)
|
)
|
||||||
|
|
||||||
serialized_agent = agent.to_dict()
|
serialized_agent = agent.to_dict()
|
||||||
@ -138,6 +139,7 @@ class TestAgent:
|
|||||||
assert deserialized_agent.tools[0].function is weather_function
|
assert deserialized_agent.tools[0].function is weather_function
|
||||||
assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder)
|
assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder)
|
||||||
assert deserialized_agent.exit_conditions == ["text", "weather_tool"]
|
assert deserialized_agent.exit_conditions == ["text", "weather_tool"]
|
||||||
|
assert deserialized_agent.state_schema == {"foo": {"type": str}}
|
||||||
|
|
||||||
def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch):
|
def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch):
|
||||||
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user