Fix registration of async functions (#1201)

* bug fix for async functions

* Update test_conversable_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update test/agentchat/test_conversable_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* commented out cell in a notebook until issue #1205 is not fixed

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Davor Runje 2024-01-11 10:01:58 +01:00 committed by GitHub
parent fba7caee53
commit 2e519b016a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 241 additions and 116 deletions

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import get_origin from typing_extensions import get_origin
__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema") __all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema", "evaluate_forwardref")
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.") PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")

View File

@ -73,7 +73,7 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns) return get_typed_annotation(annotation, globalns)
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type, str], Type]]: def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type[Any], str], Type[Any]]]:
"""Get the type annotations of the parameters of a function """Get the type annotations of the parameters of a function
Args: Args:
@ -111,7 +111,7 @@ class ToolFunction(BaseModel):
def get_parameter_json_schema( def get_parameter_json_schema(
k: str, v: Union[Annotated[Type, str], Type], default_values: Dict[str, Any] k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
) -> JsonSchemaValue: ) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API """Get a JSON schema for a parameter as defined by the OpenAI API
@ -124,10 +124,14 @@ def get_parameter_json_schema(
A Pydanitc model for the parameter A Pydanitc model for the parameter
""" """
def type2description(k: str, v: Union[Annotated[Type, str], Type]) -> str: def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
# handles Annotated # handles Annotated
if hasattr(v, "__metadata__"): if hasattr(v, "__metadata__"):
return v.__metadata__[0] retval = v.__metadata__[0]
if isinstance(retval, str):
return retval
else:
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
else: else:
return k return k
@ -166,7 +170,9 @@ def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
def get_parameters( def get_parameters(
required: List[str], param_annotations: Dict[str, Union[Annotated[Type, str], Type]], default_values: Dict[str, Any] required: List[str],
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
default_values: Dict[str, Any],
) -> Parameters: ) -> Parameters:
"""Get the parameters of a function as defined by the OpenAI API """Get the parameters of a function as defined by the OpenAI API
@ -278,7 +284,7 @@ def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, de
return model_dump(function) return model_dump(function)
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type], BaseModel]]: def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type[Any]], BaseModel]]:
"""Get a function to load a parameter if it is a Pydantic model """Get a function to load a parameter if it is a Pydantic model
Args: Args:
@ -319,7 +325,7 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
# a function that loads the parameters before calling the original function # a function that loads the parameters before calling the original function
@functools.wraps(func) @functools.wraps(func)
def load_parameters_if_needed(*args, **kwargs): def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed # load the BaseModels if needed
for k, f in kwargs_mapping.items(): for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k]) kwargs[k] = f(kwargs[k], param_annotations[k])
@ -327,7 +333,19 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
# call the original function # call the original function
return func(*args, **kwargs) return func(*args, **kwargs)
return load_parameters_if_needed @functools.wraps(func)
async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])
# call the original function
return await func(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return _a_load_parameters_if_needed
else:
return _load_parameters_if_needed
def serialize_to_str(x: Any) -> str: def serialize_to_str(x: Any) -> str:

View File

@ -119,7 +119,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"id": "9fb85afb", "id": "9fb85afb",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -134,40 +134,46 @@
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n", "\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n", "\n",
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n", "\u001b[32m***** Suggested tool Call (call_thUjscBN349eGd6xh3XrVT18): timer *****\u001b[0m\n",
"Arguments: \n", "Arguments: \n",
"{\"num_seconds\":\"5\"}\n", "{\"num_seconds\":\"5\"}\n",
"\u001b[32m******************************************\u001b[0m\n", "\u001b[32m**********************************************************************\u001b[0m\n",
"\n", "\n",
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[35m\n", "\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n", ">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n", "\n",
"\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"timer\" *****\u001b[0m\n",
"Timer is done!\n", "Timer is done!\n",
"\u001b[32m**************************************************\u001b[0m\n", "\u001b[32m**********************************************\u001b[0m\n",
"\n", "\n",
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n", "\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n", "\n",
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n", "\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): stopwatch *****\u001b[0m\n",
"Arguments: \n", "Arguments: \n",
"{\"num_seconds\":\"5\"}\n", "{\"num_seconds\":\"5\"}\n",
"\u001b[32m**********************************************\u001b[0m\n", "\u001b[32m**************************************************************************\u001b[0m\n",
"\n", "\n",
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[35m\n", "\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n", ">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n", "\n",
"\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"stopwatch\" *****\u001b[0m\n",
"Stopwatch is done!\n", "Stopwatch is done!\n",
"\u001b[32m******************************************************\u001b[0m\n", "\u001b[32m**************************************************\u001b[0m\n",
"\n", "\n",
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n", "\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n", "\n",
"Both the timer and the stopwatch for 5 seconds have been completed. \n",
"\n",
"TERMINATE\n", "TERMINATE\n",
"\n", "\n",
"--------------------------------------------------------------------------------\n" "--------------------------------------------------------------------------------\n"
@ -239,7 +245,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"id": "2472f95c", "id": "2472f95c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -274,105 +280,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"id": "e2c9267a", "id": "e2c9267a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"1) Create a timer for 5 seconds.\n",
"2) a stopwatch for 5 seconds.\n",
"3) Pretty print the result as md.\n",
"4) when 1-3 are done, terminate the group chat\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m******************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
"Timer is done!\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n",
"Stopwatch is done!\n",
"\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mMarkdown_agent\u001b[0m (to chat_manager):\n",
"\n",
"The results are as follows:\n",
"\n",
"- Timer: Completed after `5 seconds`.\n",
"- Stopwatch: Recorded time of `5 seconds`.\n",
"\n",
"**Timer and Stopwatch Summary:**\n",
"Both the timer and stopwatch were set for `5 seconds` and have now concluded successfully. \n",
"\n",
"Now, let's proceed to terminate the group chat as requested.\n",
"\u001b[32m***** Suggested function Call: terminate_group_chat *****\u001b[0m\n",
"Arguments: \n",
"{\"message\":\"All tasks have been completed. The group chat will now be terminated. Goodbye!\"}\n",
"\u001b[32m*********************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION terminate_group_chat...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling function \"terminate_group_chat\" *****\u001b[0m\n",
"[GROUPCHAT_TERMINATE] All tasks have been completed. The group chat will now be terminated. Goodbye!\n",
"\u001b[32m*****************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [ "source": [
"await user_proxy.a_initiate_chat( # noqa: F704\n", "# todo: remove comment after fixing https://github.com/microsoft/autogen/issues/1205\n",
" manager,\n", "# await user_proxy.a_initiate_chat( # noqa: F704\n",
" message=\"\"\"\n", "# manager,\n",
"1) Create a timer for 5 seconds.\n", "# message=\"\"\"\n",
"2) a stopwatch for 5 seconds.\n", "# 1) Create a timer for 5 seconds.\n",
"3) Pretty print the result as md.\n", "# 2) a stopwatch for 5 seconds.\n",
"4) when 1-3 are done, terminate the group chat\"\"\",\n", "# 3) Pretty print the result as md.\n",
")" "# 4) when 1-3 are done, terminate the group chat\"\"\",\n",
"# )"
] ]
}, },
{ {

View File

@ -1,12 +1,18 @@
import asyncio
import copy import copy
import sys
import time
from typing import Any, Callable, Dict, Literal from typing import Any, Callable, Dict, Literal
import unittest
import pytest import pytest
from unittest.mock import patch from unittest.mock import patch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
import autogen
from autogen.agentchat import ConversableAgent, UserProxyAgent from autogen.agentchat import ConversableAgent, UserProxyAgent
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
from conftest import skip_openai from conftest import skip_openai
try: try:
@ -445,6 +451,8 @@ def test__wrap_function_sync():
== '{"currency":"EUR","amount":100.1}' == '{"currency":"EUR","amount":100.1}'
) )
assert not asyncio.coroutines.iscoroutinefunction(currency_calculator)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test__wrap_function_async(): async def test__wrap_function_async():
@ -481,6 +489,8 @@ async def test__wrap_function_async():
== '{"currency":"EUR","amount":100.1}' == '{"currency":"EUR","amount":100.1}'
) )
assert asyncio.coroutines.iscoroutinefunction(currency_calculator)
def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]]: def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]]:
return {k: v._origin for k, v in d.items()} return {k: v._origin for k, v in d.items()}
@ -624,6 +634,161 @@ def test_register_for_execution():
assert get_origin(user_proxy_1.function_map) == expected_function_map assert get_origin(user_proxy_1.function_map) == expected_function_map
@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
def test_function_registration_e2e_sync() -> None:
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict={
"model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
},
file_location=KEY_LOC,
)
llm_config = {
"config_list": config_list,
}
coder = autogen.AssistantAgent(
name="chatbot",
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
llm_config=llm_config,
)
# create a UserProxyAgent instance named "user_proxy"
user_proxy = autogen.UserProxyAgent(
name="user_proxy",
system_message="A proxy for the user for executing code.",
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir": "coding"},
)
# define functions according to the function description
timer_mock = unittest.mock.MagicMock()
stopwatch_mock = unittest.mock.MagicMock()
# An example async function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a timer for N seconds")
def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
print("timer is running")
for i in range(int(num_seconds)):
print(".", end="")
time.sleep(0.01)
print()
timer_mock(num_seconds=num_seconds)
return "Timer is done!"
# An example sync function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a stopwatch for N seconds")
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
print("stopwatch is running")
# assert False, "stopwatch's alive!"
for i in range(int(num_seconds)):
print(".", end="")
time.sleep(0.01)
print()
stopwatch_mock(num_seconds=num_seconds)
return "Stopwatch is done!"
# start the conversation
# 'await' is used to pause and resume code execution for async IO operations.
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
user_proxy.initiate_chat( # noqa: F704
coder,
message="Create a timer for 2 seconds and then a stopwatch for 3 seconds.",
)
timer_mock.assert_called_once_with(num_seconds="2")
stopwatch_mock.assert_called_once_with(num_seconds="3")
@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
@pytest.mark.asyncio()
async def test_function_registration_e2e_async() -> None:
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict={
"model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
},
file_location=KEY_LOC,
)
llm_config = {
"config_list": config_list,
}
coder = autogen.AssistantAgent(
name="chatbot",
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
llm_config=llm_config,
)
# create a UserProxyAgent instance named "user_proxy"
user_proxy = autogen.UserProxyAgent(
name="user_proxy",
system_message="A proxy for the user for executing code.",
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir": "coding"},
)
# define functions according to the function description
timer_mock = unittest.mock.MagicMock()
stopwatch_mock = unittest.mock.MagicMock()
# An example async function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a timer for N seconds")
async def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
print("timer is running")
for i in range(int(num_seconds)):
print(".", end="")
await asyncio.sleep(0.01)
print()
timer_mock(num_seconds=num_seconds)
return "Timer is done!"
# An example sync function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a stopwatch for N seconds")
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
print("stopwatch is running")
# assert False, "stopwatch's alive!"
for i in range(int(num_seconds)):
print(".", end="")
time.sleep(0.01)
print()
stopwatch_mock(num_seconds=num_seconds)
return "Stopwatch is done!"
# start the conversation
# 'await' is used to pause and resume code execution for async IO operations.
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
await user_proxy.a_initiate_chat( # noqa: F704
coder,
message="Create a timer for 4 seconds and then a stopwatch for 5 seconds.",
)
timer_mock.assert_called_once_with(num_seconds="4")
stopwatch_mock.assert_called_once_with(num_seconds="5")
@pytest.mark.skipif( @pytest.mark.skipif(
skip, skip,
reason="do not run if skipping openai", reason="do not run if skipping openai",

View File

@ -1,3 +1,4 @@
import asyncio
import inspect import inspect
import unittest.mock import unittest.mock
from typing import Dict, List, Literal, Optional, Tuple from typing import Dict, List, Literal, Optional, Tuple
@ -355,7 +356,7 @@ def test_get_load_param_if_needed_function() -> None:
assert actual == expected, actual assert actual == expected, actual
def test_load_basemodels_if_needed() -> None: def test_load_basemodels_if_needed_sync() -> None:
@load_basemodels_if_needed @load_basemodels_if_needed
def f( def f(
base: Annotated[Currency, "Base currency"], base: Annotated[Currency, "Base currency"],
@ -363,6 +364,8 @@ def test_load_basemodels_if_needed() -> None:
) -> Tuple[Currency, CurrencySymbol]: ) -> Tuple[Currency, CurrencySymbol]:
return base, quote_currency return base, quote_currency
assert not asyncio.coroutines.iscoroutinefunction(f)
actual = f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR") actual = f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
assert isinstance(actual[0], Currency) assert isinstance(actual[0], Currency)
assert actual[0].amount == 123.45 assert actual[0].amount == 123.45
@ -370,6 +373,24 @@ def test_load_basemodels_if_needed() -> None:
assert actual[1] == "EUR" assert actual[1] == "EUR"
@pytest.mark.asyncio
async def test_load_basemodels_if_needed_async() -> None:
@load_basemodels_if_needed
async def f(
base: Annotated[Currency, "Base currency"],
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
) -> Tuple[Currency, CurrencySymbol]:
return base, quote_currency
assert asyncio.coroutines.iscoroutinefunction(f)
actual = await f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
assert isinstance(actual[0], Currency)
assert actual[0].amount == 123.45
assert actual[0].currency == "USD"
assert actual[1] == "EUR"
def test_serialize_to_json(): def test_serialize_to_json():
assert serialize_to_str("abc") == "abc" assert serialize_to_str("abc") == "abc"
assert serialize_to_str(123) == "123" assert serialize_to_str(123) == "123"