mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-08 22:04:27 +00:00
Remove usage of internal pydantic functionality for forward ref eval (#4816)
* Remove usage of internal pydantic func * Update python/packages/autogen-core/tests/test_tools.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update python/packages/autogen-core/tests/test_tools.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove unused import NoneType from test_tools.py --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
f774eaa105
commit
0040016bec
@ -2,13 +2,13 @@
|
|||||||
# Credit to original authors
|
# Credit to original authors
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import typing
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
ForwardRef,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
@ -25,29 +25,13 @@ from pydantic import BaseModel, Field, create_model # type: ignore
|
|||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema
|
from ._pydantic_compat import model_dump, type2schema
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
|
||||||
"""Get the type annotation of a parameter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
annotation: The annotation of the parameter
|
|
||||||
globalns: The global namespace of the function
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The type annotation of the parameter
|
|
||||||
"""
|
|
||||||
if isinstance(annotation, str):
|
|
||||||
annotation = ForwardRef(annotation)
|
|
||||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
|
||||||
return annotation
|
|
||||||
|
|
||||||
|
|
||||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||||
"""Get the signature of a function with type annotations.
|
"""Get the signature of a function with type annotations.
|
||||||
|
|
||||||
@ -59,16 +43,17 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
|||||||
"""
|
"""
|
||||||
signature = inspect.signature(call)
|
signature = inspect.signature(call)
|
||||||
globalns = getattr(call, "__globals__", {})
|
globalns = getattr(call, "__globals__", {})
|
||||||
|
type_hints = typing.get_type_hints(call, globalns, include_extras=True)
|
||||||
typed_params = [
|
typed_params = [
|
||||||
inspect.Parameter(
|
inspect.Parameter(
|
||||||
name=param.name,
|
name=param.name,
|
||||||
kind=param.kind,
|
kind=param.kind,
|
||||||
default=param.default,
|
default=param.default,
|
||||||
annotation=get_typed_annotation(param.annotation, globalns),
|
annotation=type_hints[param.name],
|
||||||
)
|
)
|
||||||
for param in signature.parameters.values()
|
for param in signature.parameters.values()
|
||||||
]
|
]
|
||||||
return_annotation = get_typed_annotation(signature.return_annotation, globalns)
|
return_annotation = type_hints.get("return", inspect.Signature.empty)
|
||||||
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
|
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
|
||||||
return typed_signature
|
return typed_signature
|
||||||
|
|
||||||
@ -89,7 +74,8 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
globalns = getattr(call, "__globals__", {})
|
globalns = getattr(call, "__globals__", {})
|
||||||
return get_typed_annotation(annotation, globalns)
|
type_hints = typing.get_type_hints(call, globalns, include_extras=True)
|
||||||
|
return type_hints.get("return", inspect.Signature.empty)
|
||||||
|
|
||||||
|
|
||||||
def get_param_annotations(
|
def get_param_annotations(
|
||||||
|
|||||||
@ -8,26 +8,11 @@ from pydantic import BaseModel
|
|||||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||||
from typing_extensions import get_origin
|
from typing_extensions import get_origin
|
||||||
|
|
||||||
__all__ = ("model_dump", "type2schema", "evaluate_forwardref")
|
__all__ = ("model_dump", "type2schema")
|
||||||
|
|
||||||
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
|
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_forwardref(
|
|
||||||
value: Any,
|
|
||||||
globalns: dict[str, Any] | None = None,
|
|
||||||
localns: dict[str, Any] | None = None,
|
|
||||||
) -> Any:
|
|
||||||
if PYDANTIC_V1:
|
|
||||||
from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal
|
|
||||||
|
|
||||||
return evaluate_forwardref_internal(value, globalns, localns)
|
|
||||||
else:
|
|
||||||
from pydantic._internal._typing_extra import eval_type_lenient
|
|
||||||
|
|
||||||
return eval_type_lenient(value, globalns, localns)
|
|
||||||
|
|
||||||
|
|
||||||
def type2schema(t: Type[Any] | None) -> Dict[str, Any]:
|
def type2schema(t: Type[Any] | None) -> Dict[str, Any]:
|
||||||
if PYDANTIC_V1:
|
if PYDANTIC_V1:
|
||||||
from pydantic import schema_of # type: ignore
|
from pydantic import schema_of # type: ignore
|
||||||
|
|||||||
@ -164,6 +164,48 @@ def test_get_typed_signature_string() -> None:
|
|||||||
assert sig.return_annotation is str
|
assert sig.return_annotation is str
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_typed_signature_params() -> None:
|
||||||
|
def my_function(arg: str) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sig = get_typed_signature(my_function)
|
||||||
|
assert isinstance(sig, inspect.Signature)
|
||||||
|
assert sig.return_annotation is type(None)
|
||||||
|
assert len(sig.parameters) == 1
|
||||||
|
assert sig.parameters["arg"].annotation is str
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_typed_signature_two_params() -> None:
|
||||||
|
def my_function(arg: str, arg2: int) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sig = get_typed_signature(my_function)
|
||||||
|
assert isinstance(sig, inspect.Signature)
|
||||||
|
assert len(sig.parameters) == 2
|
||||||
|
assert sig.parameters["arg"].annotation is str
|
||||||
|
assert sig.parameters["arg2"].annotation is int
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_typed_signature_param_str() -> None:
|
||||||
|
def my_function(arg: "str") -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sig = get_typed_signature(my_function)
|
||||||
|
assert isinstance(sig, inspect.Signature)
|
||||||
|
assert len(sig.parameters) == 1
|
||||||
|
assert sig.parameters["arg"].annotation is str
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_typed_signature_param_annotated() -> None:
|
||||||
|
def my_function(arg: Annotated[str, "An arg"]) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sig = get_typed_signature(my_function)
|
||||||
|
assert isinstance(sig, inspect.Signature)
|
||||||
|
assert len(sig.parameters) == 1
|
||||||
|
assert sig.parameters["arg"].annotation == Annotated[str, "An arg"]
|
||||||
|
|
||||||
|
|
||||||
def test_func_tool() -> None:
|
def test_func_tool() -> None:
|
||||||
def my_function() -> str:
|
def my_function() -> str:
|
||||||
return "result"
|
return "result"
|
||||||
@ -227,7 +269,7 @@ def test_func_tool_return_none() -> None:
|
|||||||
assert tool.name == "my_function"
|
assert tool.name == "my_function"
|
||||||
assert tool.description == "Function tool."
|
assert tool.description == "Function tool."
|
||||||
assert issubclass(tool.args_type(), BaseModel)
|
assert issubclass(tool.args_type(), BaseModel)
|
||||||
assert tool.return_type() is None
|
assert tool.return_type() is type(None)
|
||||||
assert tool.state_type() is None
|
assert tool.state_type() is None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user