From 0040016becb9bedf6423fb5b9f48c0cd76b6a438 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 26 Dec 2024 15:15:54 -0500 Subject: [PATCH] Remove usage of internal pydantic functionality for forward ref eval (#4816) * Remove usage of internal pydantic func * Update python/packages/autogen-core/tests/test_tools.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update python/packages/autogen-core/tests/test_tools.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove unused import NoneType from test_tools.py --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../src/autogen_core/_function_utils.py | 28 +++--------- .../src/autogen_core/_pydantic_compat.py | 17 +------ .../packages/autogen-core/tests/test_tools.py | 44 ++++++++++++++++++- 3 files changed, 51 insertions(+), 38 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_function_utils.py b/python/packages/autogen-core/src/autogen_core/_function_utils.py index d541411d4..a6803d085 100644 --- a/python/packages/autogen-core/src/autogen_core/_function_utils.py +++ b/python/packages/autogen-core/src/autogen_core/_function_utils.py @@ -2,13 +2,13 @@ # Credit to original authors import inspect +import typing from logging import getLogger from typing import ( Annotated, Any, Callable, Dict, - ForwardRef, List, Optional, Set, @@ -25,29 +25,13 @@ from pydantic import BaseModel, Field, create_model # type: ignore from pydantic_core import PydanticUndefined from typing_extensions import Literal -from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema +from ._pydantic_compat import model_dump, type2schema logger = getLogger(__name__) T = TypeVar("T") -def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: - """Get the type annotation of a parameter. - - Args: - annotation: The annotation of the parameter - globalns: The global namespace of the function - - Returns: - The type annotation of the parameter - """ - if isinstance(annotation, str): - annotation = ForwardRef(annotation) - annotation = evaluate_forwardref(annotation, globalns, globalns) - return annotation - - def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: """Get the signature of a function with type annotations. @@ -59,16 +43,17 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: """ signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) + type_hints = typing.get_type_hints(call, globalns, include_extras=True) typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, - annotation=get_typed_annotation(param.annotation, globalns), + annotation=type_hints[param.name], ) for param in signature.parameters.values() ] - return_annotation = get_typed_annotation(signature.return_annotation, globalns) + return_annotation = type_hints.get("return", inspect.Signature.empty) typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation) return typed_signature @@ -89,7 +74,8 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: return None globalns = getattr(call, "__globals__", {}) - return get_typed_annotation(annotation, globalns) + type_hints = typing.get_type_hints(call, globalns, include_extras=True) + return type_hints.get("return", inspect.Signature.empty) def get_param_annotations( diff --git a/python/packages/autogen-core/src/autogen_core/_pydantic_compat.py b/python/packages/autogen-core/src/autogen_core/_pydantic_compat.py index 661350616..c29ccb70c 100644 --- a/python/packages/autogen-core/src/autogen_core/_pydantic_compat.py +++ b/python/packages/autogen-core/src/autogen_core/_pydantic_compat.py @@ -8,26 +8,11 @@ from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION from typing_extensions import get_origin -__all__ = ("model_dump", "type2schema", "evaluate_forwardref") +__all__ = ("model_dump", "type2schema") PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.") -def evaluate_forwardref( - value: Any, - globalns: dict[str, Any] | None = None, - localns: dict[str, Any] | None = None, -) -> Any: - if PYDANTIC_V1: - from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal - - return evaluate_forwardref_internal(value, globalns, localns) - else: - from pydantic._internal._typing_extra import eval_type_lenient - - return eval_type_lenient(value, globalns, localns) - - def type2schema(t: Type[Any] | None) -> Dict[str, Any]: if PYDANTIC_V1: from pydantic import schema_of # type: ignore diff --git a/python/packages/autogen-core/tests/test_tools.py b/python/packages/autogen-core/tests/test_tools.py index 64d97e115..f6910b507 100644 --- a/python/packages/autogen-core/tests/test_tools.py +++ b/python/packages/autogen-core/tests/test_tools.py @@ -164,6 +164,48 @@ def test_get_typed_signature_string() -> None: assert sig.return_annotation is str +def test_get_typed_signature_params() -> None: + def my_function(arg: str) -> None: + return None + + sig = get_typed_signature(my_function) + assert isinstance(sig, inspect.Signature) + assert sig.return_annotation is type(None) + assert len(sig.parameters) == 1 + assert sig.parameters["arg"].annotation is str + + +def test_get_typed_signature_two_params() -> None: + def my_function(arg: str, arg2: int) -> None: + return None + + sig = get_typed_signature(my_function) + assert isinstance(sig, inspect.Signature) + assert len(sig.parameters) == 2 + assert sig.parameters["arg"].annotation is str + assert sig.parameters["arg2"].annotation is int + + +def test_get_typed_signature_param_str() -> None: + def my_function(arg: "str") -> None: + return None + + sig = get_typed_signature(my_function) + assert isinstance(sig, inspect.Signature) + assert len(sig.parameters) == 1 + assert sig.parameters["arg"].annotation is str + + +def test_get_typed_signature_param_annotated() -> None: + def my_function(arg: Annotated[str, "An arg"]) -> None: + return None + + sig = get_typed_signature(my_function) + assert isinstance(sig, inspect.Signature) + assert len(sig.parameters) == 1 + assert sig.parameters["arg"].annotation == Annotated[str, "An arg"] + + def test_func_tool() -> None: def my_function() -> str: return "result" @@ -227,7 +269,7 @@ def test_func_tool_return_none() -> None: assert tool.name == "my_function" assert tool.description == "Function tool." assert issubclass(tool.args_type(), BaseModel) - assert tool.return_type() is None + assert tool.return_type() is type(None) assert tool.state_type() is None