haystack/haystack/utils/base_serialization.py
Amna Mubashar 050c987946
chore: remove backward compatibility for State deserialization (#9585)
* remove backward compatability

* Fix linting
2025-07-03 13:20:34 +02:00

331 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict
from haystack.core.errors import DeserializationError, SerializationError
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
def serialize_class_instance(obj: Any) -> Dict[str, Any]:
"""
Serializes an object that has a `to_dict` method into a dictionary.
:param obj:
The object to be serialized.
:returns:
A dictionary representation of the object.
:raises SerializationError:
If the object does not have a `to_dict` method.
"""
if not hasattr(obj, "to_dict"):
raise SerializationError(f"Object of class '{type(obj).__name__}' does not have a 'to_dict' method")
output = obj.to_dict()
return {"type": generate_qualified_class_name(type(obj)), "data": output}
def deserialize_class_instance(data: Dict[str, Any]) -> Any:
"""
Deserializes an object from a dictionary representation generated by `auto_serialize_class_instance`.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized object.
:raises DeserializationError:
If the serialization data is malformed, the class type cannot be imported, or the
class does not have a `from_dict` method.
"""
if "type" not in data:
raise DeserializationError("Missing 'type' in serialization data")
if "data" not in data:
raise DeserializationError("Missing 'data' in serialization data")
try:
obj_class = import_class_by_name(data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{data['type']}' not correctly imported") from e
if not hasattr(obj_class, "from_dict"):
raise DeserializationError(f"Class '{data['type']}' does not have a 'from_dict' method")
return obj_class.from_dict(data["data"])
def _serialize_value_with_schema(payload: Any) -> Dict[str, Any]:
"""
Serializes a value into a schema-aware format suitable for storage or transmission.
The output format separates the schema information from the actual data, making it easier
to deserialize complex nested structures correctly.
The function handles:
- Objects with to_dict() methods (e.g. dataclasses)
- Objects with __dict__ attributes
- Dictionaries
- Lists, tuples, and sets. Lists with mixed types are not supported.
- Primitive types (str, int, float, bool, None)
:param payload: The value to serialize (can be any type)
:returns: The serialized dict representation of the given value. Contains two keys:
- "serialization_schema": Contains type information for each field.
- "serialized_data": Contains the actual data in a simplified format.
"""
# Handle dictionary case - iterate through fields
if isinstance(payload, dict):
schema: Dict[str, Any] = {}
data: Dict[str, Any] = {}
for field, val in payload.items():
# Recursively serialize each field
serialized_value = _serialize_value_with_schema(val)
schema[field] = serialized_value["serialization_schema"]
data[field] = serialized_value["serialized_data"]
return {"serialization_schema": {"type": "object", "properties": schema}, "serialized_data": data}
# Handle array case - iterate through elements
elif isinstance(payload, (list, tuple, set)):
# Convert to list for consistent handling
pure_list = _convert_to_basic_types(list(payload))
# Determine item type from first element (if any)
if payload:
first = next(iter(payload))
item_schema = _serialize_value_with_schema(first)
base_schema = {"type": "array", "items": item_schema["serialization_schema"]}
else:
base_schema = {"type": "array", "items": {}}
# Add JSON Schema properties to infer sets and tuples
if isinstance(payload, set):
base_schema["uniqueItems"] = True
elif isinstance(payload, tuple):
base_schema["minItems"] = len(payload)
base_schema["maxItems"] = len(payload)
return {"serialization_schema": base_schema, "serialized_data": pure_list}
# Handle Haystack style objects (e.g. dataclasses and Components)
elif hasattr(payload, "to_dict") and callable(payload.to_dict):
type_name = generate_qualified_class_name(type(payload))
pure = _convert_to_basic_types(payload)
schema = {"type": type_name}
return {"serialization_schema": schema, "serialized_data": pure}
# Handle arbitrary objects with __dict__
elif hasattr(payload, "__dict__"):
type_name = generate_qualified_class_name(type(payload))
pure = _convert_to_basic_types(vars(payload))
schema = {"type": type_name}
return {"serialization_schema": schema, "serialized_data": pure}
# Handle primitives
else:
prim_type = _primitive_schema_type(payload)
schema = {"type": prim_type}
return {"serialization_schema": schema, "serialized_data": payload}
def _primitive_schema_type(value: Any) -> str:
"""
Helper function to determine the schema type for primitive values.
"""
if value is None:
return "null"
if isinstance(value, bool):
return "boolean"
if isinstance(value, int):
return "integer"
if isinstance(value, float):
return "number"
if isinstance(value, str):
return "string"
return "string" # fallback
def _convert_to_basic_types(value: Any) -> Any:
"""
Helper function to recursively convert complex Python objects into their basic type equivalents.
This helper function traverses through nested data structures and converts all complex
objects (custom classes, dataclasses, etc.) into basic Python types (dict, list, str,
int, float, bool, None) that can be easily serialized.
The function handles:
- Objects with to_dict() methods: converted using their to_dict implementation
- Objects with __dict__ attribute: converted to plain dictionaries
- Dictionaries: recursively converted values while preserving keys
- Sequences (list, tuple, set): recursively converted while preserving type
- Primitive types: returned as-is
"""
# dataclassstyle objects
if hasattr(value, "to_dict") and callable(value.to_dict):
return _convert_to_basic_types(value.to_dict())
# arbitrary objects with __dict__
if hasattr(value, "__dict__"):
return {k: _convert_to_basic_types(v) for k, v in vars(value).items()}
# dicts
if isinstance(value, dict):
return {k: _convert_to_basic_types(v) for k, v in value.items()}
# sequences
if isinstance(value, (list, tuple, set)):
return [_convert_to_basic_types(v) for v in value]
# primitive
return value
def _deserialize_value_with_schema(serialized: Dict[str, Any]) -> Any: # pylint: disable=too-many-return-statements, # noqa: PLR0911, PLR0912
"""
Deserializes a value with schema information back to its original form.
Takes a dict of the form:
{
"serialization_schema": {"type": "integer"} or {"type": "object", "properties": {...}},
"serialized_data": <the actual data>
}
:param serialized: The serialized dict with schema and data.
:returns: The deserialized value in its original form.
"""
if not serialized or "serialization_schema" not in serialized or "serialized_data" not in serialized:
raise DeserializationError(
f"Invalid format of passed serialized payload. Expected a dictionary with keys "
f"'serialization_schema' and 'serialized_data'. Got: {serialized}"
)
schema = serialized["serialization_schema"]
data = serialized["serialized_data"]
schema_type = schema.get("type")
if not schema_type:
# for backward comaptability till Haystack 2.16 we use legacy implementation
raise DeserializationError(
"Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized "
"State object created with a version of Haystack older than 2.15.0. "
"Support for the old serialization format is removed in Haystack 2.16.0. "
"Please upgrade to the new serialization format to ensure forward compatibility."
)
# Handle object case (dictionary with properties)
if schema_type == "object":
properties = schema.get("properties")
if properties:
result: Dict[str, Any] = {}
if isinstance(data, dict):
for field, raw_value in data.items():
field_schema = properties.get(field)
if field_schema:
# Recursively deserialize each field - avoid creating temporary dict
result[field] = _deserialize_value_with_schema(
{"serialization_schema": field_schema, "serialized_data": raw_value}
)
return result
else:
return _deserialize_value(data)
# Handle array case
elif schema_type == "array":
# Cache frequently accessed schema properties
item_schema = schema.get("items", {})
item_type = item_schema.get("type", "any")
is_set = schema.get("uniqueItems") is True
is_tuple = schema.get("minItems") is not None and schema.get("maxItems") is not None
# Handle nested objects/arrays first (most complex case)
if item_type in ("object", "array"):
return [
_deserialize_value_with_schema({"serialization_schema": item_schema, "serialized_data": item})
for item in data
]
# Helper function to deserialize individual items
def deserialize_item(item):
if item_type == "any":
return _deserialize_value(item)
else:
return _deserialize_value({"type": item_type, "data": item})
# Handle different collection types
if is_set:
return {deserialize_item(item) for item in data}
elif is_tuple:
return tuple(deserialize_item(item) for item in data)
else:
return [deserialize_item(item) for item in data]
# Handle primitive types
elif schema_type in ("null", "boolean", "integer", "number", "string"):
return data
# Handle custom class types
else:
return _deserialize_value({"type": schema_type, "data": data})
def _deserialize_value(value: Any) -> Any: # pylint: disable=too-many-return-statements # noqa: PLR0911
"""
Helper function to deserialize values from their envelope format {"type": T, "data": D}.
Handles four cases:
- Typed envelopes: {"type": T, "data": D} where T determines deserialization method
- Plain dicts: recursively deserialize values
- Collections (list/tuple/set): recursively deserialize elements
- Other values: return as-is
:param value: The value to deserialize
:returns: The deserialized value
"""
# 1) Envelope case
if isinstance(value, dict) and "type" in value and "data" in value:
t = value["type"]
payload = value["data"]
# 1.a) Array
if t == "array":
return [_deserialize_value(child) for child in payload]
# 1.b) Generic object/dict
if t == "object":
return {k: _deserialize_value(v) for k, v in payload.items()}
# 1.c) Primitive
if t in ("null", "boolean", "integer", "number", "string"):
return payload
# 1.d) Custom class
cls = import_class_by_name(t)
# first, recursively deserialize the inner payload
deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()}
# try from_dict
if hasattr(cls, "from_dict") and callable(cls.from_dict):
return cls.from_dict(deserialized_payload)
# fallback: set attributes on a blank instance
instance = cls.__new__(cls)
for attr_name, attr_value in deserialized_payload.items():
setattr(instance, attr_name, attr_value)
return instance
# 2) Plain dict (no envelope) → recurse
if isinstance(value, dict):
return {k: _deserialize_value(v) for k, v in value.items()}
# 3) Collections → recurse
if isinstance(value, (list, tuple, set)):
return type(value)(_deserialize_value(v) for v in value)
# 4) Fallback (shouldn't usually happen with our schema)
return value