mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
fix: Enforce basic Python types restriction on serialized component data (#8473)
This commit is contained in:
parent
a556e11bf1
commit
906177329b
@ -111,7 +111,7 @@ class PipelineBase:
|
||||
"""
|
||||
components = {}
|
||||
for name, instance in self.graph.nodes(data="instance"): # type:ignore
|
||||
components[name] = component_to_dict(instance)
|
||||
components[name] = component_to_dict(instance, name)
|
||||
|
||||
connections = []
|
||||
for sender, receiver, edge_data in self.graph.edges.data():
|
||||
|
||||
@ -6,7 +6,7 @@ import inspect
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Iterable, Optional, Type
|
||||
|
||||
from haystack.core.component.component import _hook_component_init, logger
|
||||
from haystack.core.errors import DeserializationError, SerializationError
|
||||
@ -30,7 +30,7 @@ class DeserializationCallbacks:
|
||||
component_pre_init: Optional[Callable] = None
|
||||
|
||||
|
||||
def component_to_dict(obj: Any) -> Dict[str, Any]:
|
||||
def component_to_dict(obj: Any, name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Converts a component instance into a dictionary.
|
||||
|
||||
@ -38,37 +38,82 @@ def component_to_dict(obj: Any) -> Dict[str, Any]:
|
||||
|
||||
:param obj:
|
||||
The component to be serialized.
|
||||
:param name:
|
||||
The name of the component.
|
||||
:returns:
|
||||
A dictionary representation of the component.
|
||||
|
||||
:raises SerializationError:
|
||||
If the component doesn't have a `to_dict` method and the values of the init parameters can't be determined.
|
||||
If the component doesn't have a `to_dict` method.
|
||||
If the values of the init parameters can't be determined.
|
||||
If a non-basic Python type is used in the serialized data.
|
||||
"""
|
||||
if hasattr(obj, "to_dict"):
|
||||
return obj.to_dict()
|
||||
data = obj.to_dict()
|
||||
else:
|
||||
init_parameters = {}
|
||||
for param_name, param in inspect.signature(obj.__init__).parameters.items():
|
||||
# Ignore `args` and `kwargs`, used by the default constructor
|
||||
if param_name in ("args", "kwargs"):
|
||||
continue
|
||||
try:
|
||||
# This only works if the Component constructor assigns the init
|
||||
# parameter to an instance variable or property with the same name
|
||||
param_value = getattr(obj, param_name)
|
||||
except AttributeError as e:
|
||||
# If the parameter doesn't have a default value, raise an error
|
||||
if param.default is param.empty:
|
||||
raise SerializationError(
|
||||
f"Cannot determine the value of the init parameter '{param_name}' "
|
||||
f"for the class {obj.__class__.__name__}."
|
||||
f"You can fix this error by assigning 'self.{param_name} = {param_name}' or adding a "
|
||||
f"custom serialization method 'to_dict' to the class."
|
||||
) from e
|
||||
# In case the init parameter was not assigned, we use the default value
|
||||
param_value = param.default
|
||||
init_parameters[param_name] = param_value
|
||||
|
||||
init_parameters = {}
|
||||
for name, param in inspect.signature(obj.__init__).parameters.items():
|
||||
# Ignore `args` and `kwargs`, used by the default constructor
|
||||
if name in ("args", "kwargs"):
|
||||
continue
|
||||
try:
|
||||
# This only works if the Component constructor assigns the init
|
||||
# parameter to an instance variable or property with the same name
|
||||
param_value = getattr(obj, name)
|
||||
except AttributeError as e:
|
||||
# If the parameter doesn't have a default value, raise an error
|
||||
if param.default is param.empty:
|
||||
data = default_to_dict(obj, **init_parameters)
|
||||
|
||||
_validate_component_to_dict_output(obj, name, data)
|
||||
return data
|
||||
|
||||
|
||||
def _validate_component_to_dict_output(component: Any, name: str, data: Dict[str, Any]) -> None:
|
||||
# Ensure that only basic Python types are used in the serde data.
|
||||
def is_allowed_type(obj: Any) -> bool:
|
||||
return isinstance(obj, (str, int, float, bool, list, dict, set, tuple, type(None)))
|
||||
|
||||
def check_iterable(l: Iterable[Any]):
|
||||
for v in l:
|
||||
if not is_allowed_type(v):
|
||||
raise SerializationError(
|
||||
f"Cannot determine the value of the init parameter '{name}' for the class {obj.__class__.__name__}."
|
||||
f"You can fix this error by assigning 'self.{name} = {name}' or adding a "
|
||||
f"custom serialization method 'to_dict' to the class."
|
||||
) from e
|
||||
# In case the init parameter was not assigned, we use the default value
|
||||
param_value = param.default
|
||||
init_parameters[name] = param_value
|
||||
f"Component '{name}' of type '{type(component).__name__}' has an unsupported value "
|
||||
f"of type '{type(v).__name__}' in the serialized data."
|
||||
)
|
||||
if isinstance(v, (list, set, tuple)):
|
||||
check_iterable(v)
|
||||
elif isinstance(v, dict):
|
||||
check_dict(v)
|
||||
|
||||
return default_to_dict(obj, **init_parameters)
|
||||
def check_dict(d: Dict[str, Any]):
|
||||
if any(not isinstance(k, str) for k in data.keys()):
|
||||
raise SerializationError(
|
||||
f"Component '{name}' of type '{type(component).__name__}' has a non-string key in the serialized data."
|
||||
)
|
||||
|
||||
for k, v in d.items():
|
||||
if not is_allowed_type(v):
|
||||
raise SerializationError(
|
||||
f"Component '{name}' of type '{type(component).__name__}' has an unsupported value "
|
||||
f"of type '{type(v).__name__}' in the serialized data under key '{k}'."
|
||||
)
|
||||
if isinstance(v, (list, set, tuple)):
|
||||
check_iterable(v)
|
||||
elif isinstance(v, dict):
|
||||
check_dict(v)
|
||||
|
||||
check_dict(data)
|
||||
|
||||
|
||||
def generate_qualified_class_name(cls: Type[object]) -> str:
|
||||
|
||||
@ -14,14 +14,33 @@ class YamlLoader(yaml.SafeLoader): # pylint: disable=too-many-ancestors
|
||||
return tuple(self.construct_sequence(node))
|
||||
|
||||
|
||||
class YamlDumper(yaml.SafeDumper): # pylint: disable=too-many-ancestors
|
||||
def represent_tuple(self, data: tuple):
|
||||
"""Represent a Python tuple."""
|
||||
return self.represent_sequence("tag:yaml.org,2002:python/tuple", data)
|
||||
|
||||
|
||||
YamlDumper.add_representer(tuple, YamlDumper.represent_tuple)
|
||||
YamlLoader.add_constructor("tag:yaml.org,2002:python/tuple", YamlLoader.construct_python_tuple)
|
||||
|
||||
|
||||
class YamlMarshaller:
|
||||
def marshal(self, dict_: Dict[str, Any]) -> str:
|
||||
"""Return a YAML representation of the given dictionary."""
|
||||
return yaml.dump(dict_)
|
||||
try:
|
||||
return yaml.dump(dict_, Dumper=YamlDumper)
|
||||
except yaml.representer.RepresenterError as e:
|
||||
raise TypeError(
|
||||
"Error dumping pipeline to YAML - Ensure that all pipeline "
|
||||
"components only serialize basic Python types"
|
||||
) from e
|
||||
|
||||
def unmarshal(self, data_: Union[str, bytes, bytearray]) -> Dict[str, Any]:
|
||||
"""Return a dictionary from the given YAML data."""
|
||||
return yaml.load(data_, Loader=YamlLoader)
|
||||
try:
|
||||
return yaml.load(data_, Loader=YamlLoader)
|
||||
except yaml.constructor.ConstructorError as e:
|
||||
raise TypeError(
|
||||
"Error loading pipeline from YAML - Ensure that all pipeline "
|
||||
"components only serialize basic Python types"
|
||||
) from e
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Improved serialization/deserialization errors to provide extra context about the delinquent components when possible.
|
||||
|
||||
fixes:
|
||||
- |
|
||||
Serialized data of components are now explicitly enforced to be one of the following basic Python datatypes: str, int, float, bool, list, dict, set, tuple or None.
|
||||
@ -10,13 +10,14 @@ import pytest
|
||||
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.component import component
|
||||
from haystack.core.errors import DeserializationError
|
||||
from haystack.core.errors import DeserializationError, SerializationError
|
||||
from haystack.testing import factory
|
||||
from haystack.core.serialization import (
|
||||
default_to_dict,
|
||||
default_from_dict,
|
||||
generate_qualified_class_name,
|
||||
import_class_by_name,
|
||||
component_to_dict,
|
||||
)
|
||||
|
||||
|
||||
@ -106,3 +107,25 @@ def test_import_class_by_name_no_valid_class():
|
||||
data = "some.invalid.class"
|
||||
with pytest.raises(ImportError):
|
||||
import_class_by_name(data)
|
||||
|
||||
|
||||
class CustomData:
|
||||
def __init__(self, a: int, b: str) -> None:
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
|
||||
@component()
|
||||
class UnserializableClass:
|
||||
def __init__(self, a: int, b: str, c: CustomData) -> None:
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.c = c
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_component_to_dict_invalid_type():
|
||||
with pytest.raises(SerializationError, match="unsupported value of type 'CustomData'"):
|
||||
component_to_dict(UnserializableClass(1, "s", CustomData(99, "aa")), "invalid_component")
|
||||
|
||||
@ -6,11 +6,23 @@ import pytest
|
||||
from haystack.marshal.yaml import YamlMarshaller
|
||||
|
||||
|
||||
class InvalidClass:
|
||||
def __init__(self) -> None:
|
||||
self.a = 1
|
||||
self.b = None
|
||||
self.c = "string"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def yaml_data():
|
||||
return {"key": "value", 1: 0.221, "list": [1, 2, 3], "tuple": (1, None, True), "dict": {"set": {False}}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_yaml_data():
|
||||
return {"key": "value", 1: 0.221, "list": [1, 2, 3], "tuple": (1, InvalidClass(), True), "dict": {"set": {False}}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serialized_yaml_str():
|
||||
return """key: value
|
||||
@ -36,6 +48,12 @@ def test_yaml_marshal(yaml_data, serialized_yaml_str):
|
||||
assert marshalled.strip().replace("\n", "") == serialized_yaml_str.strip().replace("\n", "")
|
||||
|
||||
|
||||
def test_yaml_marshal_invalid_type(invalid_yaml_data):
|
||||
with pytest.raises(TypeError, match="basic Python types"):
|
||||
marshaller = YamlMarshaller()
|
||||
marshalled = marshaller.marshal(invalid_yaml_data)
|
||||
|
||||
|
||||
def test_yaml_unmarshal(yaml_data, serialized_yaml_str):
|
||||
marshaller = YamlMarshaller()
|
||||
unmarshalled = marshaller.unmarshal(serialized_yaml_str)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user