haystack/test/utils/test_base_serialization.py

68 lines
2.3 KiB
Python
Raw Normal View History

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from haystack.core.errors import DeserializationError, SerializationError
import pytest
from haystack.utils.base_serialization import serialize_class_instance, deserialize_class_instance
class CustomClass:
def to_dict(self):
return {"key": "value", "more": False}
@classmethod
def from_dict(cls, data):
assert data == {"key": "value", "more": False}
return cls()
class CustomClassNoToDict:
@classmethod
def from_dict(cls, data):
assert data == {"key": "value", "more": False}
return cls()
class CustomClassNoFromDict:
def to_dict(self):
return {"key": "value", "more": False}
def test_serialize_class_instance():
result = serialize_class_instance(CustomClass())
assert result == {"data": {"key": "value", "more": False}, "type": "test_base_serialization.CustomClass"}
def test_serialize_class_instance_missing_method():
with pytest.raises(SerializationError, match="does not have a 'to_dict' method"):
serialize_class_instance(CustomClassNoToDict())
def test_deserialize_class_instance():
data = {"data": {"key": "value", "more": False}, "type": "test_base_serialization.CustomClass"}
result = deserialize_class_instance(data)
assert isinstance(result, CustomClass)
def test_deserialize_class_instance_invalid_data():
data = {"data": {"key": "value", "more": False}, "type": "test_base_serialization.CustomClass"}
with pytest.raises(DeserializationError, match="Missing 'type'"):
deserialize_class_instance({})
with pytest.raises(DeserializationError, match="Missing 'data'"):
deserialize_class_instance({"type": "test_base_serialization.CustomClass"})
with pytest.raises(
DeserializationError, match="Class 'test_base_serialization.CustomClass1' not correctly imported"
):
deserialize_class_instance({"type": "test_base_serialization.CustomClass1", "data": {}})
with pytest.raises(
DeserializationError,
match="Class 'test_base_serialization.CustomClassNoFromDict' does not have a 'from_dict' method",
):
deserialize_class_instance({"type": "test_base_serialization.CustomClassNoFromDict", "data": {}})