chore: extract type serialization (#6586)

* move functions

* tests

* reno
This commit is contained in:
ZanSara 2023-12-19 13:16:20 +00:00 committed by GitHub
parent 2dd5a94b04
commit f877704839
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 162 additions and 150 deletions

View File

@ -1,13 +1,11 @@
import importlib
import inspect
import logging
import sys
from typing import List, Dict, Any, Set, get_origin
from typing import List, Dict, Any, Set
from jinja2 import meta, Environment, TemplateSyntaxError
from jinja2.nativetypes import NativeEnvironment
from haystack import component, default_from_dict, default_to_dict, DeserializationError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.type_serialization import serialize_type, deserialize_type
logger = logging.getLogger(__name__)
@ -20,117 +18,6 @@ class RouteConditionException(Exception):
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
def serialize_type(target: Any) -> str:
"""
Serializes a type or an instance to its string representation, including the module name.
This function handles types, instances of types, and special typing objects.
It assumes that non-typing objects will have a '__name__' attribute and raises
an error if a type cannot be serialized.
:param target: The object to serialize, can be an instance or a type.
:type target: Any
:return: The string representation of the type.
:raises ValueError: If the type cannot be serialized.
"""
# If the target is a string and contains a dot, treat it as an already serialized type
if isinstance(target, str) and "." in target:
return target
# Determine if the target is a type or an instance of a typing object
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
type_obj = target if is_type_or_typing else type(target)
module = inspect.getmodule(type_obj)
type_obj_repr = repr(type_obj)
if type_obj_repr.startswith("typing."):
# e.g., typing.List[int] -> List[int], we'll add the module below
type_name = type_obj_repr.split(".", 1)[1]
elif hasattr(type_obj, "__name__"):
type_name = type_obj.__name__
else:
# If type cannot be serialized, raise an error
raise ValueError(f"Could not serialize type: {type_obj_repr}")
# Construct the full path with module name if available
if module and hasattr(module, "__name__"):
if module.__name__ == "builtins":
# omit the module name for builtins, it just clutters the output
# e.g. instead of 'builtins.str', we'll just return 'str'
full_path = type_name
else:
full_path = f"{module.__name__}.{type_name}"
else:
full_path = type_name
return full_path
def deserialize_type(type_str: str) -> Any:
"""
Deserializes a type given its full import path as a string, including nested generic types.
This function will dynamically import the module if it's not already imported
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
:param type_str: The string representation of the type's full import path.
:return: The deserialized type object.
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
"""
def parse_generic_args(args_str):
args = []
bracket_count = 0
current_arg = ""
for char in args_str:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1
if char == "," and bracket_count == 0:
args.append(current_arg.strip())
current_arg = ""
else:
current_arg += char
if current_arg:
args.append(current_arg.strip())
return args
if "[" in type_str and type_str.endswith("]"):
# Handle generics
main_type_str, generics_str = type_str.split("[", 1)
generics_str = generics_str[:-1]
main_type = deserialize_type(main_type_str)
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
# Reconstruct
return main_type[generic_args]
else:
# Handle non-generics
parts = type_str.split(".")
module_name = ".".join(parts[:-1]) or "builtins"
type_name = parts[-1]
module = sys.modules.get(module_name)
if not module:
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise DeserializationError(f"Could not import the module: {module_name}") from e
deserialized_type = getattr(module, type_name, None)
if not deserialized_type:
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
return deserialized_type
@component
class ConditionalRouter:
"""

View File

@ -0,0 +1,117 @@
import importlib
import inspect
import sys
from typing import Any, get_origin
from haystack import DeserializationError
def serialize_type(target: Any) -> str:
"""
Serializes a type or an instance to its string representation, including the module name.
This function handles types, instances of types, and special typing objects.
It assumes that non-typing objects will have a '__name__' attribute and raises
an error if a type cannot be serialized.
:param target: The object to serialize, can be an instance or a type.
:type target: Any
:return: The string representation of the type.
:raises ValueError: If the type cannot be serialized.
"""
# If the target is a string and contains a dot, treat it as an already serialized type
if isinstance(target, str) and "." in target:
return target
# Determine if the target is a type or an instance of a typing object
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
type_obj = target if is_type_or_typing else type(target)
module = inspect.getmodule(type_obj)
type_obj_repr = repr(type_obj)
if type_obj_repr.startswith("typing."):
# e.g., typing.List[int] -> List[int], we'll add the module below
type_name = type_obj_repr.split(".", 1)[1]
elif hasattr(type_obj, "__name__"):
type_name = type_obj.__name__
else:
# If type cannot be serialized, raise an error
raise ValueError(f"Could not serialize type: {type_obj_repr}")
# Construct the full path with module name if available
if module and hasattr(module, "__name__"):
if module.__name__ == "builtins":
# omit the module name for builtins, it just clutters the output
# e.g. instead of 'builtins.str', we'll just return 'str'
full_path = type_name
else:
full_path = f"{module.__name__}.{type_name}"
else:
full_path = type_name
return full_path
def deserialize_type(type_str: str) -> Any:
"""
Deserializes a type given its full import path as a string, including nested generic types.
This function will dynamically import the module if it's not already imported
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
:param type_str: The string representation of the type's full import path.
:return: The deserialized type object.
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
"""
def parse_generic_args(args_str):
args = []
bracket_count = 0
current_arg = ""
for char in args_str:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1
if char == "," and bracket_count == 0:
args.append(current_arg.strip())
current_arg = ""
else:
current_arg += char
if current_arg:
args.append(current_arg.strip())
return args
if "[" in type_str and type_str.endswith("]"):
# Handle generics
main_type_str, generics_str = type_str.split("[", 1)
generics_str = generics_str[:-1]
main_type = deserialize_type(main_type_str)
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
# Reconstruct
return main_type[generic_args]
else:
# Handle non-generics
parts = type_str.split(".")
module_name = ".".join(parts[:-1]) or "builtins"
type_name = parts[-1]
module = sys.modules.get(module_name)
if not module:
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise DeserializationError(f"Could not import the module: {module_name}") from e
deserialized_type = getattr(module, type_name, None)
if not deserialized_type:
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
return deserialized_type

View File

@ -0,0 +1,3 @@
enhancements:
- |
Move `serialize_type` and `deserialize_type` in the `utils` module.

View File

@ -1,12 +1,11 @@
import copy
import typing
from typing import List, Dict
from typing import List
from unittest import mock
import pytest
from haystack.components.routers import ConditionalRouter
from haystack.components.routers.conditional_router import NoRouteSelectedException, serialize_type, deserialize_type
from haystack.components.routers.conditional_router import NoRouteSelectedException
from haystack.dataclasses import ChatMessage
@ -190,37 +189,6 @@ class TestRouter:
with pytest.raises(ValueError):
ConditionalRouter(routes)
def test_output_type_serialization(self):
assert serialize_type(str) == "str"
assert serialize_type(List[int]) == "typing.List[int]"
assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
assert serialize_type(ChatMessage) == "haystack.dataclasses.chat_message.ChatMessage"
assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
assert serialize_type(List[ChatMessage]) == "typing.List[haystack.dataclasses.chat_message.ChatMessage]"
assert (
serialize_type(typing.Dict[int, ChatMessage])
== "typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]"
)
assert serialize_type(int) == "int"
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage"
def test_output_type_deserialization(self):
assert deserialize_type("str") == str
assert deserialize_type("typing.List[int]") == typing.List[int]
assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]]
assert deserialize_type("typing.Dict[str, int]") == Dict[str, int]
assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]]
assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]]
assert (
deserialize_type("typing.List[haystack.dataclasses.chat_message.ChatMessage]") == typing.List[ChatMessage]
)
assert (
deserialize_type("typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]")
== typing.Dict[int, ChatMessage]
)
assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage
assert deserialize_type("int") == int
def test_router_de_serialization(self):
routes = [
{"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"},

View File

@ -0,0 +1,37 @@
import copy
import typing
from typing import List, Dict
from haystack.dataclasses import ChatMessage
from haystack.components.routers.conditional_router import serialize_type, deserialize_type
def test_output_type_serialization():
assert serialize_type(str) == "str"
assert serialize_type(List[int]) == "typing.List[int]"
assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
assert serialize_type(ChatMessage) == "haystack.dataclasses.chat_message.ChatMessage"
assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]"
assert serialize_type(List[ChatMessage]) == "typing.List[haystack.dataclasses.chat_message.ChatMessage]"
assert (
serialize_type(typing.Dict[int, ChatMessage])
== "typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]"
)
assert serialize_type(int) == "int"
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage"
def test_output_type_deserialization():
assert deserialize_type("str") == str
assert deserialize_type("typing.List[int]") == typing.List[int]
assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]]
assert deserialize_type("typing.Dict[str, int]") == Dict[str, int]
assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]]
assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]]
assert deserialize_type("typing.List[haystack.dataclasses.chat_message.ChatMessage]") == typing.List[ChatMessage]
assert (
deserialize_type("typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]")
== typing.Dict[int, ChatMessage]
)
assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage
assert deserialize_type("int") == int