mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-18 02:28:36 +00:00
chore: extract type serialization (#6586)
* move functions * tests * reno
This commit is contained in:
parent
2dd5a94b04
commit
f877704839
@ -1,13 +1,11 @@
|
|||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
from typing import List, Dict, Any, Set
|
||||||
from typing import List, Dict, Any, Set, get_origin
|
|
||||||
|
|
||||||
from jinja2 import meta, Environment, TemplateSyntaxError
|
from jinja2 import meta, Environment, TemplateSyntaxError
|
||||||
from jinja2.nativetypes import NativeEnvironment
|
from jinja2.nativetypes import NativeEnvironment
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, DeserializationError
|
from haystack import component, default_from_dict, default_to_dict
|
||||||
|
from haystack.utils.type_serialization import serialize_type, deserialize_type
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -20,117 +18,6 @@ class RouteConditionException(Exception):
|
|||||||
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
|
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
|
||||||
|
|
||||||
|
|
||||||
def serialize_type(target: Any) -> str:
|
|
||||||
"""
|
|
||||||
Serializes a type or an instance to its string representation, including the module name.
|
|
||||||
|
|
||||||
This function handles types, instances of types, and special typing objects.
|
|
||||||
It assumes that non-typing objects will have a '__name__' attribute and raises
|
|
||||||
an error if a type cannot be serialized.
|
|
||||||
|
|
||||||
:param target: The object to serialize, can be an instance or a type.
|
|
||||||
:type target: Any
|
|
||||||
:return: The string representation of the type.
|
|
||||||
:raises ValueError: If the type cannot be serialized.
|
|
||||||
"""
|
|
||||||
# If the target is a string and contains a dot, treat it as an already serialized type
|
|
||||||
if isinstance(target, str) and "." in target:
|
|
||||||
return target
|
|
||||||
|
|
||||||
# Determine if the target is a type or an instance of a typing object
|
|
||||||
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
|
|
||||||
type_obj = target if is_type_or_typing else type(target)
|
|
||||||
module = inspect.getmodule(type_obj)
|
|
||||||
type_obj_repr = repr(type_obj)
|
|
||||||
|
|
||||||
if type_obj_repr.startswith("typing."):
|
|
||||||
# e.g., typing.List[int] -> List[int], we'll add the module below
|
|
||||||
type_name = type_obj_repr.split(".", 1)[1]
|
|
||||||
elif hasattr(type_obj, "__name__"):
|
|
||||||
type_name = type_obj.__name__
|
|
||||||
else:
|
|
||||||
# If type cannot be serialized, raise an error
|
|
||||||
raise ValueError(f"Could not serialize type: {type_obj_repr}")
|
|
||||||
|
|
||||||
# Construct the full path with module name if available
|
|
||||||
if module and hasattr(module, "__name__"):
|
|
||||||
if module.__name__ == "builtins":
|
|
||||||
# omit the module name for builtins, it just clutters the output
|
|
||||||
# e.g. instead of 'builtins.str', we'll just return 'str'
|
|
||||||
full_path = type_name
|
|
||||||
else:
|
|
||||||
full_path = f"{module.__name__}.{type_name}"
|
|
||||||
else:
|
|
||||||
full_path = type_name
|
|
||||||
|
|
||||||
return full_path
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_type(type_str: str) -> Any:
|
|
||||||
"""
|
|
||||||
Deserializes a type given its full import path as a string, including nested generic types.
|
|
||||||
|
|
||||||
This function will dynamically import the module if it's not already imported
|
|
||||||
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
|
|
||||||
|
|
||||||
:param type_str: The string representation of the type's full import path.
|
|
||||||
:return: The deserialized type object.
|
|
||||||
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def parse_generic_args(args_str):
|
|
||||||
args = []
|
|
||||||
bracket_count = 0
|
|
||||||
current_arg = ""
|
|
||||||
|
|
||||||
for char in args_str:
|
|
||||||
if char == "[":
|
|
||||||
bracket_count += 1
|
|
||||||
elif char == "]":
|
|
||||||
bracket_count -= 1
|
|
||||||
|
|
||||||
if char == "," and bracket_count == 0:
|
|
||||||
args.append(current_arg.strip())
|
|
||||||
current_arg = ""
|
|
||||||
else:
|
|
||||||
current_arg += char
|
|
||||||
|
|
||||||
if current_arg:
|
|
||||||
args.append(current_arg.strip())
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
if "[" in type_str and type_str.endswith("]"):
|
|
||||||
# Handle generics
|
|
||||||
main_type_str, generics_str = type_str.split("[", 1)
|
|
||||||
generics_str = generics_str[:-1]
|
|
||||||
|
|
||||||
main_type = deserialize_type(main_type_str)
|
|
||||||
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
|
|
||||||
|
|
||||||
# Reconstruct
|
|
||||||
return main_type[generic_args]
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Handle non-generics
|
|
||||||
parts = type_str.split(".")
|
|
||||||
module_name = ".".join(parts[:-1]) or "builtins"
|
|
||||||
type_name = parts[-1]
|
|
||||||
|
|
||||||
module = sys.modules.get(module_name)
|
|
||||||
if not module:
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
except ImportError as e:
|
|
||||||
raise DeserializationError(f"Could not import the module: {module_name}") from e
|
|
||||||
|
|
||||||
deserialized_type = getattr(module, type_name, None)
|
|
||||||
if not deserialized_type:
|
|
||||||
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
|
|
||||||
|
|
||||||
return deserialized_type
|
|
||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
class ConditionalRouter:
|
class ConditionalRouter:
|
||||||
"""
|
"""
|
||||||
|
|||||||
117
haystack/utils/type_serialization.py
Normal file
117
haystack/utils/type_serialization.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
from typing import Any, get_origin
|
||||||
|
|
||||||
|
from haystack import DeserializationError
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_type(target: Any) -> str:
|
||||||
|
"""
|
||||||
|
Serializes a type or an instance to its string representation, including the module name.
|
||||||
|
|
||||||
|
This function handles types, instances of types, and special typing objects.
|
||||||
|
It assumes that non-typing objects will have a '__name__' attribute and raises
|
||||||
|
an error if a type cannot be serialized.
|
||||||
|
|
||||||
|
:param target: The object to serialize, can be an instance or a type.
|
||||||
|
:type target: Any
|
||||||
|
:return: The string representation of the type.
|
||||||
|
:raises ValueError: If the type cannot be serialized.
|
||||||
|
"""
|
||||||
|
# If the target is a string and contains a dot, treat it as an already serialized type
|
||||||
|
if isinstance(target, str) and "." in target:
|
||||||
|
return target
|
||||||
|
|
||||||
|
# Determine if the target is a type or an instance of a typing object
|
||||||
|
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
|
||||||
|
type_obj = target if is_type_or_typing else type(target)
|
||||||
|
module = inspect.getmodule(type_obj)
|
||||||
|
type_obj_repr = repr(type_obj)
|
||||||
|
|
||||||
|
if type_obj_repr.startswith("typing."):
|
||||||
|
# e.g., typing.List[int] -> List[int], we'll add the module below
|
||||||
|
type_name = type_obj_repr.split(".", 1)[1]
|
||||||
|
elif hasattr(type_obj, "__name__"):
|
||||||
|
type_name = type_obj.__name__
|
||||||
|
else:
|
||||||
|
# If type cannot be serialized, raise an error
|
||||||
|
raise ValueError(f"Could not serialize type: {type_obj_repr}")
|
||||||
|
|
||||||
|
# Construct the full path with module name if available
|
||||||
|
if module and hasattr(module, "__name__"):
|
||||||
|
if module.__name__ == "builtins":
|
||||||
|
# omit the module name for builtins, it just clutters the output
|
||||||
|
# e.g. instead of 'builtins.str', we'll just return 'str'
|
||||||
|
full_path = type_name
|
||||||
|
else:
|
||||||
|
full_path = f"{module.__name__}.{type_name}"
|
||||||
|
else:
|
||||||
|
full_path = type_name
|
||||||
|
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_type(type_str: str) -> Any:
|
||||||
|
"""
|
||||||
|
Deserializes a type given its full import path as a string, including nested generic types.
|
||||||
|
|
||||||
|
This function will dynamically import the module if it's not already imported
|
||||||
|
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
|
||||||
|
|
||||||
|
:param type_str: The string representation of the type's full import path.
|
||||||
|
:return: The deserialized type object.
|
||||||
|
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse_generic_args(args_str):
|
||||||
|
args = []
|
||||||
|
bracket_count = 0
|
||||||
|
current_arg = ""
|
||||||
|
|
||||||
|
for char in args_str:
|
||||||
|
if char == "[":
|
||||||
|
bracket_count += 1
|
||||||
|
elif char == "]":
|
||||||
|
bracket_count -= 1
|
||||||
|
|
||||||
|
if char == "," and bracket_count == 0:
|
||||||
|
args.append(current_arg.strip())
|
||||||
|
current_arg = ""
|
||||||
|
else:
|
||||||
|
current_arg += char
|
||||||
|
|
||||||
|
if current_arg:
|
||||||
|
args.append(current_arg.strip())
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
if "[" in type_str and type_str.endswith("]"):
|
||||||
|
# Handle generics
|
||||||
|
main_type_str, generics_str = type_str.split("[", 1)
|
||||||
|
generics_str = generics_str[:-1]
|
||||||
|
|
||||||
|
main_type = deserialize_type(main_type_str)
|
||||||
|
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
|
||||||
|
|
||||||
|
# Reconstruct
|
||||||
|
return main_type[generic_args]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Handle non-generics
|
||||||
|
parts = type_str.split(".")
|
||||||
|
module_name = ".".join(parts[:-1]) or "builtins"
|
||||||
|
type_name = parts[-1]
|
||||||
|
|
||||||
|
module = sys.modules.get(module_name)
|
||||||
|
if not module:
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
except ImportError as e:
|
||||||
|
raise DeserializationError(f"Could not import the module: {module_name}") from e
|
||||||
|
|
||||||
|
deserialized_type = getattr(module, type_name, None)
|
||||||
|
if not deserialized_type:
|
||||||
|
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
|
||||||
|
|
||||||
|
return deserialized_type
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Move `serialize_type` and `deserialize_type` in the `utils` module.
|
||||||
@ -1,12 +1,11 @@
|
|||||||
import copy
|
import copy
|
||||||
import typing
|
from typing import List
|
||||||
from typing import List, Dict
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.components.routers import ConditionalRouter
|
from haystack.components.routers import ConditionalRouter
|
||||||
from haystack.components.routers.conditional_router import NoRouteSelectedException, serialize_type, deserialize_type
|
from haystack.components.routers.conditional_router import NoRouteSelectedException
|
||||||
from haystack.dataclasses import ChatMessage
|
from haystack.dataclasses import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
@ -190,37 +189,6 @@ class TestRouter:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ConditionalRouter(routes)
|
ConditionalRouter(routes)
|
||||||
|
|
||||||
def test_output_type_serialization(self):
|
|
||||||
assert serialize_type(str) == "str"
|
|
||||||
assert serialize_type(List[int]) == "typing.List[int]"
|
|
||||||
assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
|
||||||
assert serialize_type(ChatMessage) == "haystack.dataclasses.chat_message.ChatMessage"
|
|
||||||
assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
|
||||||
assert serialize_type(List[ChatMessage]) == "typing.List[haystack.dataclasses.chat_message.ChatMessage]"
|
|
||||||
assert (
|
|
||||||
serialize_type(typing.Dict[int, ChatMessage])
|
|
||||||
== "typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]"
|
|
||||||
)
|
|
||||||
assert serialize_type(int) == "int"
|
|
||||||
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage"
|
|
||||||
|
|
||||||
def test_output_type_deserialization(self):
|
|
||||||
assert deserialize_type("str") == str
|
|
||||||
assert deserialize_type("typing.List[int]") == typing.List[int]
|
|
||||||
assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]]
|
|
||||||
assert deserialize_type("typing.Dict[str, int]") == Dict[str, int]
|
|
||||||
assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]]
|
|
||||||
assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]]
|
|
||||||
assert (
|
|
||||||
deserialize_type("typing.List[haystack.dataclasses.chat_message.ChatMessage]") == typing.List[ChatMessage]
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
deserialize_type("typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]")
|
|
||||||
== typing.Dict[int, ChatMessage]
|
|
||||||
)
|
|
||||||
assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage
|
|
||||||
assert deserialize_type("int") == int
|
|
||||||
|
|
||||||
def test_router_de_serialization(self):
|
def test_router_de_serialization(self):
|
||||||
routes = [
|
routes = [
|
||||||
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
|
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},
|
||||||
|
|||||||
37
test/utils/test_type_serialization.py
Normal file
37
test/utils/test_type_serialization.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import copy
|
||||||
|
import typing
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from haystack.dataclasses import ChatMessage
|
||||||
|
from haystack.components.routers.conditional_router import serialize_type, deserialize_type
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_type_serialization():
|
||||||
|
assert serialize_type(str) == "str"
|
||||||
|
assert serialize_type(List[int]) == "typing.List[int]"
|
||||||
|
assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
||||||
|
assert serialize_type(ChatMessage) == "haystack.dataclasses.chat_message.ChatMessage"
|
||||||
|
assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
|
||||||
|
assert serialize_type(List[ChatMessage]) == "typing.List[haystack.dataclasses.chat_message.ChatMessage]"
|
||||||
|
assert (
|
||||||
|
serialize_type(typing.Dict[int, ChatMessage])
|
||||||
|
== "typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]"
|
||||||
|
)
|
||||||
|
assert serialize_type(int) == "int"
|
||||||
|
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage"
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_type_deserialization():
|
||||||
|
assert deserialize_type("str") == str
|
||||||
|
assert deserialize_type("typing.List[int]") == typing.List[int]
|
||||||
|
assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]]
|
||||||
|
assert deserialize_type("typing.Dict[str, int]") == Dict[str, int]
|
||||||
|
assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]]
|
||||||
|
assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]]
|
||||||
|
assert deserialize_type("typing.List[haystack.dataclasses.chat_message.ChatMessage]") == typing.List[ChatMessage]
|
||||||
|
assert (
|
||||||
|
deserialize_type("typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]")
|
||||||
|
== typing.Dict[int, ChatMessage]
|
||||||
|
)
|
||||||
|
assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage
|
||||||
|
assert deserialize_type("int") == int
|
||||||
Loading…
x
Reference in New Issue
Block a user