diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index b1ef5fedf..3552d37bd 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -15,6 +15,7 @@ import networkx # type:ignore from haystack import logging from haystack.core.component import Component, InputSocket, OutputSocket, component from haystack.core.errors import ( + DeserializationError, PipelineConnectError, PipelineDrawingError, PipelineError, @@ -179,7 +180,18 @@ class PipelineBase: # Create a new one component_class = component.registry[component_data["type"]] - instance = component_from_dict(component_class, component_data, name, callbacks) + + try: + instance = component_from_dict(component_class, component_data, name, callbacks) + except Exception as e: + msg = ( + f"Couldn't deserialize component '{name}' of class '{component_class.__name__}' " + f"with the following data: {str(component_data)}. Possible reasons include " + "malformed serialized data, mismatch between the serialized component and the " + "loaded one (due to a breaking change, see " + "https://github.com/deepset-ai/haystack/releases), etc." + ) + raise DeserializationError(msg) from e pipe.add_component(name=name, instance=instance) for connection in data.get("connections", []): @@ -229,10 +241,20 @@ class PipelineBase: The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. :param callbacks: Callbacks to invoke during deserialization. + :raises DeserializationError: + If an error occurs during deserialization. :returns: A `Pipeline` object. """ - return cls.from_dict(marshaller.unmarshal(data), callbacks) + try: + deserialized_data = marshaller.unmarshal(data) + except Exception as e: + raise DeserializationError( + "Error while unmarshalling serialized pipeline data. This is usually " + "caused by malformed or invalid syntax in the serialized representation." + ) from e + + return cls.from_dict(deserialized_data, callbacks) @classmethod def load( @@ -253,10 +275,12 @@ class PipelineBase: The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. :param callbacks: Callbacks to invoke during deserialization. + :raises DeserializationError: + If an error occurs during deserialization. :returns: A `Pipeline` object. """ - return cls.from_dict(marshaller.unmarshal(fp.read()), callbacks) + return cls.loads(fp.read(), marshaller, callbacks) def add_component(self, name: str, instance: Component) -> None: """ diff --git a/releasenotes/notes/better-deserialization-errors-f2b0e226534f4cd2.yaml b/releasenotes/notes/better-deserialization-errors-f2b0e226534f4cd2.yaml new file mode 100644 index 000000000..af2ea1986 --- /dev/null +++ b/releasenotes/notes/better-deserialization-errors-f2b0e226534f4cd2.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Improved error messages for deserialization errors. diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 9c38c6d09..ae8fd34e5 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -13,7 +13,7 @@ from haystack.components.joiners import BranchJoiner from haystack.components.others import Multiplexer from haystack.core.component import component from haystack.core.component.types import InputSocket, OutputSocket, Variadic -from haystack.core.errors import PipelineConnectError, PipelineDrawingError, PipelineError +from haystack.core.errors import DeserializationError, PipelineConnectError, PipelineDrawingError, PipelineError from haystack.core.pipeline import Pipeline, PredefinedPipeline from haystack.core.pipeline.base import ( _add_missing_input_defaults, @@ -73,6 +73,45 @@ class TestPipeline: assert isinstance(pipeline.get_component("Comp1"), FakeComponent) assert isinstance(pipeline.get_component("Comp2"), FakeComponent) + def test_pipeline_loads_invalid_data(self): + invalid_yaml = """components: + Comp1: + init_parameters: + an_init_param: null + type: test.core.pipeline.test_pipeline.FakeComponent + Comp2* + init_parameters: + an_init_param: null + type: test.core.pipeline.test_pipeline.FakeComponent + connections: + * receiver: Comp2.input_ + sender: Comp1.value + max_loops_allowed: 99 + metadata: + """ + + with pytest.raises(DeserializationError, match="unmarshalling serialized"): + pipeline = Pipeline.loads(invalid_yaml) + + invalid_init_parameter_yaml = """components: + Comp1: + init_parameters: + unknown: null + type: test.core.pipeline.test_pipeline.FakeComponent + Comp2: + init_parameters: + an_init_param: null + type: test.core.pipeline.test_pipeline.FakeComponent + connections: + - receiver: Comp2.input_ + sender: Comp1.value + max_loops_allowed: 99 + metadata: {} + """ + + with pytest.raises(DeserializationError, match=".*Comp1.*unknown.*"): + pipeline = Pipeline.loads(invalid_init_parameter_yaml) + def test_pipeline_dump(self, test_files_path, tmp_path): pipeline = Pipeline() pipeline.add_component("Comp1", FakeComponent("Foo"))