From f877704839dc585cf47305d97fd5011be367b907 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 19 Dec 2023 13:16:20 +0000 Subject: [PATCH] chore: extract type serialization (#6586) * move functions * tests * reno --- .../components/routers/conditional_router.py | 119 +----------------- haystack/utils/type_serialization.py | 117 +++++++++++++++++ ...t_type_serialization-fc3ea6418ba5632d.yaml | 3 + .../routers/test_conditional_router.py | 36 +----- test/utils/test_type_serialization.py | 37 ++++++ 5 files changed, 162 insertions(+), 150 deletions(-) create mode 100644 haystack/utils/type_serialization.py create mode 100644 releasenotes/notes/extract_type_serialization-fc3ea6418ba5632d.yaml create mode 100644 test/utils/test_type_serialization.py diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index a080dcb34..ea77a563c 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -1,13 +1,11 @@ -import importlib -import inspect import logging -import sys -from typing import List, Dict, Any, Set, get_origin +from typing import List, Dict, Any, Set from jinja2 import meta, Environment, TemplateSyntaxError from jinja2.nativetypes import NativeEnvironment -from haystack import component, default_from_dict, default_to_dict, DeserializationError +from haystack import component, default_from_dict, default_to_dict +from haystack.utils.type_serialization import serialize_type, deserialize_type logger = logging.getLogger(__name__) @@ -20,117 +18,6 @@ class RouteConditionException(Exception): """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter.""" -def serialize_type(target: Any) -> str: - """ - Serializes a type or an instance to its string representation, including the module name. - - This function handles types, instances of types, and special typing objects. - It assumes that non-typing objects will have a '__name__' attribute and raises - an error if a type cannot be serialized. - - :param target: The object to serialize, can be an instance or a type. - :type target: Any - :return: The string representation of the type. - :raises ValueError: If the type cannot be serialized. - """ - # If the target is a string and contains a dot, treat it as an already serialized type - if isinstance(target, str) and "." in target: - return target - - # Determine if the target is a type or an instance of a typing object - is_type_or_typing = isinstance(target, type) or bool(get_origin(target)) - type_obj = target if is_type_or_typing else type(target) - module = inspect.getmodule(type_obj) - type_obj_repr = repr(type_obj) - - if type_obj_repr.startswith("typing."): - # e.g., typing.List[int] -> List[int], we'll add the module below - type_name = type_obj_repr.split(".", 1)[1] - elif hasattr(type_obj, "__name__"): - type_name = type_obj.__name__ - else: - # If type cannot be serialized, raise an error - raise ValueError(f"Could not serialize type: {type_obj_repr}") - - # Construct the full path with module name if available - if module and hasattr(module, "__name__"): - if module.__name__ == "builtins": - # omit the module name for builtins, it just clutters the output - # e.g. instead of 'builtins.str', we'll just return 'str' - full_path = type_name - else: - full_path = f"{module.__name__}.{type_name}" - else: - full_path = type_name - - return full_path - - -def deserialize_type(type_str: str) -> Any: - """ - Deserializes a type given its full import path as a string, including nested generic types. - - This function will dynamically import the module if it's not already imported - and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'. - - :param type_str: The string representation of the type's full import path. - :return: The deserialized type object. - :raises DeserializationError: If the type cannot be deserialized due to missing module or type. - """ - - def parse_generic_args(args_str): - args = [] - bracket_count = 0 - current_arg = "" - - for char in args_str: - if char == "[": - bracket_count += 1 - elif char == "]": - bracket_count -= 1 - - if char == "," and bracket_count == 0: - args.append(current_arg.strip()) - current_arg = "" - else: - current_arg += char - - if current_arg: - args.append(current_arg.strip()) - - return args - - if "[" in type_str and type_str.endswith("]"): - # Handle generics - main_type_str, generics_str = type_str.split("[", 1) - generics_str = generics_str[:-1] - - main_type = deserialize_type(main_type_str) - generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str)) - - # Reconstruct - return main_type[generic_args] - - else: - # Handle non-generics - parts = type_str.split(".") - module_name = ".".join(parts[:-1]) or "builtins" - type_name = parts[-1] - - module = sys.modules.get(module_name) - if not module: - try: - module = importlib.import_module(module_name) - except ImportError as e: - raise DeserializationError(f"Could not import the module: {module_name}") from e - - deserialized_type = getattr(module, type_name, None) - if not deserialized_type: - raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}") - - return deserialized_type - - @component class ConditionalRouter: """ diff --git a/haystack/utils/type_serialization.py b/haystack/utils/type_serialization.py new file mode 100644 index 000000000..5d16554b9 --- /dev/null +++ b/haystack/utils/type_serialization.py @@ -0,0 +1,117 @@ +import importlib +import inspect +import sys +from typing import Any, get_origin + +from haystack import DeserializationError + + +def serialize_type(target: Any) -> str: + """ + Serializes a type or an instance to its string representation, including the module name. + + This function handles types, instances of types, and special typing objects. + It assumes that non-typing objects will have a '__name__' attribute and raises + an error if a type cannot be serialized. + + :param target: The object to serialize, can be an instance or a type. + :type target: Any + :return: The string representation of the type. + :raises ValueError: If the type cannot be serialized. + """ + # If the target is a string and contains a dot, treat it as an already serialized type + if isinstance(target, str) and "." in target: + return target + + # Determine if the target is a type or an instance of a typing object + is_type_or_typing = isinstance(target, type) or bool(get_origin(target)) + type_obj = target if is_type_or_typing else type(target) + module = inspect.getmodule(type_obj) + type_obj_repr = repr(type_obj) + + if type_obj_repr.startswith("typing."): + # e.g., typing.List[int] -> List[int], we'll add the module below + type_name = type_obj_repr.split(".", 1)[1] + elif hasattr(type_obj, "__name__"): + type_name = type_obj.__name__ + else: + # If type cannot be serialized, raise an error + raise ValueError(f"Could not serialize type: {type_obj_repr}") + + # Construct the full path with module name if available + if module and hasattr(module, "__name__"): + if module.__name__ == "builtins": + # omit the module name for builtins, it just clutters the output + # e.g. instead of 'builtins.str', we'll just return 'str' + full_path = type_name + else: + full_path = f"{module.__name__}.{type_name}" + else: + full_path = type_name + + return full_path + + +def deserialize_type(type_str: str) -> Any: + """ + Deserializes a type given its full import path as a string, including nested generic types. + + This function will dynamically import the module if it's not already imported + and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'. + + :param type_str: The string representation of the type's full import path. + :return: The deserialized type object. + :raises DeserializationError: If the type cannot be deserialized due to missing module or type. + """ + + def parse_generic_args(args_str): + args = [] + bracket_count = 0 + current_arg = "" + + for char in args_str: + if char == "[": + bracket_count += 1 + elif char == "]": + bracket_count -= 1 + + if char == "," and bracket_count == 0: + args.append(current_arg.strip()) + current_arg = "" + else: + current_arg += char + + if current_arg: + args.append(current_arg.strip()) + + return args + + if "[" in type_str and type_str.endswith("]"): + # Handle generics + main_type_str, generics_str = type_str.split("[", 1) + generics_str = generics_str[:-1] + + main_type = deserialize_type(main_type_str) + generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str)) + + # Reconstruct + return main_type[generic_args] + + else: + # Handle non-generics + parts = type_str.split(".") + module_name = ".".join(parts[:-1]) or "builtins" + type_name = parts[-1] + + module = sys.modules.get(module_name) + if not module: + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise DeserializationError(f"Could not import the module: {module_name}") from e + + deserialized_type = getattr(module, type_name, None) + if not deserialized_type: + raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}") + + return deserialized_type diff --git a/releasenotes/notes/extract_type_serialization-fc3ea6418ba5632d.yaml b/releasenotes/notes/extract_type_serialization-fc3ea6418ba5632d.yaml new file mode 100644 index 000000000..b380f9ec8 --- /dev/null +++ b/releasenotes/notes/extract_type_serialization-fc3ea6418ba5632d.yaml @@ -0,0 +1,3 @@ +enhancements: + - | + Move `serialize_type` and `deserialize_type` in the `utils` module. diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index 177322eea..672231ad1 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -1,12 +1,11 @@ import copy -import typing -from typing import List, Dict +from typing import List from unittest import mock import pytest from haystack.components.routers import ConditionalRouter -from haystack.components.routers.conditional_router import NoRouteSelectedException, serialize_type, deserialize_type +from haystack.components.routers.conditional_router import NoRouteSelectedException from haystack.dataclasses import ChatMessage @@ -190,37 +189,6 @@ class TestRouter: with pytest.raises(ValueError): ConditionalRouter(routes) - def test_output_type_serialization(self): - assert serialize_type(str) == "str" - assert serialize_type(List[int]) == "typing.List[int]" - assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]" - assert serialize_type(ChatMessage) == "haystack.dataclasses.chat_message.ChatMessage" - assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]" - assert serialize_type(List[ChatMessage]) == "typing.List[haystack.dataclasses.chat_message.ChatMessage]" - assert ( - serialize_type(typing.Dict[int, ChatMessage]) - == "typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]" - ) - assert serialize_type(int) == "int" - assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage" - - def test_output_type_deserialization(self): - assert deserialize_type("str") == str - assert deserialize_type("typing.List[int]") == typing.List[int] - assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]] - assert deserialize_type("typing.Dict[str, int]") == Dict[str, int] - assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]] - assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]] - assert ( - deserialize_type("typing.List[haystack.dataclasses.chat_message.ChatMessage]") == typing.List[ChatMessage] - ) - assert ( - deserialize_type("typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]") - == typing.Dict[int, ChatMessage] - ) - assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage - assert deserialize_type("int") == int - def test_router_de_serialization(self): routes = [ {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"}, diff --git a/test/utils/test_type_serialization.py b/test/utils/test_type_serialization.py new file mode 100644 index 000000000..aee596853 --- /dev/null +++ b/test/utils/test_type_serialization.py @@ -0,0 +1,37 @@ +import copy +import typing +from typing import List, Dict + +from haystack.dataclasses import ChatMessage +from haystack.components.routers.conditional_router import serialize_type, deserialize_type + + +def test_output_type_serialization(): + assert serialize_type(str) == "str" + assert serialize_type(List[int]) == "typing.List[int]" + assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]" + assert serialize_type(ChatMessage) == "haystack.dataclasses.chat_message.ChatMessage" + assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]" + assert serialize_type(List[ChatMessage]) == "typing.List[haystack.dataclasses.chat_message.ChatMessage]" + assert ( + serialize_type(typing.Dict[int, ChatMessage]) + == "typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]" + ) + assert serialize_type(int) == "int" + assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage" + + +def test_output_type_deserialization(): + assert deserialize_type("str") == str + assert deserialize_type("typing.List[int]") == typing.List[int] + assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]] + assert deserialize_type("typing.Dict[str, int]") == Dict[str, int] + assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]] + assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]] + assert deserialize_type("typing.List[haystack.dataclasses.chat_message.ChatMessage]") == typing.List[ChatMessage] + assert ( + deserialize_type("typing.Dict[int, haystack.dataclasses.chat_message.ChatMessage]") + == typing.Dict[int, ChatMessage] + ) + assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage + assert deserialize_type("int") == int