mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-03 03:10:04 +00:00
Fix registration of async functions (#1201)
* bug fix for async functions * Update test_conversable_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * Update test/agentchat/test_conversable_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * commented out cell in a notebook until issue #1205 is not fixed --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
fba7caee53
commit
2e519b016a
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from typing_extensions import get_origin
|
||||
|
||||
__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema")
|
||||
__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema", "evaluate_forwardref")
|
||||
|
||||
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
return get_typed_annotation(annotation, globalns)
|
||||
|
||||
|
||||
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type, str], Type]]:
|
||||
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type[Any], str], Type[Any]]]:
|
||||
"""Get the type annotations of the parameters of a function
|
||||
|
||||
Args:
|
||||
@ -111,7 +111,7 @@ class ToolFunction(BaseModel):
|
||||
|
||||
|
||||
def get_parameter_json_schema(
|
||||
k: str, v: Union[Annotated[Type, str], Type], default_values: Dict[str, Any]
|
||||
k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
|
||||
) -> JsonSchemaValue:
|
||||
"""Get a JSON schema for a parameter as defined by the OpenAI API
|
||||
|
||||
@ -124,10 +124,14 @@ def get_parameter_json_schema(
|
||||
A Pydanitc model for the parameter
|
||||
"""
|
||||
|
||||
def type2description(k: str, v: Union[Annotated[Type, str], Type]) -> str:
|
||||
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
|
||||
# handles Annotated
|
||||
if hasattr(v, "__metadata__"):
|
||||
return v.__metadata__[0]
|
||||
retval = v.__metadata__[0]
|
||||
if isinstance(retval, str):
|
||||
return retval
|
||||
else:
|
||||
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
|
||||
else:
|
||||
return k
|
||||
|
||||
@ -166,7 +170,9 @@ def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def get_parameters(
|
||||
required: List[str], param_annotations: Dict[str, Union[Annotated[Type, str], Type]], default_values: Dict[str, Any]
|
||||
required: List[str],
|
||||
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
|
||||
default_values: Dict[str, Any],
|
||||
) -> Parameters:
|
||||
"""Get the parameters of a function as defined by the OpenAI API
|
||||
|
||||
@ -278,7 +284,7 @@ def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, de
|
||||
return model_dump(function)
|
||||
|
||||
|
||||
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type], BaseModel]]:
|
||||
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type[Any]], BaseModel]]:
|
||||
"""Get a function to load a parameter if it is a Pydantic model
|
||||
|
||||
Args:
|
||||
@ -319,7 +325,7 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
|
||||
# a function that loads the parameters before calling the original function
|
||||
@functools.wraps(func)
|
||||
def load_parameters_if_needed(*args, **kwargs):
|
||||
def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
|
||||
# load the BaseModels if needed
|
||||
for k, f in kwargs_mapping.items():
|
||||
kwargs[k] = f(kwargs[k], param_annotations[k])
|
||||
@ -327,7 +333,19 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
# call the original function
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return load_parameters_if_needed
|
||||
@functools.wraps(func)
|
||||
async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
|
||||
# load the BaseModels if needed
|
||||
for k, f in kwargs_mapping.items():
|
||||
kwargs[k] = f(kwargs[k], param_annotations[k])
|
||||
|
||||
# call the original function
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return _a_load_parameters_if_needed
|
||||
else:
|
||||
return _load_parameters_if_needed
|
||||
|
||||
|
||||
def serialize_to_str(x: Any) -> str:
|
||||
|
||||
@ -119,7 +119,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "9fb85afb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -134,40 +134,46 @@
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_thUjscBN349eGd6xh3XrVT18): timer *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m******************************************\u001b[0m\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"timer\" *****\u001b[0m\n",
|
||||
"Timer is done!\n",
|
||||
"\u001b[32m**************************************************\u001b[0m\n",
|
||||
"\u001b[32m**********************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): stopwatch *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m**********************************************\u001b[0m\n",
|
||||
"\u001b[32m**************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"stopwatch\" *****\u001b[0m\n",
|
||||
"Stopwatch is done!\n",
|
||||
"\u001b[32m******************************************************\u001b[0m\n",
|
||||
"\u001b[32m**************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Both the timer and the stopwatch for 5 seconds have been completed. \n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
@ -239,7 +245,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"id": "2472f95c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -274,105 +280,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"id": "e2c9267a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"1) Create a timer for 5 seconds.\n",
|
||||
"2) a stopwatch for 5 seconds.\n",
|
||||
"3) Pretty print the result as md.\n",
|
||||
"4) when 1-3 are done, terminate the group chat\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m******************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
|
||||
"Timer is done!\n",
|
||||
"\u001b[32m**************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m**********************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n",
|
||||
"Stopwatch is done!\n",
|
||||
"\u001b[32m******************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mMarkdown_agent\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"The results are as follows:\n",
|
||||
"\n",
|
||||
"- Timer: Completed after `5 seconds`.\n",
|
||||
"- Stopwatch: Recorded time of `5 seconds`.\n",
|
||||
"\n",
|
||||
"**Timer and Stopwatch Summary:**\n",
|
||||
"Both the timer and stopwatch were set for `5 seconds` and have now concluded successfully. \n",
|
||||
"\n",
|
||||
"Now, let's proceed to terminate the group chat as requested.\n",
|
||||
"\u001b[32m***** Suggested function Call: terminate_group_chat *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"message\":\"All tasks have been completed. The group chat will now be terminated. Goodbye!\"}\n",
|
||||
"\u001b[32m*********************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION terminate_group_chat...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling function \"terminate_group_chat\" *****\u001b[0m\n",
|
||||
"[GROUPCHAT_TERMINATE] All tasks have been completed. The group chat will now be terminated. Goodbye!\n",
|
||||
"\u001b[32m*****************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await user_proxy.a_initiate_chat( # noqa: F704\n",
|
||||
" manager,\n",
|
||||
" message=\"\"\"\n",
|
||||
"1) Create a timer for 5 seconds.\n",
|
||||
"2) a stopwatch for 5 seconds.\n",
|
||||
"3) Pretty print the result as md.\n",
|
||||
"4) when 1-3 are done, terminate the group chat\"\"\",\n",
|
||||
")"
|
||||
"# todo: remove comment after fixing https://github.com/microsoft/autogen/issues/1205\n",
|
||||
"# await user_proxy.a_initiate_chat( # noqa: F704\n",
|
||||
"# manager,\n",
|
||||
"# message=\"\"\"\n",
|
||||
"# 1) Create a timer for 5 seconds.\n",
|
||||
"# 2) a stopwatch for 5 seconds.\n",
|
||||
"# 3) Pretty print the result as md.\n",
|
||||
"# 4) when 1-3 are done, terminate the group chat\"\"\",\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -1,12 +1,18 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Literal
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
import autogen
|
||||
|
||||
from autogen.agentchat import ConversableAgent, UserProxyAgent
|
||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
|
||||
from conftest import skip_openai
|
||||
|
||||
try:
|
||||
@ -445,6 +451,8 @@ def test__wrap_function_sync():
|
||||
== '{"currency":"EUR","amount":100.1}'
|
||||
)
|
||||
|
||||
assert not asyncio.coroutines.iscoroutinefunction(currency_calculator)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__wrap_function_async():
|
||||
@ -481,6 +489,8 @@ async def test__wrap_function_async():
|
||||
== '{"currency":"EUR","amount":100.1}'
|
||||
)
|
||||
|
||||
assert asyncio.coroutines.iscoroutinefunction(currency_calculator)
|
||||
|
||||
|
||||
def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]]:
|
||||
return {k: v._origin for k, v in d.items()}
|
||||
@ -624,6 +634,161 @@ def test_register_for_execution():
|
||||
assert get_origin(user_proxy_1.function_map) == expected_function_map
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip or not sys.version.startswith("3.10"),
|
||||
reason="do not run if openai is not installed or py!=3.10",
|
||||
)
|
||||
def test_function_registration_e2e_sync() -> None:
|
||||
config_list = autogen.config_list_from_json(
|
||||
OAI_CONFIG_LIST,
|
||||
filter_dict={
|
||||
"model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
|
||||
},
|
||||
file_location=KEY_LOC,
|
||||
)
|
||||
|
||||
llm_config = {
|
||||
"config_list": config_list,
|
||||
}
|
||||
|
||||
coder = autogen.AssistantAgent(
|
||||
name="chatbot",
|
||||
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
# create a UserProxyAgent instance named "user_proxy"
|
||||
user_proxy = autogen.UserProxyAgent(
|
||||
name="user_proxy",
|
||||
system_message="A proxy for the user for executing code.",
|
||||
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||
human_input_mode="NEVER",
|
||||
max_consecutive_auto_reply=10,
|
||||
code_execution_config={"work_dir": "coding"},
|
||||
)
|
||||
|
||||
# define functions according to the function description
|
||||
timer_mock = unittest.mock.MagicMock()
|
||||
stopwatch_mock = unittest.mock.MagicMock()
|
||||
|
||||
# An example async function
|
||||
@user_proxy.register_for_execution()
|
||||
@coder.register_for_llm(description="create a timer for N seconds")
|
||||
def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
|
||||
print("timer is running")
|
||||
for i in range(int(num_seconds)):
|
||||
print(".", end="")
|
||||
time.sleep(0.01)
|
||||
print()
|
||||
|
||||
timer_mock(num_seconds=num_seconds)
|
||||
return "Timer is done!"
|
||||
|
||||
# An example sync function
|
||||
@user_proxy.register_for_execution()
|
||||
@coder.register_for_llm(description="create a stopwatch for N seconds")
|
||||
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
|
||||
print("stopwatch is running")
|
||||
# assert False, "stopwatch's alive!"
|
||||
for i in range(int(num_seconds)):
|
||||
print(".", end="")
|
||||
time.sleep(0.01)
|
||||
print()
|
||||
|
||||
stopwatch_mock(num_seconds=num_seconds)
|
||||
return "Stopwatch is done!"
|
||||
|
||||
# start the conversation
|
||||
# 'await' is used to pause and resume code execution for async IO operations.
|
||||
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
|
||||
# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
|
||||
user_proxy.initiate_chat( # noqa: F704
|
||||
coder,
|
||||
message="Create a timer for 2 seconds and then a stopwatch for 3 seconds.",
|
||||
)
|
||||
|
||||
timer_mock.assert_called_once_with(num_seconds="2")
|
||||
stopwatch_mock.assert_called_once_with(num_seconds="3")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip or not sys.version.startswith("3.10"),
|
||||
reason="do not run if openai is not installed or py!=3.10",
|
||||
)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_function_registration_e2e_async() -> None:
|
||||
config_list = autogen.config_list_from_json(
|
||||
OAI_CONFIG_LIST,
|
||||
filter_dict={
|
||||
"model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
|
||||
},
|
||||
file_location=KEY_LOC,
|
||||
)
|
||||
|
||||
llm_config = {
|
||||
"config_list": config_list,
|
||||
}
|
||||
|
||||
coder = autogen.AssistantAgent(
|
||||
name="chatbot",
|
||||
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
# create a UserProxyAgent instance named "user_proxy"
|
||||
user_proxy = autogen.UserProxyAgent(
|
||||
name="user_proxy",
|
||||
system_message="A proxy for the user for executing code.",
|
||||
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||
human_input_mode="NEVER",
|
||||
max_consecutive_auto_reply=10,
|
||||
code_execution_config={"work_dir": "coding"},
|
||||
)
|
||||
|
||||
# define functions according to the function description
|
||||
timer_mock = unittest.mock.MagicMock()
|
||||
stopwatch_mock = unittest.mock.MagicMock()
|
||||
|
||||
# An example async function
|
||||
@user_proxy.register_for_execution()
|
||||
@coder.register_for_llm(description="create a timer for N seconds")
|
||||
async def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
|
||||
print("timer is running")
|
||||
for i in range(int(num_seconds)):
|
||||
print(".", end="")
|
||||
await asyncio.sleep(0.01)
|
||||
print()
|
||||
|
||||
timer_mock(num_seconds=num_seconds)
|
||||
return "Timer is done!"
|
||||
|
||||
# An example sync function
|
||||
@user_proxy.register_for_execution()
|
||||
@coder.register_for_llm(description="create a stopwatch for N seconds")
|
||||
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
|
||||
print("stopwatch is running")
|
||||
# assert False, "stopwatch's alive!"
|
||||
for i in range(int(num_seconds)):
|
||||
print(".", end="")
|
||||
time.sleep(0.01)
|
||||
print()
|
||||
|
||||
stopwatch_mock(num_seconds=num_seconds)
|
||||
return "Stopwatch is done!"
|
||||
|
||||
# start the conversation
|
||||
# 'await' is used to pause and resume code execution for async IO operations.
|
||||
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
|
||||
# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
|
||||
await user_proxy.a_initiate_chat( # noqa: F704
|
||||
coder,
|
||||
message="Create a timer for 4 seconds and then a stopwatch for 5 seconds.",
|
||||
)
|
||||
|
||||
timer_mock.assert_called_once_with(num_seconds="4")
|
||||
stopwatch_mock.assert_called_once_with(num_seconds="5")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip,
|
||||
reason="do not run if skipping openai",
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import unittest.mock
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
@ -355,7 +356,7 @@ def test_get_load_param_if_needed_function() -> None:
|
||||
assert actual == expected, actual
|
||||
|
||||
|
||||
def test_load_basemodels_if_needed() -> None:
|
||||
def test_load_basemodels_if_needed_sync() -> None:
|
||||
@load_basemodels_if_needed
|
||||
def f(
|
||||
base: Annotated[Currency, "Base currency"],
|
||||
@ -363,6 +364,8 @@ def test_load_basemodels_if_needed() -> None:
|
||||
) -> Tuple[Currency, CurrencySymbol]:
|
||||
return base, quote_currency
|
||||
|
||||
assert not asyncio.coroutines.iscoroutinefunction(f)
|
||||
|
||||
actual = f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
|
||||
assert isinstance(actual[0], Currency)
|
||||
assert actual[0].amount == 123.45
|
||||
@ -370,6 +373,24 @@ def test_load_basemodels_if_needed() -> None:
|
||||
assert actual[1] == "EUR"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_basemodels_if_needed_async() -> None:
|
||||
@load_basemodels_if_needed
|
||||
async def f(
|
||||
base: Annotated[Currency, "Base currency"],
|
||||
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
|
||||
) -> Tuple[Currency, CurrencySymbol]:
|
||||
return base, quote_currency
|
||||
|
||||
assert asyncio.coroutines.iscoroutinefunction(f)
|
||||
|
||||
actual = await f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
|
||||
assert isinstance(actual[0], Currency)
|
||||
assert actual[0].amount == 123.45
|
||||
assert actual[0].currency == "USD"
|
||||
assert actual[1] == "EUR"
|
||||
|
||||
|
||||
def test_serialize_to_json():
|
||||
assert serialize_to_str("abc") == "abc"
|
||||
assert serialize_to_str(123) == "123"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user