mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-13 07:47:26 +00:00
chore: extract type serialization (#6586)
* move functions * tests * reno
This commit is contained in:
parent
2dd5a94b04
commit
f877704839
@ -1,13 +1,11 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from typing import List, Dict, Any, Set, get_origin
|
||||
from typing import List, Dict, Any, Set
|
||||
|
||||
from jinja2 import meta, Environment, TemplateSyntaxError
|
||||
from jinja2.nativetypes import NativeEnvironment
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, DeserializationError
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.utils.type_serialization import serialize_type, deserialize_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -20,117 +18,6 @@ class RouteConditionException(Exception):
|
||||
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
|
||||
|
||||
|
||||
def serialize_type(target: Any) -> str:
|
||||
"""
|
||||
Serializes a type or an instance to its string representation, including the module name.
|
||||
|
||||
This function handles types, instances of types, and special typing objects.
|
||||
It assumes that non-typing objects will have a '__name__' attribute and raises
|
||||
an error if a type cannot be serialized.
|
||||
|
||||
:param target: The object to serialize, can be an instance or a type.
|
||||
:type target: Any
|
||||
:return: The string representation of the type.
|
||||
:raises ValueError: If the type cannot be serialized.
|
||||
"""
|
||||
# If the target is a string and contains a dot, treat it as an already serialized type
|
||||
if isinstance(target, str) and "." in target:
|
||||
return target
|
||||
|
||||
# Determine if the target is a type or an instance of a typing object
|
||||
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
|
||||
type_obj = target if is_type_or_typing else type(target)
|
||||
module = inspect.getmodule(type_obj)
|
||||
type_obj_repr = repr(type_obj)
|
||||
|
||||
if type_obj_repr.startswith("typing."):
|
||||
# e.g., typing.List[int] -> List[int], we'll add the module below
|
||||
type_name = type_obj_repr.split(".", 1)[1]
|
||||
elif hasattr(type_obj, "__name__"):
|
||||
type_name = type_obj.__name__
|
||||
else:
|
||||
# If type cannot be serialized, raise an error
|
||||
raise ValueError(f"Could not serialize type: {type_obj_repr}")
|
||||
|
||||
# Construct the full path with module name if available
|
||||
if module and hasattr(module, "__name__"):
|
||||
if module.__name__ == "builtins":
|
||||
# omit the module name for builtins, it just clutters the output
|
||||
# e.g. instead of 'builtins.str', we'll just return 'str'
|
||||
full_path = type_name
|
||||
else:
|
||||
full_path = f"{module.__name__}.{type_name}"
|
||||
else:
|
||||
full_path = type_name
|
||||
|
||||
return full_path
|
||||
|
||||
|
||||
def deserialize_type(type_str: str) -> Any:
|
||||
"""
|
||||
Deserializes a type given its full import path as a string, including nested generic types.
|
||||
|
||||
This function will dynamically import the module if it's not already imported
|
||||
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
|
||||
|
||||
:param type_str: The string representation of the type's full import path.
|
||||
:return: The deserialized type object.
|
||||
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
|
||||
"""
|
||||
|
||||
def parse_generic_args(args_str):
|
||||
args = []
|
||||
bracket_count = 0
|
||||
current_arg = ""
|
||||
|
||||
for char in args_str:
|
||||
if char == "[":
|
||||
bracket_count += 1
|
||||
elif char == "]":
|
||||
bracket_count -= 1
|
||||
|
||||
if char == "," and bracket_count == 0:
|
||||
args.append(current_arg.strip())
|
||||
current_arg = ""
|
||||
else:
|
||||
current_arg += char
|
||||
|
||||
if current_arg:
|
||||
args.append(current_arg.strip())
|
||||
|
||||
return args
|
||||
|
||||
if "[" in type_str and type_str.endswith("]"):
|
||||
# Handle generics
|
||||
main_type_str, generics_str = type_str.split("[", 1)
|
||||
generics_str = generics_str[:-1]
|
||||
|
||||
main_type = deserialize_type(main_type_str)
|
||||
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
|
||||
|
||||
# Reconstruct
|
||||
return main_type[generic_args]
|
||||
|
||||
else:
|
||||
# Handle non-generics
|
||||
parts = type_str.split(".")
|
||||
module_name = ".".join(parts[:-1]) or "builtins"
|
||||
type_name = parts[-1]
|
||||
|
||||
module = sys.modules.get(module_name)
|
||||
if not module:
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
raise DeserializationError(f"Could not import the module: {module_name}") from e
|
||||
|
||||
deserialized_type = getattr(module, type_name, None)
|
||||
if not deserialized_type:
|
||||
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
|
||||
|
||||
return deserialized_type
|
||||
|
||||
|
||||
@component
|
||||
class ConditionalRouter:
|
||||
"""
|
||||
|
||||
117
haystack/utils/type_serialization.py
Normal file
117
haystack/utils/type_serialization.py
Normal file
@ -0,0 +1,117 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any, get_origin
|
||||
|
||||
from haystack import DeserializationError
|
||||
|
||||
|
||||
def serialize_type(target: Any) -> str:
|
||||
"""
|
||||
Serializes a type or an instance to its string representation, including the module name.
|
||||
|
||||
This function handles types, instances of types, and special typing objects.
|
||||
It assumes that non-typing objects will have a '__name__' attribute and raises
|
||||
an error if a type cannot be serialized.
|
||||
|
||||
:param target: The object to serialize, can be an instance or a type.
|
||||
:type target: Any
|
||||
:return: The string representation of the type.
|
||||
:raises ValueError: If the type cannot be serialized.
|
||||
"""
|
||||
# If the target is a string and contains a dot, treat it as an already serialized type
|
||||
if isinstance(target, str) and "." in target:
|
||||
return target
|
||||
|
||||
# Determine if the target is a type or an instance of a typing object
|
||||
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
|
||||
type_obj = target if is_type_or_typing else type(target)
|
||||
module = inspect.getmodule(type_obj)
|
||||
type_obj_repr = repr(type_obj)
|
||||
|
||||
if type_obj_repr.startswith("typing."):
|
||||
# e.g., typing.List[int] -> List[int], we'll add the module below
|
||||
type_name = type_obj_repr.split(".", 1)[1]
|
||||
elif hasattr(type_obj, "__name__"):
|
||||
type_name = type_obj.__name__
|
||||
else:
|
||||
# If type cannot be serialized, raise an error
|
||||
raise ValueError(f"Could not serialize type: {type_obj_repr}")
|
||||
|
||||
# Construct the full path with module name if available
|
||||
if module and hasattr(module, "__name__"):
|
||||
if module.__name__ == "builtins":
|
||||
# omit the module name for builtins, it just clutters the output
|
||||
# e.g. instead of 'builtins.str', we'll just return 'str'
|
||||
full_path = type_name
|
||||
else:
|
||||
full_path = f"{module.__name__}.{type_name}"
|
||||
else:
|
||||
full_path = type_name
|
||||
|
||||
return full_path
|
||||
|
||||
|
||||
def deserialize_type(type_str: str) -> Any:
|
||||
"""
|
||||
Deserializes a type given its full import path as a string, including nested generic types.
|
||||
|
||||
This function will dynamically import the module if it's not already imported
|
||||
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
|
||||
|
||||
:param type_str: The string representation of the type's full import path.
|
||||
:return: The deserialized type object.
|
||||
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
|
||||
"""
|
||||
|
||||
def parse_generic_args(args_str):
|
||||
args = []
|
||||
bracket_count = 0
|
||||
current_arg = ""
|
||||
|
||||
for char in args_str:
|
||||
if char == "[":
|
||||
bracket_count += 1
|
||||
elif char == "]":
|
||||
bracket_count -= 1
|
||||
|
||||
if char == "," and bracket_count == 0:
|
||||
args.append(current_arg.strip())
|
||||
current_arg = ""
|
||||
else:
|
||||
current_arg += char
|
||||
|
||||
if current_arg:
|
||||
args.append(current_arg.strip())
|
||||
|
||||
return args
|
||||
|
||||
if "[" in type_str and type_str.endswith("]"):
|
||||
# Handle generics
|
||||
main_type_str, generics_str = type_str.split("[", 1)
|
||||
generics_str = generics_str[:-1]
|
||||
|
||||
main_type = deserialize_type(main_type_str)
|
||||
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
|
||||
|
||||
# Reconstruct
|
||||
return main_type[generic_args]
|
||||
|
||||
else:
|
||||
# Handle non-generics
|
||||
parts = type_str.split(".")
|
||||
module_name = ".".join(parts[:-1]) or "builtins"
|
||||
type_name = parts[-1]
|
||||
|
||||
module = sys.modules.get(module_name)
|
||||
if not module:
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
raise DeserializationError(f"Could not import the module: {module_name}") from e
|
||||
|
||||
deserialized_type = getattr(module, type_name, None)
|
||||
if not deserialized_type:
|
||||
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
|
||||
|
||||
return deserialized_type
|
||||
@ -0,0 +1,3 @@
|
||||
enhancements:
|
||||
- |
|
||||
Move `serialize_type` and `deserialize_type` in the `utils` module.
|
||||
@ -1,12 +1,11 @@
|
||||
import copy
|
||||
import typing
|
||||
from typing import List, Dict
|
||||
from typing import List
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.components.routers import ConditionalRouter
|
||||
from haystack.components.routers.conditional_router import NoRouteSelectedException, serialize_type, deserialize_type
|
||||
from haystack.components.routers.conditional_router import NoRouteSelectedException
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
|
||||
@ -190,37 +189,6 @@ class TestRouter:
|
||||
with pytest.raises(ValueError):
|
||||
ConditionalRouter(routes)
|
||||
|
||||
def test_output_type_serialization(self):
|
||||
assert serialize_type(str) == "str"
|
||||
assert serialize_type(List[int]) == "typing.List[int]"
|
||||
assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
||||
assert serialize_type(ChatMessage) == "haystack.dataclasses.chat_message.ChatMessage"
|
||||
assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
||||
assert serialize_type(List[ChatMessage]) == "typing.List[haystack.dataclasses.chat_message.ChatMessage]"
|
||||
assert (
|
||||
serialize_type(typing.Dict[int, ChatMessage])
|
||||
== "typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]"
|
||||
)
|
||||
assert serialize_type(int) == "int"
|
||||
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage"
|
||||
|
||||
def test_output_type_deserialization(self):
|
||||
assert deserialize_type("str") == str
|
||||
assert deserialize_type("typing.List[int]") == typing.List[int]
|
||||
assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]]
|
||||
assert deserialize_type("typing.Dict[str, int]") == Dict[str, int]
|
||||
assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]]
|
||||
assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]]
|
||||
assert (
|
||||
deserialize_type("typing.List[haystack.dataclasses.chat_message.ChatMessage]") == typing.List[ChatMessage]
|
||||
)
|
||||
assert (
|
||||
deserialize_type("typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]")
|
||||
== typing.Dict[int, ChatMessage]
|
||||
)
|
||||
assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage
|
||||
assert deserialize_type("int") == int
|
||||
|
||||
def test_router_de_serialization(self):
|
||||
routes = [
|
||||
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
|
||||
|
||||
37
test/utils/test_type_serialization.py
Normal file
37
test/utils/test_type_serialization.py
Normal file
@ -0,0 +1,37 @@
|
||||
import copy
|
||||
import typing
|
||||
from typing import List, Dict
|
||||
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.components.routers.conditional_router import serialize_type, deserialize_type
|
||||
|
||||
|
||||
def test_output_type_serialization():
|
||||
assert serialize_type(str) == "str"
|
||||
assert serialize_type(List[int]) == "typing.List[int]"
|
||||
assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
||||
assert serialize_type(ChatMessage) == "haystack.dataclasses.chat_message.ChatMessage"
|
||||
assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
||||
assert serialize_type(List[ChatMessage]) == "typing.List[haystack.dataclasses.chat_message.ChatMessage]"
|
||||
assert (
|
||||
serialize_type(typing.Dict[int, ChatMessage])
|
||||
== "typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]"
|
||||
)
|
||||
assert serialize_type(int) == "int"
|
||||
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage"
|
||||
|
||||
|
||||
def test_output_type_deserialization():
|
||||
assert deserialize_type("str") == str
|
||||
assert deserialize_type("typing.List[int]") == typing.List[int]
|
||||
assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]]
|
||||
assert deserialize_type("typing.Dict[str, int]") == Dict[str, int]
|
||||
assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]]
|
||||
assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]]
|
||||
assert deserialize_type("typing.List[haystack.dataclasses.chat_message.ChatMessage]") == typing.List[ChatMessage]
|
||||
assert (
|
||||
deserialize_type("typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]")
|
||||
== typing.Dict[int, ChatMessage]
|
||||
)
|
||||
assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage
|
||||
assert deserialize_type("int") == int
|
||||
Loading…
x
Reference in New Issue
Block a user