Fix registration of async functions (#1201)

* bug fix for async functions

* Update test_conversable_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update test/agentchat/test_conversable_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* commented out cell in a notebook until issue #1205 is not fixed

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Davor Runje 2024-01-11 10:01:58 +01:00 committed by GitHub
parent fba7caee53
commit 2e519b016a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 241 additions and 116 deletions

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import get_origin
__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema")
__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema", "evaluate_forwardref")
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")

View File

@ -73,7 +73,7 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type, str], Type]]:
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type[Any], str], Type[Any]]]:
"""Get the type annotations of the parameters of a function
Args:
@ -111,7 +111,7 @@ class ToolFunction(BaseModel):
def get_parameter_json_schema(
k: str, v: Union[Annotated[Type, str], Type], default_values: Dict[str, Any]
k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API
@ -124,10 +124,14 @@ def get_parameter_json_schema(
A Pydanitc model for the parameter
"""
def type2description(k: str, v: Union[Annotated[Type, str], Type]) -> str:
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
# handles Annotated
if hasattr(v, "__metadata__"):
return v.__metadata__[0]
retval = v.__metadata__[0]
if isinstance(retval, str):
return retval
else:
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
else:
return k
@ -166,7 +170,9 @@ def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
def get_parameters(
required: List[str], param_annotations: Dict[str, Union[Annotated[Type, str], Type]], default_values: Dict[str, Any]
required: List[str],
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
default_values: Dict[str, Any],
) -> Parameters:
"""Get the parameters of a function as defined by the OpenAI API
@ -278,7 +284,7 @@ def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, de
return model_dump(function)
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type], BaseModel]]:
def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type[Any]], BaseModel]]:
"""Get a function to load a parameter if it is a Pydantic model
Args:
@ -319,7 +325,7 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
# a function that loads the parameters before calling the original function
@functools.wraps(func)
def load_parameters_if_needed(*args, **kwargs):
def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])
@ -327,7 +333,19 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
# call the original function
return func(*args, **kwargs)
return load_parameters_if_needed
@functools.wraps(func)
async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])
# call the original function
return await func(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return _a_load_parameters_if_needed
else:
return _load_parameters_if_needed
def serialize_to_str(x: Any) -> str:

View File

@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "9fb85afb",
"metadata": {},
"outputs": [
@ -134,40 +134,46 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_thUjscBN349eGd6xh3XrVT18): timer *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m******************************************\u001b[0m\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"timer\" *****\u001b[0m\n",
"Timer is done!\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): stopwatch *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\u001b[32m**************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"stopwatch\" *****\u001b[0m\n",
"Stopwatch is done!\n",
"\u001b[32m******************************************************\u001b[0m\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"Both the timer and the stopwatch for 5 seconds have been completed. \n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
@ -239,7 +245,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "2472f95c",
"metadata": {},
"outputs": [],
@ -274,105 +280,20 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "e2c9267a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"1) Create a timer for 5 seconds.\n",
"2) a stopwatch for 5 seconds.\n",
"3) Pretty print the result as md.\n",
"4) when 1-3 are done, terminate the group chat\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m******************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
"Timer is done!\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n",
"Stopwatch is done!\n",
"\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mMarkdown_agent\u001b[0m (to chat_manager):\n",
"\n",
"The results are as follows:\n",
"\n",
"- Timer: Completed after `5 seconds`.\n",
"- Stopwatch: Recorded time of `5 seconds`.\n",
"\n",
"**Timer and Stopwatch Summary:**\n",
"Both the timer and stopwatch were set for `5 seconds` and have now concluded successfully. \n",
"\n",
"Now, let's proceed to terminate the group chat as requested.\n",
"\u001b[32m***** Suggested function Call: terminate_group_chat *****\u001b[0m\n",
"Arguments: \n",
"{\"message\":\"All tasks have been completed. The group chat will now be terminated. Goodbye!\"}\n",
"\u001b[32m*********************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION terminate_group_chat...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling function \"terminate_group_chat\" *****\u001b[0m\n",
"[GROUPCHAT_TERMINATE] All tasks have been completed. The group chat will now be terminated. Goodbye!\n",
"\u001b[32m*****************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"outputs": [],
"source": [
"await user_proxy.a_initiate_chat( # noqa: F704\n",
" manager,\n",
" message=\"\"\"\n",
"1) Create a timer for 5 seconds.\n",
"2) a stopwatch for 5 seconds.\n",
"3) Pretty print the result as md.\n",
"4) when 1-3 are done, terminate the group chat\"\"\",\n",
")"
"# todo: remove comment after fixing https://github.com/microsoft/autogen/issues/1205\n",
"# await user_proxy.a_initiate_chat( # noqa: F704\n",
"# manager,\n",
"# message=\"\"\"\n",
"# 1) Create a timer for 5 seconds.\n",
"# 2) a stopwatch for 5 seconds.\n",
"# 3) Pretty print the result as md.\n",
"# 4) when 1-3 are done, terminate the group chat\"\"\",\n",
"# )"
]
},
{

View File

@ -1,12 +1,18 @@
import asyncio
import copy
import sys
import time
from typing import Any, Callable, Dict, Literal
import unittest
import pytest
from unittest.mock import patch
from pydantic import BaseModel, Field
from typing_extensions import Annotated
import autogen
from autogen.agentchat import ConversableAgent, UserProxyAgent
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
from conftest import skip_openai
try:
@ -445,6 +451,8 @@ def test__wrap_function_sync():
== '{"currency":"EUR","amount":100.1}'
)
assert not asyncio.coroutines.iscoroutinefunction(currency_calculator)
@pytest.mark.asyncio
async def test__wrap_function_async():
@ -481,6 +489,8 @@ async def test__wrap_function_async():
== '{"currency":"EUR","amount":100.1}'
)
assert asyncio.coroutines.iscoroutinefunction(currency_calculator)
def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]]:
return {k: v._origin for k, v in d.items()}
@ -624,6 +634,161 @@ def test_register_for_execution():
assert get_origin(user_proxy_1.function_map) == expected_function_map
@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
def test_function_registration_e2e_sync() -> None:
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict={
"model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
},
file_location=KEY_LOC,
)
llm_config = {
"config_list": config_list,
}
coder = autogen.AssistantAgent(
name="chatbot",
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
llm_config=llm_config,
)
# create a UserProxyAgent instance named "user_proxy"
user_proxy = autogen.UserProxyAgent(
name="user_proxy",
system_message="A proxy for the user for executing code.",
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir": "coding"},
)
# define functions according to the function description
timer_mock = unittest.mock.MagicMock()
stopwatch_mock = unittest.mock.MagicMock()
# An example async function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a timer for N seconds")
def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
print("timer is running")
for i in range(int(num_seconds)):
print(".", end="")
time.sleep(0.01)
print()
timer_mock(num_seconds=num_seconds)
return "Timer is done!"
# An example sync function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a stopwatch for N seconds")
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
print("stopwatch is running")
# assert False, "stopwatch's alive!"
for i in range(int(num_seconds)):
print(".", end="")
time.sleep(0.01)
print()
stopwatch_mock(num_seconds=num_seconds)
return "Stopwatch is done!"
# start the conversation
# 'await' is used to pause and resume code execution for async IO operations.
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
user_proxy.initiate_chat( # noqa: F704
coder,
message="Create a timer for 2 seconds and then a stopwatch for 3 seconds.",
)
timer_mock.assert_called_once_with(num_seconds="2")
stopwatch_mock.assert_called_once_with(num_seconds="3")
@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
@pytest.mark.asyncio()
async def test_function_registration_e2e_async() -> None:
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict={
"model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
},
file_location=KEY_LOC,
)
llm_config = {
"config_list": config_list,
}
coder = autogen.AssistantAgent(
name="chatbot",
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
llm_config=llm_config,
)
# create a UserProxyAgent instance named "user_proxy"
user_proxy = autogen.UserProxyAgent(
name="user_proxy",
system_message="A proxy for the user for executing code.",
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir": "coding"},
)
# define functions according to the function description
timer_mock = unittest.mock.MagicMock()
stopwatch_mock = unittest.mock.MagicMock()
# An example async function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a timer for N seconds")
async def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
print("timer is running")
for i in range(int(num_seconds)):
print(".", end="")
await asyncio.sleep(0.01)
print()
timer_mock(num_seconds=num_seconds)
return "Timer is done!"
# An example sync function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a stopwatch for N seconds")
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
print("stopwatch is running")
# assert False, "stopwatch's alive!"
for i in range(int(num_seconds)):
print(".", end="")
time.sleep(0.01)
print()
stopwatch_mock(num_seconds=num_seconds)
return "Stopwatch is done!"
# start the conversation
# 'await' is used to pause and resume code execution for async IO operations.
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
await user_proxy.a_initiate_chat( # noqa: F704
coder,
message="Create a timer for 4 seconds and then a stopwatch for 5 seconds.",
)
timer_mock.assert_called_once_with(num_seconds="4")
stopwatch_mock.assert_called_once_with(num_seconds="5")
@pytest.mark.skipif(
skip,
reason="do not run if skipping openai",

View File

@ -1,3 +1,4 @@
import asyncio
import inspect
import unittest.mock
from typing import Dict, List, Literal, Optional, Tuple
@ -355,7 +356,7 @@ def test_get_load_param_if_needed_function() -> None:
assert actual == expected, actual
def test_load_basemodels_if_needed() -> None:
def test_load_basemodels_if_needed_sync() -> None:
@load_basemodels_if_needed
def f(
base: Annotated[Currency, "Base currency"],
@ -363,6 +364,8 @@ def test_load_basemodels_if_needed() -> None:
) -> Tuple[Currency, CurrencySymbol]:
return base, quote_currency
assert not asyncio.coroutines.iscoroutinefunction(f)
actual = f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
assert isinstance(actual[0], Currency)
assert actual[0].amount == 123.45
@ -370,6 +373,24 @@ def test_load_basemodels_if_needed() -> None:
assert actual[1] == "EUR"
@pytest.mark.asyncio
async def test_load_basemodels_if_needed_async() -> None:
@load_basemodels_if_needed
async def f(
base: Annotated[Currency, "Base currency"],
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
) -> Tuple[Currency, CurrencySymbol]:
return base, quote_currency
assert asyncio.coroutines.iscoroutinefunction(f)
actual = await f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
assert isinstance(actual[0], Currency)
assert actual[0].amount == 123.45
assert actual[0].currency == "USD"
assert actual[1] == "EUR"
def test_serialize_to_json():
assert serialize_to_str("abc") == "abc"
assert serialize_to_str(123) == "123"