Function calling upgrade (#1443)

* function calling upgraded: async/sync mixing works now for all combinations and register_function added to simplify registration of functions without decorators

* polishing

* fixing tests

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Davor Runje 2024-01-31 16:30:55 +01:00 committed by GitHub
parent 0107b52d5a
commit a2d4b47503
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 827 additions and 221 deletions

3
.gitignore vendored
View File

@ -175,3 +175,6 @@ test/test_files/agenteval-in-out/out/
# Files created by tests
*tmp_code_*
test/agentchat/test_agent_scripts/*
# test cache
.cache_test

View File

@ -1,14 +1,15 @@
from .agent import Agent
from .assistant_agent import AssistantAgent
from .conversable_agent import ConversableAgent
from .conversable_agent import ConversableAgent, register_function
from .groupchat import GroupChat, GroupChatManager
from .user_proxy_agent import UserProxyAgent
__all__ = [
__all__ = (
"Agent",
"ConversableAgent",
"AssistantAgent",
"UserProxyAgent",
"GroupChat",
"GroupChatManager",
]
"register_function",
)

View File

@ -914,9 +914,20 @@ class ConversableAgent(Agent):
func_call = message["function_call"]
func = self._function_map.get(func_call.get("name", None), None)
if inspect.iscoroutinefunction(func):
return False, None
try:
# get the running loop if it was already created
loop = asyncio.get_running_loop()
close_loop = False
except RuntimeError:
# create a loop if there is no running loop
loop = asyncio.new_event_loop()
close_loop = True
_, func_return = self.execute_function(message["function_call"])
_, func_return = loop.run_until_complete(self.a_execute_function(func_call))
if close_loop:
loop.close()
else:
_, func_return = self.execute_function(message["function_call"])
return True, func_return
return False, None
@ -943,7 +954,9 @@ class ConversableAgent(Agent):
func = self._function_map.get(func_name, None)
if func and inspect.iscoroutinefunction(func):
_, func_return = await self.a_execute_function(func_call)
return True, func_return
else:
_, func_return = self.execute_function(func_call)
return True, func_return
return False, None
@ -968,8 +981,20 @@ class ConversableAgent(Agent):
function_call = tool_call.get("function", {})
func = self._function_map.get(function_call.get("name", None), None)
if inspect.iscoroutinefunction(func):
continue
_, func_return = self.execute_function(function_call)
try:
# get the running loop if it was already created
loop = asyncio.get_running_loop()
close_loop = False
except RuntimeError:
# create a loop if there is no running loop
loop = asyncio.new_event_loop()
close_loop = True
_, func_return = loop.run_until_complete(self.a_execute_function(function_call))
if close_loop:
loop.close()
else:
_, func_return = self.execute_function(function_call)
tool_returns.append(
{
"tool_call_id": id,
@ -1986,3 +2011,30 @@ class ConversableAgent(Agent):
return None
else:
return self.client.total_usage_summary
def register_function(
f: Callable[..., Any],
*,
caller: ConversableAgent,
executor: ConversableAgent,
name: Optional[str] = None,
description: str,
) -> None:
"""Register a function to be proposed by an agent and executed for an executor.
This function can be used instead of function decorators `@ConversationAgent.register_for_llm` and
`@ConversationAgent.register_for_execution`.
Args:
f: the function to be registered.
caller: the agent calling the function, typically an instance of ConversableAgent.
executor: the agent executing the function, typically an instance of UserProxy.
name: name of the function. If None, the function name will be used (default: None).
description: description of the function. The description is used by LLM to decode whether the function
is called. Make sure the description is properly describing what the function does or it might not be
called by LLM when needed.
"""
f = caller.register_for_llm(name=name, description=description)(f)
executor.register_for_execution(name=name)(f)

File diff suppressed because one or more lines are too long

View File

@ -61,7 +61,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "dca301a4",
"metadata": {},
"outputs": [],
@ -71,6 +71,7 @@
"from typing_extensions import Annotated\n",
"\n",
"import autogen\n",
"from autogen.cache import Cache\n",
"\n",
"config_list = autogen.config_list_from_json(\n",
" \"OAI_CONFIG_LIST\",\n",
@ -119,65 +120,10 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "9fb85afb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"Create a timer for 5 seconds and then a stopwatch for 5 seconds.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_fGgH8U261nOnx3JGNJWslhh6): timer *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"timer\" *****\u001b[0m\n",
"Timer is done!\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_BZs6ynF8gtcZKhONiIRZkECB): stopwatch *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m**************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION stopwatch...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"stopwatch\" *****\u001b[0m\n",
"Stopwatch is done!\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"outputs": [],
"source": [
"llm_config = {\n",
" \"config_list\": config_list,\n",
@ -201,7 +147,7 @@
"\n",
"# define functions according to the function description\n",
"\n",
"# An example async function\n",
"# An example async function registered using register_for_llm and register_for_execution decorators\n",
"\n",
"\n",
"@user_proxy.register_for_execution()\n",
@ -213,25 +159,104 @@
" return \"Timer is done!\"\n",
"\n",
"\n",
"# An example sync function\n",
"@user_proxy.register_for_execution()\n",
"@coder.register_for_llm(description=\"create a stopwatch for N seconds\")\n",
"# An example sync function registered using register_function\n",
"def stopwatch(num_seconds: Annotated[str, \"Number of seconds in the stopwatch.\"]) -> str:\n",
" for i in range(int(num_seconds)):\n",
" time.sleep(1)\n",
" return \"Stopwatch is done!\"\n",
"\n",
"\n",
"# start the conversation\n",
"# 'await' is used to pause and resume code execution for async IO operations.\n",
"# Without 'await', an async function returns a coroutine object but doesn't execute the function.\n",
"# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.\n",
"await user_proxy.a_initiate_chat( # noqa: F704\n",
" coder,\n",
" message=\"Create a timer for 5 seconds and then a stopwatch for 5 seconds.\",\n",
"autogen.agentchat.register_function(\n",
" stopwatch,\n",
" caller=coder,\n",
" executor=user_proxy,\n",
" description=\"create a stopwatch for N seconds\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "159cd7b6",
"metadata": {},
"source": [
"Start the conversation. `await` is used to pause and resume code execution for async IO operations. Without `await`, an async function returns a coroutine object but doesn't execute the function. With `await`, the async function is executed and the current function is paused until the awaited function returns a result."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "37514ea3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"Create a timer for 5 seconds and then a stopwatch for 5 seconds.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_h6324df0CdGPDNjPO8GrnAQJ): timer *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_h6324df0CdGPDNjPO8GrnAQJ\" *****\u001b[0m\n",
"Timer is done!\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_7SzbQxI8Nsl6dPQtScoSGPAu): stopwatch *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
"\u001b[32m**************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION stopwatch...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_7SzbQxI8Nsl6dPQtScoSGPAu\" *****\u001b[0m\n",
"Stopwatch is done!\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"with Cache.disk():\n",
" await user_proxy.a_initiate_chat( # noqa: F704\n",
" coder,\n",
" message=\"Create a timer for 5 seconds and then a stopwatch for 5 seconds.\",\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "950f3de7",
@ -243,7 +268,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "2472f95c",
"metadata": {},
"outputs": [],
@ -276,9 +301,17 @@
")"
]
},
{
"cell_type": "markdown",
"id": "612bdd22",
"metadata": {},
"source": [
"Finally, we initialize the chat that would use the functions defined above:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "e2c9267a",
"metadata": {},
"outputs": [
@ -293,14 +326,21 @@
"2) Pretty print the result as md.\n",
"3) when 1 and 2 are done, terminate the group chat\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_zlHKR9LBzCqs1iLId5kvNvJ5): timer *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_qlS3QkcY1NkfgpKtCoR6oGo7): timer *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\": \"5\"}\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_rH1dgbS9itiJO1Gwnxxhcm35): stopwatch *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_TEHlvMgCp0S3RzBbVsVPXWeL): stopwatch *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\": \"5\"}\n",
"\u001b[32m**************************************************************************\u001b[0m\n",
@ -314,29 +354,23 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling tool \"timer\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_qlS3QkcY1NkfgpKtCoR6oGo7\" *****\u001b[0m\n",
"Timer is done!\n",
"\u001b[32m**********************************************\u001b[0m\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling tool \"stopwatch\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_TEHlvMgCp0S3RzBbVsVPXWeL\" *****\u001b[0m\n",
"Stopwatch is done!\n",
"\u001b[32m**************************************************\u001b[0m\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mMarkdown_agent\u001b[0m (to chat_manager):\n",
"\n",
"The results of the timer and stopwatch are as follows:\n",
"\n",
"- Timer: Timer is done!\n",
"- Stopwatch: Stopwatch is done!\n",
"\n",
"Now, I will proceed to terminate the group chat.\n",
"\u001b[32m***** Suggested tool Call (call_3Js7oU80vPatnA8IiaKXB5Xu): terminate_group_chat *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_JuQwvj4FigfvGyBeTMglY2ee): terminate_group_chat *****\u001b[0m\n",
"Arguments: \n",
"{\"message\":\"The session has concluded, and the group chat will now be terminated.\"}\n",
"{\"message\":\"Both timer and stopwatch have completed their countdowns. The group chat is now being terminated.\"}\n",
"\u001b[32m*************************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
@ -346,23 +380,26 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\u001b[32m***** Response from calling tool \"terminate_group_chat\" *****\u001b[0m\n",
"[GROUPCHAT_TERMINATE] The session has concluded, and the group chat will now be terminated.\n",
"\u001b[32m*************************************************************\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_JuQwvj4FigfvGyBeTMglY2ee\" *****\u001b[0m\n",
"[GROUPCHAT_TERMINATE] Both timer and stopwatch have completed their countdowns. The group chat is now being terminated.\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"# todo: remove comment after fixing https://github.com/microsoft/autogen/issues/1205\n",
"await user_proxy.a_initiate_chat( # noqa: F704\n",
" manager,\n",
" message=\"\"\"\n",
"message = \"\"\"\n",
"1) Create a timer and a stopwatch for 5 seconds each in parallel.\n",
"2) Pretty print the result as md.\n",
"3) when 1 and 2 are done, terminate the group chat\"\"\",\n",
")"
"3) when 1 and 2 are done, terminate the group chat\n",
"\"\"\"\n",
"\n",
"with Cache.disk():\n",
" await user_proxy.a_initiate_chat( # noqa: F704\n",
" manager,\n",
" message=message,\n",
" )"
]
},
{
@ -390,7 +427,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.10.13"
}
},
"nbformat": 4,

View File

@ -63,6 +63,8 @@
"from typing_extensions import Annotated\n",
"\n",
"import autogen\n",
"from autogen.cache import Cache\n",
"\n",
"\n",
"config_list = autogen.config_list_from_json(\n",
" \"OAI_CONFIG_LIST\",\n",
@ -274,9 +276,9 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"currency_calculator\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_ubo7cKE3TKumGHkqGjQtZisy\" *****\u001b[0m\n",
"112.22727272727272 EUR\n",
"\u001b[32m************************************************************\u001b[0m\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
@ -298,11 +300,12 @@
}
],
"source": [
"# start the conversation\n",
"user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 123.45 USD in EUR?\",\n",
")"
"with Cache.disk():\n",
" # start the conversation\n",
" user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 123.45 USD in EUR?\",\n",
" )"
]
},
{
@ -353,14 +356,21 @@
" amount: Annotated[float, Field(0, description=\"Amount of currency\", ge=0)]\n",
"\n",
"\n",
"@user_proxy.register_for_execution()\n",
"@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n",
"# another way to register a function is to use register_function instead of register_for_execution and register_for_llm decorators\n",
"def currency_calculator(\n",
" base: Annotated[Currency, \"Base currency: amount and currency symbol\"],\n",
" quote_currency: Annotated[CurrencySymbol, \"Quote currency symbol\"] = \"USD\",\n",
") -> Currency:\n",
" quote_amount = exchange_rate(base.currency, quote_currency) * base.amount\n",
" return Currency(amount=quote_amount, currency=quote_currency)"
" return Currency(amount=quote_amount, currency=quote_currency)\n",
"\n",
"\n",
"autogen.agentchat.register_function(\n",
" currency_calculator,\n",
" caller=chatbot,\n",
" executor=user_proxy,\n",
" description=\"Currency exchange calculator.\",\n",
")"
]
},
{
@ -434,14 +444,14 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"currency_calculator\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_0VuU2rATuOgYrGmcBnXzPXlh\" *****\u001b[0m\n",
"{\"currency\":\"USD\",\"amount\":123.45300000000002}\n",
"\u001b[32m************************************************************\u001b[0m\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
"112.23 Euros is approximately 123.45 US Dollars.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
@ -458,11 +468,12 @@
}
],
"source": [
"# start the conversation\n",
"user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 112.23 Euros in US Dollars?\",\n",
")"
"with Cache.disk():\n",
" # start the conversation\n",
" user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 112.23 Euros in US Dollars?\",\n",
" )"
]
},
{
@ -494,9 +505,9 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"currency_calculator\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_A6lqMu7s5SyDvftTSeQTtPcj\" *****\u001b[0m\n",
"{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n",
"\u001b[32m************************************************************\u001b[0m\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
@ -518,11 +529,12 @@
}
],
"source": [
"# start the conversation\n",
"user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 123.45 US Dollars in Euros?\",\n",
")"
"with Cache.disk():\n",
" # start the conversation\n",
" user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 123.45 US Dollars in Euros?\",\n",
" )"
]
},
{

View File

@ -13,6 +13,7 @@ from typing_extensions import Annotated
import autogen
from autogen.agentchat import ConversableAgent, UserProxyAgent
from autogen.agentchat.conversable_agent import register_function
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
from conftest import skip_openai
@ -823,6 +824,47 @@ def test_register_for_execution():
assert get_origin(user_proxy_1.function_map) == expected_function_map
def test_register_functions():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
agent = ConversableAgent(name="agent", llm_config={"config_list": []})
user_proxy = UserProxyAgent(name="user_proxy")
def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str:
pass
register_function(
exec_python,
caller=agent,
executor=user_proxy,
description="run cell in ipython and return the execution result.",
)
expected_function_map = {"exec_python": exec_python}
assert get_origin(user_proxy.function_map) == expected_function_map
expected = [
{
"type": "function",
"function": {
"description": "run cell in ipython and return the execution result.",
"name": "exec_python",
"parameters": {
"type": "object",
"properties": {
"cell": {
"type": "string",
"description": "Valid Python cell to execute.",
}
},
"required": ["cell"],
},
},
}
]
assert agent.llm_config["tools"] == expected
@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
@ -860,7 +902,7 @@ def test_function_registration_e2e_sync() -> None:
timer_mock = unittest.mock.MagicMock()
stopwatch_mock = unittest.mock.MagicMock()
# An example async function
# An example async function registered using decorators
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a timer for N seconds")
def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
@ -873,9 +915,7 @@ def test_function_registration_e2e_sync() -> None:
timer_mock(num_seconds=num_seconds)
return "Timer is done!"
# An example sync function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a stopwatch for N seconds")
# An example sync function registered using register_function
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
print("stopwatch is running")
# assert False, "stopwatch's alive!"
@ -887,6 +927,8 @@ def test_function_registration_e2e_sync() -> None:
stopwatch_mock(num_seconds=num_seconds)
return "Stopwatch is done!"
register_function(stopwatch, caller=coder, executor=user_proxy, description="create a stopwatch for N seconds")
# start the conversation
# 'await' is used to pause and resume code execution for async IO operations.
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
@ -938,9 +980,7 @@ async def test_function_registration_e2e_async() -> None:
timer_mock = unittest.mock.MagicMock()
stopwatch_mock = unittest.mock.MagicMock()
# An example async function
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a timer for N seconds")
# An example async function registered using register_function
async def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
print("timer is running")
for i in range(int(num_seconds)):
@ -951,7 +991,9 @@ async def test_function_registration_e2e_async() -> None:
timer_mock(num_seconds=num_seconds)
return "Timer is done!"
# An example sync function
register_function(timer, caller=coder, executor=user_proxy, description="create a timer for N seconds")
# An example sync function registered using decorators
@user_proxy.register_for_execution()
@coder.register_for_llm(description="create a stopwatch for N seconds")
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:

View File

@ -0,0 +1,378 @@
import json
from typing import Any, Callable, Dict, List
import pytest
from autogen.agentchat.conversable_agent import ConversableAgent
def _tool_func_1(arg1: str, arg2: str) -> str:
return f"_tool_func_1: {arg1} {arg2}"
def _tool_func_2(arg1: str, arg2: str) -> str:
return f"_tool_func_2: {arg1} {arg2}"
def _tool_func_error(arg1: str, arg2: str) -> str:
raise RuntimeError("Error in tool function")
async def _a_tool_func_1(arg1: str, arg2: str) -> str:
return f"_tool_func_1: {arg1} {arg2}"
async def _a_tool_func_2(arg1: str, arg2: str) -> str:
return f"_tool_func_2: {arg1} {arg2}"
async def _a_tool_func_error(arg1: str, arg2: str) -> str:
raise RuntimeError("Error in tool function")
_tool_use_message_1 = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "1",
"type": "function",
"function": {
"name": "_tool_func_1",
"arguments": json.dumps({"arg1": "value1", "arg2": "value2"}),
},
},
{
"id": "2",
"type": "function",
"function": {
"name": "_tool_func_2",
"arguments": json.dumps({"arg1": "value3", "arg2": "value4"}),
},
},
],
}
_tool_use_message_1_bad_json = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "1",
"type": "function",
"function": {
"name": "_tool_func_1",
# add extra comma to make json invalid
"arguments": json.dumps({"arg1": "value3", "arg2": "value4"})[:-1] + ",}",
},
},
{
"id": "2",
"type": "function",
"function": {
"name": "_tool_func_2",
"arguments": json.dumps({"arg1": "value3", "arg2": "value4"}),
},
},
],
}
_tool_use_message_1_expected_reply = {
"role": "tool",
"tool_responses": [
{"tool_call_id": "1", "role": "tool", "content": "_tool_func_1: value1 value2"},
{"tool_call_id": "2", "role": "tool", "content": "_tool_func_2: value3 value4"},
],
# "content": "Tool Call Id: 1\n_tool_func_1: value1 value2\n\nTool Call Id: 2\n_tool_func_2: value3 value4",
"content": "_tool_func_1: value1 value2\n\n_tool_func_2: value3 value4",
}
_tool_use_message_1_bad_json_expected_reply = {
"role": "tool",
"tool_responses": [
{
"tool_call_id": "1",
"role": "tool",
"content": "Error: Expecting property name enclosed in double quotes: line 1 column 37 (char 36)\n You argument should follow json format.",
},
{"tool_call_id": "2", "role": "tool", "content": "_tool_func_2: value3 value4"},
],
"content": "Error: Expecting property name enclosed in double quotes: line 1 column 37 (char 36)\n You argument should follow json format.\n\n_tool_func_2: value3 value4",
}
_tool_use_message_1_error_expected_reply = {
"role": "tool",
"tool_responses": [
{"tool_call_id": "1", "role": "tool", "content": "_tool_func_1: value1 value2"},
{
"tool_call_id": "2",
"role": "tool",
"content": "Error: Error in tool function",
},
],
"content": "_tool_func_1: value1 value2\n\nError: Error in tool function",
}
_tool_use_message_1_not_found_expected_reply = {
"role": "tool",
"tool_responses": [
{"tool_call_id": "1", "role": "tool", "content": "_tool_func_1: value1 value2"},
{
"tool_call_id": "2",
"role": "tool",
"content": "Error: Function _tool_func_2 not found.",
},
],
"content": "_tool_func_1: value1 value2\n\nError: Function _tool_func_2 not found.",
}
_function_use_message_1 = {
"role": "assistant",
"content": None,
"function_call": {
"name": "_tool_func_1",
"arguments": json.dumps({"arg1": "value1", "arg2": "value2"}),
},
}
_function_use_message_1_bad_json = {
"role": "assistant",
"content": None,
"function_call": {
"name": "_tool_func_1",
"arguments": json.dumps({"arg1": "value1", "arg2": "value2"})[:-1] + ",}",
},
}
_function_use_message_1_expected_reply = {
"name": "_tool_func_1",
"role": "function",
"content": "_tool_func_1: value1 value2",
}
_function_use_message_1_bad_json_expected_reply = {
"name": "_tool_func_1",
"role": "function",
"content": "Error: Expecting property name enclosed in double quotes: line 1 column 37 (char 36)\n You argument should follow json format.",
}
_function_use_message_1_error_expected_reply = {
"name": "_tool_func_1",
"role": "function",
"content": "Error: Error in tool function",
}
_function_use_message_1_not_found_expected_reply = {
"name": "_tool_func_1",
"role": "function",
"content": "Error: Function _tool_func_1 not found.",
}
_text_message = {"content": "Hi!", "role": "user"}
def _get_function_map(is_function_async: bool, drop_tool_2: bool = False) -> Dict[str, Callable[..., Any]]:
if is_function_async:
return (
{
"_tool_func_1": _a_tool_func_1,
"_tool_func_2": _a_tool_func_2,
}
if not drop_tool_2
else {
"_tool_func_1": _a_tool_func_1,
}
)
else:
return (
{
"_tool_func_1": _tool_func_1,
"_tool_func_2": _tool_func_2,
}
if not drop_tool_2
else {
"_tool_func_1": _tool_func_1,
}
)
def _get_error_function_map(
is_function_async: bool, error_on_tool_func_2: bool = True
) -> Dict[str, Callable[..., Any]]:
if is_function_async:
return {
"_tool_func_1": _a_tool_func_1 if error_on_tool_func_2 else _a_tool_func_error,
"_tool_func_2": _a_tool_func_error if error_on_tool_func_2 else _a_tool_func_2,
}
else:
return {
"_tool_func_1": _tool_func_1 if error_on_tool_func_2 else _tool_func_error,
"_tool_func_2": _tool_func_error if error_on_tool_func_2 else _tool_func_2,
}
@pytest.mark.parametrize("is_function_async", [True, False])
def test_generate_function_call_reply_on_function_call_message(is_function_async: bool) -> None:
agent = ConversableAgent(name="agent", llm_config=False)
# empty function_map
agent._function_map = {}
messages = [_function_use_message_1]
finished, retval = agent.generate_function_call_reply(messages)
assert (finished, retval) == (True, _function_use_message_1_not_found_expected_reply)
# function map set
agent._function_map = _get_function_map(is_function_async)
# correct function call, multiple times to make sure cleanups are done properly
for _ in range(3):
messages = [_function_use_message_1]
finished, retval = agent.generate_function_call_reply(messages)
assert (finished, retval) == (True, _function_use_message_1_expected_reply)
# bad JSON
messages = [_function_use_message_1_bad_json]
finished, retval = agent.generate_function_call_reply(messages)
assert (finished, retval) == (True, _function_use_message_1_bad_json_expected_reply)
# tool call
messages = [_tool_use_message_1]
finished, retval = agent.generate_function_call_reply(messages)
assert (finished, retval) == (False, None)
# text message
messages: List[Dict[str, str]] = [_text_message]
finished, retval = agent.generate_function_call_reply(messages)
assert (finished, retval) == (False, None)
# error in function (raises Exception)
agent._function_map = _get_error_function_map(is_function_async, error_on_tool_func_2=False)
messages = [_function_use_message_1]
finished, retval = agent.generate_function_call_reply(messages)
assert (finished, retval) == (True, _function_use_message_1_error_expected_reply)
@pytest.mark.asyncio()
@pytest.mark.parametrize("is_function_async", [True, False])
async def test_a_generate_function_call_reply_on_function_call_message(is_function_async: bool) -> None:
agent = ConversableAgent(name="agent", llm_config=False)
# empty function_map
agent._function_map = {}
messages = [_function_use_message_1]
finished, retval = await agent.a_generate_function_call_reply(messages)
assert (finished, retval) == (True, _function_use_message_1_not_found_expected_reply)
# function map set
agent._function_map = _get_function_map(is_function_async)
# correct function call, multiple times to make sure cleanups are done properly
for _ in range(3):
messages = [_function_use_message_1]
finished, retval = await agent.a_generate_function_call_reply(messages)
assert (finished, retval) == (True, _function_use_message_1_expected_reply)
# bad JSON
messages = [_function_use_message_1_bad_json]
finished, retval = await agent.a_generate_function_call_reply(messages)
assert (finished, retval) == (True, _function_use_message_1_bad_json_expected_reply)
# tool call
messages = [_tool_use_message_1]
finished, retval = await agent.a_generate_function_call_reply(messages)
assert (finished, retval) == (False, None)
# text message
messages: List[Dict[str, str]] = [_text_message]
finished, retval = await agent.a_generate_function_call_reply(messages)
assert (finished, retval) == (False, None)
# error in function (raises Exception)
agent._function_map = _get_error_function_map(is_function_async, error_on_tool_func_2=False)
messages = [_function_use_message_1]
finished, retval = await agent.a_generate_function_call_reply(messages)
assert (finished, retval) == (True, _function_use_message_1_error_expected_reply)
@pytest.mark.parametrize("is_function_async", [True, False])
def test_generate_tool_calls_reply_on_function_call_message(is_function_async: bool) -> None:
agent = ConversableAgent(name="agent", llm_config=False)
# empty function_map
agent._function_map = _get_function_map(is_function_async, drop_tool_2=True)
messages = [_tool_use_message_1]
finished, retval = agent.generate_tool_calls_reply(messages)
assert (finished, retval) == (True, _tool_use_message_1_not_found_expected_reply)
# function map set
agent._function_map = _get_function_map(is_function_async)
# correct function call, multiple times to make sure cleanups are done properly
for _ in range(3):
messages = [_tool_use_message_1]
finished, retval = agent.generate_tool_calls_reply(messages)
assert (finished, retval) == (True, _tool_use_message_1_expected_reply)
# bad JSON
messages = [_tool_use_message_1_bad_json]
finished, retval = agent.generate_tool_calls_reply(messages)
assert (finished, retval) == (True, _tool_use_message_1_bad_json_expected_reply)
# function call
messages = [_function_use_message_1]
finished, retval = agent.generate_tool_calls_reply(messages)
assert (finished, retval) == (False, None)
# text message
messages: List[Dict[str, str]] = [_text_message]
finished, retval = agent.generate_tool_calls_reply(messages)
assert (finished, retval) == (False, None)
# error in function (raises Exception)
agent._function_map = _get_error_function_map(is_function_async)
messages = [_tool_use_message_1]
finished, retval = agent.generate_tool_calls_reply(messages)
assert (finished, retval) == (True, _tool_use_message_1_error_expected_reply)
@pytest.mark.asyncio()
@pytest.mark.parametrize("is_function_async", [True, False])
async def test_a_generate_tool_calls_reply_on_function_call_message(is_function_async: bool) -> None:
agent = ConversableAgent(name="agent", llm_config=False)
# empty function_map
agent._function_map = _get_function_map(is_function_async, drop_tool_2=True)
messages = [_tool_use_message_1]
finished, retval = await agent.a_generate_tool_calls_reply(messages)
assert (finished, retval) == (True, _tool_use_message_1_not_found_expected_reply)
# function map set
agent._function_map = _get_function_map(is_function_async)
# correct function call, multiple times to make sure cleanups are done properly
for _ in range(3):
messages = [_tool_use_message_1]
finished, retval = await agent.a_generate_tool_calls_reply(messages)
assert (finished, retval) == (True, _tool_use_message_1_expected_reply)
# bad JSON
messages = [_tool_use_message_1_bad_json]
finished, retval = await agent.a_generate_tool_calls_reply(messages)
assert (finished, retval) == (True, _tool_use_message_1_bad_json_expected_reply)
# function call
messages = [_function_use_message_1]
finished, retval = await agent.a_generate_tool_calls_reply(messages)
assert (finished, retval) == (False, None)
# text message
messages: List[Dict[str, str]] = [_text_message]
finished, retval = await agent.a_generate_tool_calls_reply(messages)
assert (finished, retval) == (False, None)
# error in function (raises Exception)
agent._function_map = _get_error_function_map(is_function_async)
messages = [_tool_use_message_1]
finished, retval = await agent.a_generate_tool_calls_reply(messages)
assert (finished, retval) == (True, _tool_use_message_1_error_expected_reply)

View File

@ -102,7 +102,6 @@ user_proxy = autogen.UserProxyAgent(
``` python
CurrencySymbol = Literal["USD", "EUR"]
def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:
if base_currency == quote_currency:
return 1.0
@ -156,12 +155,30 @@ you can call the decorators as functions:
```python
# Register the function with the chatbot's llm_config.
chatbot.register_for_llm(description="Currency exchange calculator.")(currency_calculator)
currency_calculator = chatbot.register_for_llm(description="Currency exchange calculator.")(currency_calculator)
# Register the function with the user_proxy's function_map.
user_proxy.register_for_execution()(currency_calculator)
```
Alternatevely, you can also use `autogen.agentchat.register_function()` instead as follows:
```python
def currency_calculator(
base_amount: Annotated[float, "Amount of currency in base_currency"],
base_currency: Annotated[CurrencySymbol, "Base currency"] = "USD",
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
) -> str:
quote_amount = exchange_rate(base_currency, quote_currency) * base_amount
return f"{quote_amount} {quote_currency}"
autogen.agentchat.register_function(
currency_calculator,
agent=chatbot,
executor=user_proxy,
description="Currency exchange calculator.",
)
```
4. Agents can now use the function as follows:
```python
user_proxy.initiate_chat(
@ -216,14 +233,19 @@ class Currency(BaseModel):
# parameter of type float, must be greater or equal to 0 with default value 0
amount: Annotated[float, Field(0, description="Amount of currency", ge=0)]
@user_proxy.register_for_execution()
@chatbot.register_for_llm(description="Currency exchange calculator.")
def currency_calculator(
base: Annotated[Currency, "Base currency: amount and currency symbol"],
quote_currency: Annotated[CurrencySymbol, "Quote currency symbol"] = "USD",
) -> Currency:
quote_amount = exchange_rate(base.currency, quote_currency) * base.amount
return Currency(amount=quote_amount, currency=quote_currency)
autogen.agentchat.register_function(
currency_calculator,
agent=chatbot,
executor=user_proxy,
description="Currency exchange calculator.",
)
```
The generated JSON schema has additional properties such as minimum value encoded: