mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-05 08:01:02 +00:00
331 lines
12 KiB
Python
331 lines
12 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||
#
|
||
# SPDX-License-Identifier: Apache-2.0
|
||
|
||
from typing import Any, Dict
|
||
|
||
from haystack.core.errors import DeserializationError, SerializationError
|
||
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
|
||
|
||
|
||
def serialize_class_instance(obj: Any) -> Dict[str, Any]:
|
||
"""
|
||
Serializes an object that has a `to_dict` method into a dictionary.
|
||
|
||
:param obj:
|
||
The object to be serialized.
|
||
:returns:
|
||
A dictionary representation of the object.
|
||
:raises SerializationError:
|
||
If the object does not have a `to_dict` method.
|
||
"""
|
||
if not hasattr(obj, "to_dict"):
|
||
raise SerializationError(f"Object of class '{type(obj).__name__}' does not have a 'to_dict' method")
|
||
|
||
output = obj.to_dict()
|
||
return {"type": generate_qualified_class_name(type(obj)), "data": output}
|
||
|
||
|
||
def deserialize_class_instance(data: Dict[str, Any]) -> Any:
|
||
"""
|
||
Deserializes an object from a dictionary representation generated by `auto_serialize_class_instance`.
|
||
|
||
:param data:
|
||
The dictionary to deserialize from.
|
||
:returns:
|
||
The deserialized object.
|
||
:raises DeserializationError:
|
||
If the serialization data is malformed, the class type cannot be imported, or the
|
||
class does not have a `from_dict` method.
|
||
"""
|
||
if "type" not in data:
|
||
raise DeserializationError("Missing 'type' in serialization data")
|
||
if "data" not in data:
|
||
raise DeserializationError("Missing 'data' in serialization data")
|
||
|
||
try:
|
||
obj_class = import_class_by_name(data["type"])
|
||
except ImportError as e:
|
||
raise DeserializationError(f"Class '{data['type']}' not correctly imported") from e
|
||
|
||
if not hasattr(obj_class, "from_dict"):
|
||
raise DeserializationError(f"Class '{data['type']}' does not have a 'from_dict' method")
|
||
|
||
return obj_class.from_dict(data["data"])
|
||
|
||
|
||
def _serialize_value_with_schema(payload: Any) -> Dict[str, Any]:
|
||
"""
|
||
Serializes a value into a schema-aware format suitable for storage or transmission.
|
||
|
||
The output format separates the schema information from the actual data, making it easier
|
||
to deserialize complex nested structures correctly.
|
||
|
||
The function handles:
|
||
- Objects with to_dict() methods (e.g. dataclasses)
|
||
- Objects with __dict__ attributes
|
||
- Dictionaries
|
||
- Lists, tuples, and sets. Lists with mixed types are not supported.
|
||
- Primitive types (str, int, float, bool, None)
|
||
|
||
:param payload: The value to serialize (can be any type)
|
||
:returns: The serialized dict representation of the given value. Contains two keys:
|
||
- "serialization_schema": Contains type information for each field.
|
||
- "serialized_data": Contains the actual data in a simplified format.
|
||
|
||
"""
|
||
# Handle dictionary case - iterate through fields
|
||
if isinstance(payload, dict):
|
||
schema: Dict[str, Any] = {}
|
||
data: Dict[str, Any] = {}
|
||
|
||
for field, val in payload.items():
|
||
# Recursively serialize each field
|
||
serialized_value = _serialize_value_with_schema(val)
|
||
schema[field] = serialized_value["serialization_schema"]
|
||
data[field] = serialized_value["serialized_data"]
|
||
|
||
return {"serialization_schema": {"type": "object", "properties": schema}, "serialized_data": data}
|
||
|
||
# Handle array case - iterate through elements
|
||
elif isinstance(payload, (list, tuple, set)):
|
||
# Convert to list for consistent handling
|
||
pure_list = _convert_to_basic_types(list(payload))
|
||
|
||
# Determine item type from first element (if any)
|
||
if payload:
|
||
first = next(iter(payload))
|
||
item_schema = _serialize_value_with_schema(first)
|
||
base_schema = {"type": "array", "items": item_schema["serialization_schema"]}
|
||
else:
|
||
base_schema = {"type": "array", "items": {}}
|
||
|
||
# Add JSON Schema properties to infer sets and tuples
|
||
if isinstance(payload, set):
|
||
base_schema["uniqueItems"] = True
|
||
elif isinstance(payload, tuple):
|
||
base_schema["minItems"] = len(payload)
|
||
base_schema["maxItems"] = len(payload)
|
||
|
||
return {"serialization_schema": base_schema, "serialized_data": pure_list}
|
||
|
||
# Handle Haystack style objects (e.g. dataclasses and Components)
|
||
elif hasattr(payload, "to_dict") and callable(payload.to_dict):
|
||
type_name = generate_qualified_class_name(type(payload))
|
||
pure = _convert_to_basic_types(payload)
|
||
schema = {"type": type_name}
|
||
return {"serialization_schema": schema, "serialized_data": pure}
|
||
|
||
# Handle arbitrary objects with __dict__
|
||
elif hasattr(payload, "__dict__"):
|
||
type_name = generate_qualified_class_name(type(payload))
|
||
pure = _convert_to_basic_types(vars(payload))
|
||
schema = {"type": type_name}
|
||
return {"serialization_schema": schema, "serialized_data": pure}
|
||
|
||
# Handle primitives
|
||
else:
|
||
prim_type = _primitive_schema_type(payload)
|
||
schema = {"type": prim_type}
|
||
return {"serialization_schema": schema, "serialized_data": payload}
|
||
|
||
|
||
def _primitive_schema_type(value: Any) -> str:
|
||
"""
|
||
Helper function to determine the schema type for primitive values.
|
||
"""
|
||
if value is None:
|
||
return "null"
|
||
if isinstance(value, bool):
|
||
return "boolean"
|
||
if isinstance(value, int):
|
||
return "integer"
|
||
if isinstance(value, float):
|
||
return "number"
|
||
if isinstance(value, str):
|
||
return "string"
|
||
return "string" # fallback
|
||
|
||
|
||
def _convert_to_basic_types(value: Any) -> Any:
|
||
"""
|
||
Helper function to recursively convert complex Python objects into their basic type equivalents.
|
||
|
||
This helper function traverses through nested data structures and converts all complex
|
||
objects (custom classes, dataclasses, etc.) into basic Python types (dict, list, str,
|
||
int, float, bool, None) that can be easily serialized.
|
||
|
||
The function handles:
|
||
- Objects with to_dict() methods: converted using their to_dict implementation
|
||
- Objects with __dict__ attribute: converted to plain dictionaries
|
||
- Dictionaries: recursively converted values while preserving keys
|
||
- Sequences (list, tuple, set): recursively converted while preserving type
|
||
- Primitive types: returned as-is
|
||
|
||
"""
|
||
# dataclass‐style objects
|
||
if hasattr(value, "to_dict") and callable(value.to_dict):
|
||
return _convert_to_basic_types(value.to_dict())
|
||
|
||
# arbitrary objects with __dict__
|
||
if hasattr(value, "__dict__"):
|
||
return {k: _convert_to_basic_types(v) for k, v in vars(value).items()}
|
||
|
||
# dicts
|
||
if isinstance(value, dict):
|
||
return {k: _convert_to_basic_types(v) for k, v in value.items()}
|
||
|
||
# sequences
|
||
if isinstance(value, (list, tuple, set)):
|
||
return [_convert_to_basic_types(v) for v in value]
|
||
|
||
# primitive
|
||
return value
|
||
|
||
|
||
def _deserialize_value_with_schema(serialized: Dict[str, Any]) -> Any: # pylint: disable=too-many-return-statements, # noqa: PLR0911, PLR0912
|
||
"""
|
||
Deserializes a value with schema information back to its original form.
|
||
|
||
Takes a dict of the form:
|
||
{
|
||
"serialization_schema": {"type": "integer"} or {"type": "object", "properties": {...}},
|
||
"serialized_data": <the actual data>
|
||
}
|
||
|
||
:param serialized: The serialized dict with schema and data.
|
||
:returns: The deserialized value in its original form.
|
||
"""
|
||
|
||
if not serialized or "serialization_schema" not in serialized or "serialized_data" not in serialized:
|
||
raise DeserializationError(
|
||
f"Invalid format of passed serialized payload. Expected a dictionary with keys "
|
||
f"'serialization_schema' and 'serialized_data'. Got: {serialized}"
|
||
)
|
||
schema = serialized["serialization_schema"]
|
||
data = serialized["serialized_data"]
|
||
|
||
schema_type = schema.get("type")
|
||
|
||
if not schema_type:
|
||
# for backward comaptability till Haystack 2.16 we use legacy implementation
|
||
raise DeserializationError(
|
||
"Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized "
|
||
"State object created with a version of Haystack older than 2.15.0. "
|
||
"Support for the old serialization format is removed in Haystack 2.16.0. "
|
||
"Please upgrade to the new serialization format to ensure forward compatibility."
|
||
)
|
||
|
||
# Handle object case (dictionary with properties)
|
||
if schema_type == "object":
|
||
properties = schema.get("properties")
|
||
if properties:
|
||
result: Dict[str, Any] = {}
|
||
|
||
if isinstance(data, dict):
|
||
for field, raw_value in data.items():
|
||
field_schema = properties.get(field)
|
||
if field_schema:
|
||
# Recursively deserialize each field - avoid creating temporary dict
|
||
result[field] = _deserialize_value_with_schema(
|
||
{"serialization_schema": field_schema, "serialized_data": raw_value}
|
||
)
|
||
|
||
return result
|
||
else:
|
||
return _deserialize_value(data)
|
||
|
||
# Handle array case
|
||
elif schema_type == "array":
|
||
# Cache frequently accessed schema properties
|
||
item_schema = schema.get("items", {})
|
||
item_type = item_schema.get("type", "any")
|
||
is_set = schema.get("uniqueItems") is True
|
||
is_tuple = schema.get("minItems") is not None and schema.get("maxItems") is not None
|
||
|
||
# Handle nested objects/arrays first (most complex case)
|
||
if item_type in ("object", "array"):
|
||
return [
|
||
_deserialize_value_with_schema({"serialization_schema": item_schema, "serialized_data": item})
|
||
for item in data
|
||
]
|
||
|
||
# Helper function to deserialize individual items
|
||
def deserialize_item(item):
|
||
if item_type == "any":
|
||
return _deserialize_value(item)
|
||
else:
|
||
return _deserialize_value({"type": item_type, "data": item})
|
||
|
||
# Handle different collection types
|
||
if is_set:
|
||
return {deserialize_item(item) for item in data}
|
||
elif is_tuple:
|
||
return tuple(deserialize_item(item) for item in data)
|
||
else:
|
||
return [deserialize_item(item) for item in data]
|
||
|
||
# Handle primitive types
|
||
elif schema_type in ("null", "boolean", "integer", "number", "string"):
|
||
return data
|
||
|
||
# Handle custom class types
|
||
else:
|
||
return _deserialize_value({"type": schema_type, "data": data})
|
||
|
||
|
||
def _deserialize_value(value: Any) -> Any: # pylint: disable=too-many-return-statements # noqa: PLR0911
|
||
"""
|
||
Helper function to deserialize values from their envelope format {"type": T, "data": D}.
|
||
|
||
Handles four cases:
|
||
- Typed envelopes: {"type": T, "data": D} where T determines deserialization method
|
||
- Plain dicts: recursively deserialize values
|
||
- Collections (list/tuple/set): recursively deserialize elements
|
||
- Other values: return as-is
|
||
|
||
:param value: The value to deserialize
|
||
:returns: The deserialized value
|
||
|
||
"""
|
||
# 1) Envelope case
|
||
if isinstance(value, dict) and "type" in value and "data" in value:
|
||
t = value["type"]
|
||
payload = value["data"]
|
||
|
||
# 1.a) Array
|
||
if t == "array":
|
||
return [_deserialize_value(child) for child in payload]
|
||
|
||
# 1.b) Generic object/dict
|
||
if t == "object":
|
||
return {k: _deserialize_value(v) for k, v in payload.items()}
|
||
|
||
# 1.c) Primitive
|
||
if t in ("null", "boolean", "integer", "number", "string"):
|
||
return payload
|
||
|
||
# 1.d) Custom class
|
||
cls = import_class_by_name(t)
|
||
# first, recursively deserialize the inner payload
|
||
deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()}
|
||
# try from_dict
|
||
if hasattr(cls, "from_dict") and callable(cls.from_dict):
|
||
return cls.from_dict(deserialized_payload)
|
||
# fallback: set attributes on a blank instance
|
||
instance = cls.__new__(cls)
|
||
for attr_name, attr_value in deserialized_payload.items():
|
||
setattr(instance, attr_name, attr_value)
|
||
return instance
|
||
|
||
# 2) Plain dict (no envelope) → recurse
|
||
if isinstance(value, dict):
|
||
return {k: _deserialize_value(v) for k, v in value.items()}
|
||
|
||
# 3) Collections → recurse
|
||
if isinstance(value, (list, tuple, set)):
|
||
return type(value)(_deserialize_value(v) for v in value)
|
||
|
||
# 4) Fallback (shouldn't usually happen with our schema)
|
||
return value
|