From 81fbe546cb5e079f367ce1b11c0bf94ae8fdb70e Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 11 Apr 2025 10:30:48 +0200 Subject: [PATCH] feat: `ChatGenerator` protocol - do not require `to_dict` and `from_dict` methods (#9213) * minimize protocol * progress * rm unneeded test changes * reno * use keywords arguments for clarity --- haystack/components/agents/agent.py | 3 ++- .../evaluators/context_relevance.py | 3 ++- .../components/evaluators/faithfulness.py | 3 ++- .../components/evaluators/llm_evaluator.py | 3 ++- .../extractors/llm_metadata_extractor.py | 3 ++- .../generators/chat/types/protocol.py | 20 ------------------- haystack/utils/deserialization.py | 7 ++----- ...hatgenerator-minimal-2efb76f02ab0f033.yaml | 4 ++++ test/utils/test_deserialization.py | 4 ++-- 9 files changed, 18 insertions(+), 32 deletions(-) create mode 100644 releasenotes/notes/chatgenerator-minimal-2efb76f02ab0f033.yaml diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index b301032c0..815b862c8 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.chat.types import ChatGenerator from haystack.components.tools import ToolInvoker +from haystack.core.serialization import component_to_dict from haystack.dataclasses import ChatMessage from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema from haystack.dataclasses.state_utils import merge_lists @@ -152,7 +153,7 @@ class Agent: return default_to_dict( self, - chat_generator=self.chat_generator.to_dict(), + chat_generator=component_to_dict(obj=self.chat_generator, name="chat_generator"), tools=[t.to_dict() for t in self.tools], system_prompt=self.system_prompt, exit_conditions=self.exit_conditions, diff --git a/haystack/components/evaluators/context_relevance.py b/haystack/components/evaluators/context_relevance.py index acdf9310a..4dcc3b9cc 100644 --- a/haystack/components/evaluators/context_relevance.py +++ b/haystack/components/evaluators/context_relevance.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict from haystack.components.evaluators.llm_evaluator import LLMEvaluator from haystack.components.generators.chat.types import ChatGenerator +from haystack.core.serialization import component_to_dict from haystack.utils import Secret, deserialize_chatgenerator_inplace, deserialize_secrets_inplace # Private global variable for default examples to include in the prompt if the user does not provide any examples @@ -209,7 +210,7 @@ class ContextRelevanceEvaluator(LLMEvaluator): """ return default_to_dict( self, - chat_generator=self._chat_generator.to_dict(), + chat_generator=component_to_dict(obj=self._chat_generator, name="chat_generator"), examples=self.examples, progress_bar=self.progress_bar, raise_on_failure=self.raise_on_failure, diff --git a/haystack/components/evaluators/faithfulness.py b/haystack/components/evaluators/faithfulness.py index 841c64263..35f75f735 100644 --- a/haystack/components/evaluators/faithfulness.py +++ b/haystack/components/evaluators/faithfulness.py @@ -9,6 +9,7 @@ from numpy import mean as np_mean from haystack import component, default_from_dict, default_to_dict from haystack.components.evaluators.llm_evaluator import LLMEvaluator from haystack.components.generators.chat.types import ChatGenerator +from haystack.core.serialization import component_to_dict from haystack.utils import Secret, deserialize_chatgenerator_inplace, deserialize_secrets_inplace # Default examples to include in the prompt if the user does not provide any examples @@ -203,7 +204,7 @@ class FaithfulnessEvaluator(LLMEvaluator): """ return default_to_dict( self, - chat_generator=self._chat_generator.to_dict(), + chat_generator=component_to_dict(obj=self._chat_generator, name="chat_generator"), examples=self.examples, progress_bar=self.progress_bar, raise_on_failure=self.raise_on_failure, diff --git a/haystack/components/evaluators/llm_evaluator.py b/haystack/components/evaluators/llm_evaluator.py index 101b1c8e6..7374d2b4b 100644 --- a/haystack/components/evaluators/llm_evaluator.py +++ b/haystack/components/evaluators/llm_evaluator.py @@ -12,6 +12,7 @@ from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.builders import PromptBuilder from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.types import ChatGenerator +from haystack.core.serialization import component_to_dict from haystack.dataclasses.chat_message import ChatMessage from haystack.utils import ( Secret, @@ -322,7 +323,7 @@ class LLMEvaluator: inputs=inputs, outputs=self.outputs, examples=self.examples, - chat_generator=self._chat_generator.to_dict(), + chat_generator=component_to_dict(obj=self._chat_generator, name="chat_generator"), progress_bar=self.progress_bar, ) diff --git a/haystack/components/extractors/llm_metadata_extractor.py b/haystack/components/extractors/llm_metadata_extractor.py index f51a44fe5..c6f61127a 100644 --- a/haystack/components/extractors/llm_metadata_extractor.py +++ b/haystack/components/extractors/llm_metadata_extractor.py @@ -17,6 +17,7 @@ from haystack.components.builders import PromptBuilder from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator from haystack.components.generators.chat.types import ChatGenerator from haystack.components.preprocessors import DocumentSplitter +from haystack.core.serialization import component_to_dict from haystack.dataclasses import ChatMessage from haystack.lazy_imports import LazyImport from haystack.utils import ( @@ -281,7 +282,7 @@ class LLMMetadataExtractor: return default_to_dict( self, prompt=self.prompt, - chat_generator=self._chat_generator.to_dict(), + chat_generator=component_to_dict(obj=self._chat_generator, name="chat_generator"), expected_keys=self.expected_keys, page_range=self.expanded_range, raise_on_failure=self.raise_on_failure, diff --git a/haystack/components/generators/chat/types/protocol.py b/haystack/components/generators/chat/types/protocol.py index 79d082fa0..394d72941 100644 --- a/haystack/components/generators/chat/types/protocol.py +++ b/haystack/components/generators/chat/types/protocol.py @@ -21,26 +21,6 @@ class ChatGenerator(Protocol): responses using a Language Model. They return a dictionary. """ - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this ChatGenerator to a dictionary. - - :returns: - The serialized ChatGenerator as a dictionary. - """ - ... - - @classmethod - def from_dict(cls: type[T], data: Dict[str, Any]) -> T: - """ - Deserialize this ChatGenerator from a dictionary. - - :param data: The dictionary representation of this ChatGenerator. - :returns: - An instance of the specific implementing class. - """ - ... - def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: """ Generate messages using the underlying Language Model. diff --git a/haystack/utils/deserialization.py b/haystack/utils/deserialization.py index 7af0c608e..d80c2672f 100644 --- a/haystack/utils/deserialization.py +++ b/haystack/utils/deserialization.py @@ -5,7 +5,7 @@ from typing import Any, Dict from haystack import DeserializationError -from haystack.core.serialization import default_from_dict, import_class_by_name +from haystack.core.serialization import component_from_dict, default_from_dict, import_class_by_name def deserialize_document_store_in_init_params_inplace(data: Dict[str, Any], key: str = "document_store"): @@ -68,7 +68,4 @@ def deserialize_chatgenerator_inplace(data: Dict[str, Any], key: str = "chat_gen except ImportError as e: raise DeserializationError(f"Class '{serialized_chat_generator['type']}' not correctly imported") from e - if not hasattr(chat_generator_class, "from_dict"): - raise DeserializationError(f"Class '{chat_generator_class}' does not have a 'from_dict' method") - - data[key] = chat_generator_class.from_dict(serialized_chat_generator) + data[key] = component_from_dict(cls=chat_generator_class, data=serialized_chat_generator, name="chat_generator") diff --git a/releasenotes/notes/chatgenerator-minimal-2efb76f02ab0f033.yaml b/releasenotes/notes/chatgenerator-minimal-2efb76f02ab0f033.yaml new file mode 100644 index 000000000..a5a4946c8 --- /dev/null +++ b/releasenotes/notes/chatgenerator-minimal-2efb76f02ab0f033.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + The `ChatGenerator` Protocol no longer requires `to_dict` and `from_dict` methods. diff --git a/test/utils/test_deserialization.py b/test/utils/test_deserialization.py index 9f58079be..eb08791ab 100644 --- a/test/utils/test_deserialization.py +++ b/test/utils/test_deserialization.py @@ -130,5 +130,5 @@ class TestDeserializeChatGeneratorInplace: def test_chat_generator_no_from_dict_method(self): chat_generator = ChatGeneratorWithoutFromDict() data = {"chat_generator": chat_generator.to_dict()} - with pytest.raises(DeserializationError): - deserialize_chatgenerator_inplace(data) + deserialize_chatgenerator_inplace(data) + assert isinstance(data["chat_generator"], ChatGeneratorWithoutFromDict)