haystack/test/utils/test_base_serialization.py
Sebastian Husch Lee 512dd86d97
feat: Add serialization and deserialization of Enum type when creating a PipelineSnaphsot (#9869)
* refactor tests

* Test refactoring and add failing test for enum

* Remove redundant method

* Slight refactoring

* refactoring

* simplification of _deserialize_value_with_schema and _deserialize_value

* Add some more TODOs

* Add support for enum serialization and deserialization

* types

* Add reno

* fix linting

* PR comments

* Add warning message

* dev comment
2025-10-14 10:28:19 +00:00

427 lines
16 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
import pytest
from haystack.core.errors import DeserializationError, SerializationError
from haystack.dataclasses import ChatMessage, Document, GeneratedAnswer
from haystack.utils.base_serialization import (
_deserialize_value_with_schema,
_serialize_value_with_schema,
deserialize_class_instance,
serialize_class_instance,
)
class CustomEnum(Enum):
ONE = "one"
TWO = "two"
class CustomClass:
def to_dict(self):
return {"key": "value", "more": False}
@classmethod
def from_dict(cls, data):
assert data == {"key": "value", "more": False}
return cls()
class CustomClassNoToDict:
@classmethod
def from_dict(cls, data):
assert data == {"key": "value", "more": False}
return cls()
class CustomClassNoFromDict:
def to_dict(self):
return {"key": "value", "more": False}
def simple_calc_function(x: int) -> int:
return x * 2
def test_serialize_class_instance():
result = serialize_class_instance(CustomClass())
assert result == {"data": {"key": "value", "more": False}, "type": "test_base_serialization.CustomClass"}
def test_serialize_class_instance_missing_method():
with pytest.raises(SerializationError, match="does not have a 'to_dict' method"):
serialize_class_instance(CustomClassNoToDict())
def test_deserialize_class_instance():
data = {"data": {"key": "value", "more": False}, "type": "test_base_serialization.CustomClass"}
result = deserialize_class_instance(data)
assert isinstance(result, CustomClass)
def test_deserialize_class_instance_invalid_data():
with pytest.raises(DeserializationError, match="Missing 'type'"):
deserialize_class_instance({})
with pytest.raises(DeserializationError, match="Missing 'data'"):
deserialize_class_instance({"type": "test_base_serialization.CustomClass"})
with pytest.raises(
DeserializationError, match="Class 'test_base_serialization.CustomClass1' not correctly imported"
):
deserialize_class_instance({"type": "test_base_serialization.CustomClass1", "data": {}})
with pytest.raises(
DeserializationError,
match="Class 'test_base_serialization.CustomClassNoFromDict' does not have a 'from_dict' method",
):
deserialize_class_instance({"type": "test_base_serialization.CustomClassNoFromDict", "data": {}})
@pytest.mark.parametrize(
"value,result",
[
# integer
(1, {"serialization_schema": {"type": "integer"}, "serialized_data": 1}),
# float
(1.5, {"serialization_schema": {"type": "number"}, "serialized_data": 1.5}),
# string
("test", {"serialization_schema": {"type": "string"}, "serialized_data": "test"}),
# boolean
(True, {"serialization_schema": {"type": "boolean"}, "serialized_data": True}),
(False, {"serialization_schema": {"type": "boolean"}, "serialized_data": False}),
# None
(None, {"serialization_schema": {"type": "null"}, "serialized_data": None}),
],
)
def test_serialize_and_deserialize_primitive_types(value, result):
assert _serialize_value_with_schema(value) == result
assert _deserialize_value_with_schema(result) == value
@pytest.mark.parametrize(
"value,result",
[
# empty dict
({}, {"serialization_schema": {"type": "object", "properties": {}}, "serialized_data": {}}),
# empty list
([], {"serialization_schema": {"type": "array", "items": {}}, "serialized_data": []}),
# empty tuple
(
(),
{
"serialization_schema": {"type": "array", "items": {}, "minItems": 0, "maxItems": 0},
"serialized_data": [],
},
),
# empty set
(set(), {"serialization_schema": {"type": "array", "items": {}, "uniqueItems": True}, "serialized_data": []}),
# nested empty structures
(
{"empty_list": [], "empty_dict": {}, "nested_empty": {"empty": []}},
{
"serialization_schema": {
"type": "object",
"properties": {
"empty_list": {"type": "array", "items": {}},
"empty_dict": {"type": "object", "properties": {}},
"nested_empty": {"type": "object", "properties": {"empty": {"type": "array", "items": {}}}},
},
},
"serialized_data": {"empty_list": [], "empty_dict": {}, "nested_empty": {"empty": []}},
},
),
],
)
def test_serializing_and_deserializing_empty_structures(value, result):
assert _serialize_value_with_schema(value) == result
assert _deserialize_value_with_schema(result) == value
@pytest.mark.parametrize(
"value,result",
[
# list
(
[1, 2, 3],
{"serialization_schema": {"type": "array", "items": {"type": "integer"}}, "serialized_data": [1, 2, 3]},
),
# set
(
{1, 2, 3},
{
"serialization_schema": {"type": "array", "items": {"type": "integer"}, "uniqueItems": True},
"serialized_data": [1, 2, 3],
},
),
# tuple
(
(1, 2, 3),
{
"serialization_schema": {"type": "array", "items": {"type": "integer"}, "minItems": 3, "maxItems": 3},
"serialized_data": [1, 2, 3],
},
),
# nested list
(
[[1, 2], [3, 4]],
{
"serialization_schema": {"type": "array", "items": {"type": "array", "items": {"type": "integer"}}},
"serialized_data": [[1, 2], [3, 4]],
},
),
# list of set
(
[{1, 2}, {3, 4}],
{
"serialization_schema": {
"items": {"items": {"type": "integer"}, "type": "array", "uniqueItems": True},
"type": "array",
},
"serialized_data": [[1, 2], [3, 4]],
},
),
# nested tuple
(
((1, 2), (3, 4), (5, 6)),
{
"serialization_schema": {
"type": "array",
"items": {"type": "array", "items": {"type": "integer"}, "minItems": 2, "maxItems": 2},
"minItems": 3,
"maxItems": 3,
},
"serialized_data": [[1, 2], [3, 4], [5, 6]],
},
),
# nested list of GeneratedAnswer
(
[
[
GeneratedAnswer(
data="Paris",
query="What is the capital of France?",
documents=[Document(content="Paris is the capital of France", id="1")],
meta={"page": 1},
)
],
[
GeneratedAnswer(
data="Berlin",
query="What is the capital of Germany?",
documents=[Document(content="Berlin is the capital of Germany", id="2")],
meta={"page": 1},
)
],
],
{
"serialization_schema": {
"type": "array",
"items": {"type": "array", "items": {"type": "haystack.dataclasses.answer.GeneratedAnswer"}},
},
"serialized_data": [
[
{
"type": "haystack.dataclasses.answer.GeneratedAnswer",
"init_parameters": {
"data": "Paris",
"query": "What is the capital of France?",
"documents": [
{
"id": "1",
"content": "Paris is the capital of France",
"blob": None,
"meta": {},
"score": None,
"embedding": None,
"sparse_embedding": None,
}
],
"meta": {"page": 1},
},
}
],
[
{
"type": "haystack.dataclasses.answer.GeneratedAnswer",
"init_parameters": {
"data": "Berlin",
"query": "What is the capital of Germany?",
"documents": [
{
"id": "2",
"content": "Berlin is the capital of Germany",
"blob": None,
"meta": {},
"score": None,
"embedding": None,
"sparse_embedding": None,
}
],
"meta": {"page": 1},
},
}
],
],
},
),
],
)
def test_serialize_and_deserialize_sequence_types(value, result):
assert _serialize_value_with_schema(value) == result
assert _deserialize_value_with_schema(result) == value
def test_serialize_and_deserialize_nested_dicts():
data = {"key1": {"nested1": "value1", "nested2": {"deep": "value2"}}}
expected = {
"serialization_schema": {
"type": "object",
"properties": {
"key1": {
"type": "object",
"properties": {
"nested1": {"type": "string"},
"nested2": {"type": "object", "properties": {"deep": {"type": "string"}}},
},
}
},
},
"serialized_data": {"key1": {"nested1": "value1", "nested2": {"deep": "value2"}}},
}
assert _serialize_value_with_schema(data) == expected
assert _deserialize_value_with_schema(expected) == data
def test_serialize_and_deserialize_value_with_schema_with_various_types():
data = {
"numbers": 1,
"messages": [ChatMessage.from_user(text="Hello, world!"), ChatMessage.from_assistant(text="Hello, world!")],
"user_id": "123",
"dict_of_lists": {"numbers": [1, 2, 3]},
"documents": [Document(content="Hello, world!", id="1")],
"list_of_dicts": [{"numbers": [1, 2, 3]}],
"answers": [
GeneratedAnswer(
data="Paris",
query="What is the capital of France?",
documents=[Document(content="Paris is the capital of France", id="2")],
meta={"page": 1},
)
],
}
expected = {
"serialization_schema": {
"type": "object",
"properties": {
"numbers": {"type": "integer"},
"messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}},
"user_id": {"type": "string"},
"dict_of_lists": {
"type": "object",
"properties": {"numbers": {"type": "array", "items": {"type": "integer"}}},
},
"documents": {"type": "array", "items": {"type": "haystack.dataclasses.document.Document"}},
"list_of_dicts": {
"type": "array",
"items": {
"type": "object",
"properties": {"numbers": {"type": "array", "items": {"type": "integer"}}},
},
},
"answers": {"type": "array", "items": {"type": "haystack.dataclasses.answer.GeneratedAnswer"}},
},
},
"serialized_data": {
"numbers": 1,
"messages": [
{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]},
{"role": "assistant", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]},
],
"user_id": "123",
"dict_of_lists": {"numbers": [1, 2, 3]},
"documents": [
{
"id": "1",
"content": "Hello, world!",
"blob": None,
"score": None,
"embedding": None,
"sparse_embedding": None,
}
],
"list_of_dicts": [{"numbers": [1, 2, 3]}],
"answers": [
{
"type": "haystack.dataclasses.answer.GeneratedAnswer",
"init_parameters": {
"data": "Paris",
"query": "What is the capital of France?",
"documents": [
{
"id": "2",
"content": "Paris is the capital of France",
"blob": None,
"meta": {},
"score": None,
"embedding": None,
"sparse_embedding": None,
}
],
"meta": {"page": 1},
},
}
],
},
}
assert _serialize_value_with_schema(data) == expected
assert _deserialize_value_with_schema(expected) == data
def test_serializing_and_deserializing_custom_class_type():
custom_type = CustomClass()
data = {"numbers": 1, "custom_type": custom_type}
serialized_data = _serialize_value_with_schema(data)
assert serialized_data == {
"serialization_schema": {
"properties": {
"custom_type": {"type": "test_base_serialization.CustomClass"},
"numbers": {"type": "integer"},
},
"type": "object",
},
"serialized_data": {"numbers": 1, "custom_type": {"key": "value", "more": False}},
}
deserialized_data = _deserialize_value_with_schema(serialized_data)
assert deserialized_data["numbers"] == 1
assert isinstance(deserialized_data["custom_type"], CustomClass)
def test_serialize_and_deserialize_value_with_callable():
expected = {
"serialization_schema": {"type": "typing.Callable"},
"serialized_data": "test_base_serialization.simple_calc_function",
}
assert _serialize_value_with_schema(simple_calc_function) == expected
assert _deserialize_value_with_schema(expected) == simple_calc_function
def test_serialize_and_deserialize_value_with_enum():
data = CustomEnum.ONE
expected = {"serialization_schema": {"type": "test_base_serialization.CustomEnum"}, "serialized_data": "ONE"}
assert _serialize_value_with_schema(data) == expected
assert _deserialize_value_with_schema(expected) == data
def test_deserialize_value_with_wrong_value():
with pytest.raises(DeserializationError, match="Value 'NOT_VALID' is not a valid member of Enum"):
_deserialize_value_with_schema(
{"serialization_schema": {"type": "test_base_serialization.CustomEnum"}, "serialized_data": "NOT_VALID"}
)