refactor: Improve error messages shown during pipeline deserialization (#8016)

* refactor: Improve error messages shown during pipeline deserialization

* Add link to release notes

* Update release notes link
This commit is contained in:
Madeesh Kannan 2024-07-12 16:47:00 +02:00 committed by GitHub
parent 1f05e633a9
commit 94b806815c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 4 deletions

View File

@ -15,6 +15,7 @@ import networkx # type:ignore
from haystack import logging
from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.errors import (
DeserializationError,
PipelineConnectError,
PipelineDrawingError,
PipelineError,
@ -179,7 +180,18 @@ class PipelineBase:
# Create a new one
component_class = component.registry[component_data["type"]]
instance = component_from_dict(component_class, component_data, name, callbacks)
try:
instance = component_from_dict(component_class, component_data, name, callbacks)
except Exception as e:
msg = (
f"Couldn't deserialize component '{name}' of class '{component_class.__name__}' "
f"with the following data: {str(component_data)}. Possible reasons include "
"malformed serialized data, mismatch between the serialized component and the "
"loaded one (due to a breaking change, see "
"https://github.com/deepset-ai/haystack/releases), etc."
)
raise DeserializationError(msg) from e
pipe.add_component(name=name, instance=instance)
for connection in data.get("connections", []):
@ -229,10 +241,20 @@ class PipelineBase:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
:param callbacks:
Callbacks to invoke during deserialization.
:raises DeserializationError:
If an error occurs during deserialization.
:returns:
A `Pipeline` object.
"""
return cls.from_dict(marshaller.unmarshal(data), callbacks)
try:
deserialized_data = marshaller.unmarshal(data)
except Exception as e:
raise DeserializationError(
"Error while unmarshalling serialized pipeline data. This is usually "
"caused by malformed or invalid syntax in the serialized representation."
) from e
return cls.from_dict(deserialized_data, callbacks)
@classmethod
def load(
@ -253,10 +275,12 @@ class PipelineBase:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
:param callbacks:
Callbacks to invoke during deserialization.
:raises DeserializationError:
If an error occurs during deserialization.
:returns:
A `Pipeline` object.
"""
return cls.from_dict(marshaller.unmarshal(fp.read()), callbacks)
return cls.loads(fp.read(), marshaller, callbacks)
def add_component(self, name: str, instance: Component) -> None:
"""

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Improved error messages for deserialization errors.

View File

@ -13,7 +13,7 @@ from haystack.components.joiners import BranchJoiner
from haystack.components.others import Multiplexer
from haystack.core.component import component
from haystack.core.component.types import InputSocket, OutputSocket, Variadic
from haystack.core.errors import PipelineConnectError, PipelineDrawingError, PipelineError
from haystack.core.errors import DeserializationError, PipelineConnectError, PipelineDrawingError, PipelineError
from haystack.core.pipeline import Pipeline, PredefinedPipeline
from haystack.core.pipeline.base import (
_add_missing_input_defaults,
@ -73,6 +73,45 @@ class TestPipeline:
assert isinstance(pipeline.get_component("Comp1"), FakeComponent)
assert isinstance(pipeline.get_component("Comp2"), FakeComponent)
def test_pipeline_loads_invalid_data(self):
invalid_yaml = """components:
Comp1:
init_parameters:
an_init_param: null
type: test.core.pipeline.test_pipeline.FakeComponent
Comp2*
init_parameters:
an_init_param: null
type: test.core.pipeline.test_pipeline.FakeComponent
connections:
* receiver: Comp2.input_
sender: Comp1.value
max_loops_allowed: 99
metadata:
"""
with pytest.raises(DeserializationError, match="unmarshalling serialized"):
pipeline = Pipeline.loads(invalid_yaml)
invalid_init_parameter_yaml = """components:
Comp1:
init_parameters:
unknown: null
type: test.core.pipeline.test_pipeline.FakeComponent
Comp2:
init_parameters:
an_init_param: null
type: test.core.pipeline.test_pipeline.FakeComponent
connections:
- receiver: Comp2.input_
sender: Comp1.value
max_loops_allowed: 99
metadata: {}
"""
with pytest.raises(DeserializationError, match=".*Comp1.*unknown.*"):
pipeline = Pipeline.loads(invalid_init_parameter_yaml)
def test_pipeline_dump(self, test_files_path, tmp_path):
pipeline = Pipeline()
pipeline.add_component("Comp1", FakeComponent("Foo"))