mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 18:59:48 +00:00
Remove usage of internal pydantic functionality for forward ref eval (#4816)
* Remove usage of internal pydantic func * Update python/packages/autogen-core/tests/test_tools.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update python/packages/autogen-core/tests/test_tools.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove unused import NoneType from test_tools.py --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
f774eaa105
commit
0040016bec
@ -2,13 +2,13 @@
|
||||
# Credit to original authors
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from logging import getLogger
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
@ -25,29 +25,13 @@ from pydantic import BaseModel, Field, create_model # type: ignore
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema
|
||||
from ._pydantic_compat import model_dump, type2schema
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
||||
"""Get the type annotation of a parameter.
|
||||
|
||||
Args:
|
||||
annotation: The annotation of the parameter
|
||||
globalns: The global namespace of the function
|
||||
|
||||
Returns:
|
||||
The type annotation of the parameter
|
||||
"""
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
return annotation
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""Get the signature of a function with type annotations.
|
||||
|
||||
@ -59,16 +43,17 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
type_hints = typing.get_type_hints(call, globalns, include_extras=True)
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_typed_annotation(param.annotation, globalns),
|
||||
annotation=type_hints[param.name],
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
return_annotation = get_typed_annotation(signature.return_annotation, globalns)
|
||||
return_annotation = type_hints.get("return", inspect.Signature.empty)
|
||||
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
|
||||
return typed_signature
|
||||
|
||||
@ -89,7 +74,8 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
return None
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
return get_typed_annotation(annotation, globalns)
|
||||
type_hints = typing.get_type_hints(call, globalns, include_extras=True)
|
||||
return type_hints.get("return", inspect.Signature.empty)
|
||||
|
||||
|
||||
def get_param_annotations(
|
||||
|
@ -8,26 +8,11 @@ from pydantic import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from typing_extensions import get_origin
|
||||
|
||||
__all__ = ("model_dump", "type2schema", "evaluate_forwardref")
|
||||
__all__ = ("model_dump", "type2schema")
|
||||
|
||||
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
|
||||
|
||||
|
||||
def evaluate_forwardref(
|
||||
value: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if PYDANTIC_V1:
|
||||
from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal
|
||||
|
||||
return evaluate_forwardref_internal(value, globalns, localns)
|
||||
else:
|
||||
from pydantic._internal._typing_extra import eval_type_lenient
|
||||
|
||||
return eval_type_lenient(value, globalns, localns)
|
||||
|
||||
|
||||
def type2schema(t: Type[Any] | None) -> Dict[str, Any]:
|
||||
if PYDANTIC_V1:
|
||||
from pydantic import schema_of # type: ignore
|
||||
|
@ -164,6 +164,48 @@ def test_get_typed_signature_string() -> None:
|
||||
assert sig.return_annotation is str
|
||||
|
||||
|
||||
def test_get_typed_signature_params() -> None:
|
||||
def my_function(arg: str) -> None:
|
||||
return None
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert sig.return_annotation is type(None)
|
||||
assert len(sig.parameters) == 1
|
||||
assert sig.parameters["arg"].annotation is str
|
||||
|
||||
|
||||
def test_get_typed_signature_two_params() -> None:
|
||||
def my_function(arg: str, arg2: int) -> None:
|
||||
return None
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert len(sig.parameters) == 2
|
||||
assert sig.parameters["arg"].annotation is str
|
||||
assert sig.parameters["arg2"].annotation is int
|
||||
|
||||
|
||||
def test_get_typed_signature_param_str() -> None:
|
||||
def my_function(arg: "str") -> None:
|
||||
return None
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert len(sig.parameters) == 1
|
||||
assert sig.parameters["arg"].annotation is str
|
||||
|
||||
|
||||
def test_get_typed_signature_param_annotated() -> None:
|
||||
def my_function(arg: Annotated[str, "An arg"]) -> None:
|
||||
return None
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert len(sig.parameters) == 1
|
||||
assert sig.parameters["arg"].annotation == Annotated[str, "An arg"]
|
||||
|
||||
|
||||
def test_func_tool() -> None:
|
||||
def my_function() -> str:
|
||||
return "result"
|
||||
@ -227,7 +269,7 @@ def test_func_tool_return_none() -> None:
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert tool.return_type() is None
|
||||
assert tool.return_type() is type(None)
|
||||
assert tool.state_type() is None
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user