diff --git a/haystack/utils/deserialization.py b/haystack/utils/deserialization.py index 72f36af17..dddf1da2e 100644 --- a/haystack/utils/deserialization.py +++ b/haystack/utils/deserialization.py @@ -48,6 +48,22 @@ def deserialize_chatgenerator_inplace(data: Dict[str, Any], key: str = "chat_gen :param key: The key in the dictionary where the ChatGenerator is stored. + :raises DeserializationError: + If the key is missing in the serialized data, the value is not a dictionary, + the type key is missing, the class cannot be imported, or the class lacks a 'from_dict' method. + """ + deserialize_component_inplace(data, key=key) + + +def deserialize_component_inplace(data: Dict[str, Any], key: str = "chat_generator") -> None: + """ + Deserialize a Component in a dictionary inplace. + + :param data: + The dictionary with the serialized data. + :param key: + The key in the dictionary where the Component is stored. Default is "chat_generator". + :raises DeserializationError: If the key is missing in the serialized data, the value is not a dictionary, the type key is missing, the class cannot be imported, or the class lacks a 'from_dict' method. @@ -55,17 +71,17 @@ def deserialize_chatgenerator_inplace(data: Dict[str, Any], key: str = "chat_gen if key not in data: raise DeserializationError(f"Missing '{key}' in serialization data") - serialized_chat_generator = data[key] + serialized_component = data[key] - if not isinstance(serialized_chat_generator, dict): + if not isinstance(serialized_component, dict): raise DeserializationError(f"The value of '{key}' is not a dictionary") - if "type" not in serialized_chat_generator: + if "type" not in serialized_component: raise DeserializationError(f"Missing 'type' in {key} serialization data") try: - chat_generator_class = import_class_by_name(serialized_chat_generator["type"]) + component_class = import_class_by_name(serialized_component["type"]) except ImportError as e: - raise DeserializationError(f"Class '{serialized_chat_generator['type']}' not correctly imported") from e + raise DeserializationError(f"Class '{serialized_component['type']}' not correctly imported") from e - data[key] = component_from_dict(cls=chat_generator_class, data=serialized_chat_generator, name="chat_generator") + data[key] = component_from_dict(cls=component_class, data=serialized_component, name=key) diff --git a/releasenotes/notes/deserialisaation-in-place-d52a3bc54b9ea027.yaml b/releasenotes/notes/deserialisaation-in-place-d52a3bc54b9ea027.yaml new file mode 100644 index 000000000..356fb1357 --- /dev/null +++ b/releasenotes/notes/deserialisaation-in-place-d52a3bc54b9ea027.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Added a new `deserialize_component_inplace` function to handle generic component deserialization that works with any component type. diff --git a/test/utils/test_deserialization.py b/test/utils/test_deserialization.py index eb08791ab..44acece4f 100644 --- a/test/utils/test_deserialization.py +++ b/test/utils/test_deserialization.py @@ -8,7 +8,7 @@ import pytest from haystack.document_stores.in_memory.document_store import InMemoryDocumentStore from haystack.utils.deserialization import ( deserialize_document_store_in_init_params_inplace, - deserialize_chatgenerator_inplace, + deserialize_component_inplace, ) from haystack.core.errors import DeserializationError from haystack.components.generators.chat.openai import OpenAIChatGenerator @@ -97,38 +97,37 @@ class TestDeserializeDocumentStoreInInitParamsInplace: deserialize_document_store_in_init_params_inplace(data) -class TestDeserializeChatGeneratorInplace: - def test_deserialize_chatgenerator_inplace(self, monkeypatch): +class TestDeserializeComponentInplace: + def test_deserialize_component_inplace(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") chat_generator = OpenAIChatGenerator() data = {"chat_generator": chat_generator.to_dict()} - - deserialize_chatgenerator_inplace(data) + deserialize_component_inplace(data) assert isinstance(data["chat_generator"], OpenAIChatGenerator) assert data["chat_generator"].to_dict() == chat_generator.to_dict() - def test_missing_chat_generator_key(self): + def test_missing_key(self): data = {"some_key": "some_value"} with pytest.raises(DeserializationError): - deserialize_chatgenerator_inplace(data) + deserialize_component_inplace(data) - def test_chat_generator_is_not_a_dict(self): + def test_component_is_not_a_dict(self): data = {"chat_generator": "not_a_dict"} with pytest.raises(DeserializationError): - deserialize_chatgenerator_inplace(data) + deserialize_component_inplace(data) def test_type_key_missing(self): data = {"chat_generator": {"some_key": "some_value"}} with pytest.raises(DeserializationError): - deserialize_chatgenerator_inplace(data) + deserialize_component_inplace(data) def test_class_not_correctly_imported(self): data = {"chat_generator": {"type": "invalid.module.InvalidClass"}} with pytest.raises(DeserializationError): - deserialize_chatgenerator_inplace(data) + deserialize_component_inplace(data) - def test_chat_generator_no_from_dict_method(self): + def test_component_no_from_dict_method(self): chat_generator = ChatGeneratorWithoutFromDict() data = {"chat_generator": chat_generator.to_dict()} - deserialize_chatgenerator_inplace(data) + deserialize_component_inplace(data) assert isinstance(data["chat_generator"], ChatGeneratorWithoutFromDict)