mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 22:18:53 +00:00
Function calling upgrade (#1443)
* function calling upgraded: async/sync mixing works now for all combinations and register_function added to simplify registration of functions without decorators * polishing * fixing tests --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
0107b52d5a
commit
a2d4b47503
3
.gitignore
vendored
3
.gitignore
vendored
@ -175,3 +175,6 @@ test/test_files/agenteval-in-out/out/
|
||||
# Files created by tests
|
||||
*tmp_code_*
|
||||
test/agentchat/test_agent_scripts/*
|
||||
|
||||
# test cache
|
||||
.cache_test
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
from .agent import Agent
|
||||
from .assistant_agent import AssistantAgent
|
||||
from .conversable_agent import ConversableAgent
|
||||
from .conversable_agent import ConversableAgent, register_function
|
||||
from .groupchat import GroupChat, GroupChatManager
|
||||
from .user_proxy_agent import UserProxyAgent
|
||||
|
||||
__all__ = [
|
||||
__all__ = (
|
||||
"Agent",
|
||||
"ConversableAgent",
|
||||
"AssistantAgent",
|
||||
"UserProxyAgent",
|
||||
"GroupChat",
|
||||
"GroupChatManager",
|
||||
]
|
||||
"register_function",
|
||||
)
|
||||
|
||||
@ -914,9 +914,20 @@ class ConversableAgent(Agent):
|
||||
func_call = message["function_call"]
|
||||
func = self._function_map.get(func_call.get("name", None), None)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return False, None
|
||||
try:
|
||||
# get the running loop if it was already created
|
||||
loop = asyncio.get_running_loop()
|
||||
close_loop = False
|
||||
except RuntimeError:
|
||||
# create a loop if there is no running loop
|
||||
loop = asyncio.new_event_loop()
|
||||
close_loop = True
|
||||
|
||||
_, func_return = self.execute_function(message["function_call"])
|
||||
_, func_return = loop.run_until_complete(self.a_execute_function(func_call))
|
||||
if close_loop:
|
||||
loop.close()
|
||||
else:
|
||||
_, func_return = self.execute_function(message["function_call"])
|
||||
return True, func_return
|
||||
return False, None
|
||||
|
||||
@ -943,7 +954,9 @@ class ConversableAgent(Agent):
|
||||
func = self._function_map.get(func_name, None)
|
||||
if func and inspect.iscoroutinefunction(func):
|
||||
_, func_return = await self.a_execute_function(func_call)
|
||||
return True, func_return
|
||||
else:
|
||||
_, func_return = self.execute_function(func_call)
|
||||
return True, func_return
|
||||
|
||||
return False, None
|
||||
|
||||
@ -968,8 +981,20 @@ class ConversableAgent(Agent):
|
||||
function_call = tool_call.get("function", {})
|
||||
func = self._function_map.get(function_call.get("name", None), None)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
continue
|
||||
_, func_return = self.execute_function(function_call)
|
||||
try:
|
||||
# get the running loop if it was already created
|
||||
loop = asyncio.get_running_loop()
|
||||
close_loop = False
|
||||
except RuntimeError:
|
||||
# create a loop if there is no running loop
|
||||
loop = asyncio.new_event_loop()
|
||||
close_loop = True
|
||||
|
||||
_, func_return = loop.run_until_complete(self.a_execute_function(function_call))
|
||||
if close_loop:
|
||||
loop.close()
|
||||
else:
|
||||
_, func_return = self.execute_function(function_call)
|
||||
tool_returns.append(
|
||||
{
|
||||
"tool_call_id": id,
|
||||
@ -1986,3 +2011,30 @@ class ConversableAgent(Agent):
|
||||
return None
|
||||
else:
|
||||
return self.client.total_usage_summary
|
||||
|
||||
|
||||
def register_function(
|
||||
f: Callable[..., Any],
|
||||
*,
|
||||
caller: ConversableAgent,
|
||||
executor: ConversableAgent,
|
||||
name: Optional[str] = None,
|
||||
description: str,
|
||||
) -> None:
|
||||
"""Register a function to be proposed by an agent and executed for an executor.
|
||||
|
||||
This function can be used instead of function decorators `@ConversationAgent.register_for_llm` and
|
||||
`@ConversationAgent.register_for_execution`.
|
||||
|
||||
Args:
|
||||
f: the function to be registered.
|
||||
caller: the agent calling the function, typically an instance of ConversableAgent.
|
||||
executor: the agent executing the function, typically an instance of UserProxy.
|
||||
name: name of the function. If None, the function name will be used (default: None).
|
||||
description: description of the function. The description is used by LLM to decode whether the function
|
||||
is called. Make sure the description is properly describing what the function does or it might not be
|
||||
called by LLM when needed.
|
||||
|
||||
"""
|
||||
f = caller.register_for_llm(name=name, description=description)(f)
|
||||
executor.register_for_execution(name=name)(f)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -61,7 +61,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"id": "dca301a4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -71,6 +71,7 @@
|
||||
"from typing_extensions import Annotated\n",
|
||||
"\n",
|
||||
"import autogen\n",
|
||||
"from autogen.cache import Cache\n",
|
||||
"\n",
|
||||
"config_list = autogen.config_list_from_json(\n",
|
||||
" \"OAI_CONFIG_LIST\",\n",
|
||||
@ -119,65 +120,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"id": "9fb85afb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"Create a timer for 5 seconds and then a stopwatch for 5 seconds.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_fGgH8U261nOnx3JGNJWslhh6): timer *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"timer\" *****\u001b[0m\n",
|
||||
"Timer is done!\n",
|
||||
"\u001b[32m**********************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_BZs6ynF8gtcZKhONiIRZkECB): stopwatch *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m**************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING ASYNC FUNCTION stopwatch...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"stopwatch\" *****\u001b[0m\n",
|
||||
"Stopwatch is done!\n",
|
||||
"\u001b[32m**************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_config = {\n",
|
||||
" \"config_list\": config_list,\n",
|
||||
@ -201,7 +147,7 @@
|
||||
"\n",
|
||||
"# define functions according to the function description\n",
|
||||
"\n",
|
||||
"# An example async function\n",
|
||||
"# An example async function registered using register_for_llm and register_for_execution decorators\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@user_proxy.register_for_execution()\n",
|
||||
@ -213,25 +159,104 @@
|
||||
" return \"Timer is done!\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# An example sync function\n",
|
||||
"@user_proxy.register_for_execution()\n",
|
||||
"@coder.register_for_llm(description=\"create a stopwatch for N seconds\")\n",
|
||||
"# An example sync function registered using register_function\n",
|
||||
"def stopwatch(num_seconds: Annotated[str, \"Number of seconds in the stopwatch.\"]) -> str:\n",
|
||||
" for i in range(int(num_seconds)):\n",
|
||||
" time.sleep(1)\n",
|
||||
" return \"Stopwatch is done!\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# start the conversation\n",
|
||||
"# 'await' is used to pause and resume code execution for async IO operations.\n",
|
||||
"# Without 'await', an async function returns a coroutine object but doesn't execute the function.\n",
|
||||
"# With 'await', the async function is executed and the current function is paused until the awaited function returns a result.\n",
|
||||
"await user_proxy.a_initiate_chat( # noqa: F704\n",
|
||||
" coder,\n",
|
||||
" message=\"Create a timer for 5 seconds and then a stopwatch for 5 seconds.\",\n",
|
||||
"autogen.agentchat.register_function(\n",
|
||||
" stopwatch,\n",
|
||||
" caller=coder,\n",
|
||||
" executor=user_proxy,\n",
|
||||
" description=\"create a stopwatch for N seconds\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "159cd7b6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Start the conversation. `await` is used to pause and resume code execution for async IO operations. Without `await`, an async function returns a coroutine object but doesn't execute the function. With `await`, the async function is executed and the current function is paused until the awaited function returns a result."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "37514ea3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"Create a timer for 5 seconds and then a stopwatch for 5 seconds.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_h6324df0CdGPDNjPO8GrnAQJ): timer *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_h6324df0CdGPDNjPO8GrnAQJ\" *****\u001b[0m\n",
|
||||
"Timer is done!\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_7SzbQxI8Nsl6dPQtScoSGPAu): stopwatch *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\":\"5\"}\n",
|
||||
"\u001b[32m**************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING ASYNC FUNCTION stopwatch...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_7SzbQxI8Nsl6dPQtScoSGPAu\" *****\u001b[0m\n",
|
||||
"Stopwatch is done!\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with Cache.disk():\n",
|
||||
" await user_proxy.a_initiate_chat( # noqa: F704\n",
|
||||
" coder,\n",
|
||||
" message=\"Create a timer for 5 seconds and then a stopwatch for 5 seconds.\",\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "950f3de7",
|
||||
@ -243,7 +268,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 5,
|
||||
"id": "2472f95c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -276,9 +301,17 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "612bdd22",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, we initialize the chat that would use the functions defined above:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"id": "e2c9267a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -293,14 +326,21 @@
|
||||
"2) Pretty print the result as md.\n",
|
||||
"3) when 1 and 2 are done, terminate the group chat\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_zlHKR9LBzCqs1iLId5kvNvJ5): timer *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_qlS3QkcY1NkfgpKtCoR6oGo7): timer *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\": \"5\"}\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_rH1dgbS9itiJO1Gwnxxhcm35): stopwatch *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_TEHlvMgCp0S3RzBbVsVPXWeL): stopwatch *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"num_seconds\": \"5\"}\n",
|
||||
"\u001b[32m**************************************************************************\u001b[0m\n",
|
||||
@ -314,29 +354,23 @@
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"timer\" *****\u001b[0m\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_qlS3QkcY1NkfgpKtCoR6oGo7\" *****\u001b[0m\n",
|
||||
"Timer is done!\n",
|
||||
"\u001b[32m**********************************************\u001b[0m\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"stopwatch\" *****\u001b[0m\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_TEHlvMgCp0S3RzBbVsVPXWeL\" *****\u001b[0m\n",
|
||||
"Stopwatch is done!\n",
|
||||
"\u001b[32m**************************************************\u001b[0m\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mMarkdown_agent\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"The results of the timer and stopwatch are as follows:\n",
|
||||
"\n",
|
||||
"- Timer: Timer is done!\n",
|
||||
"- Stopwatch: Stopwatch is done!\n",
|
||||
"\n",
|
||||
"Now, I will proceed to terminate the group chat.\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_3Js7oU80vPatnA8IiaKXB5Xu): terminate_group_chat *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_JuQwvj4FigfvGyBeTMglY2ee): terminate_group_chat *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"message\":\"The session has concluded, and the group chat will now be terminated.\"}\n",
|
||||
"{\"message\":\"Both timer and stopwatch have completed their countdowns. The group chat is now being terminated.\"}\n",
|
||||
"\u001b[32m*************************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
@ -346,23 +380,26 @@
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"terminate_group_chat\" *****\u001b[0m\n",
|
||||
"[GROUPCHAT_TERMINATE] The session has concluded, and the group chat will now be terminated.\n",
|
||||
"\u001b[32m*************************************************************\u001b[0m\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_JuQwvj4FigfvGyBeTMglY2ee\" *****\u001b[0m\n",
|
||||
"[GROUPCHAT_TERMINATE] Both timer and stopwatch have completed their countdowns. The group chat is now being terminated.\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# todo: remove comment after fixing https://github.com/microsoft/autogen/issues/1205\n",
|
||||
"await user_proxy.a_initiate_chat( # noqa: F704\n",
|
||||
" manager,\n",
|
||||
" message=\"\"\"\n",
|
||||
"message = \"\"\"\n",
|
||||
"1) Create a timer and a stopwatch for 5 seconds each in parallel.\n",
|
||||
"2) Pretty print the result as md.\n",
|
||||
"3) when 1 and 2 are done, terminate the group chat\"\"\",\n",
|
||||
")"
|
||||
"3) when 1 and 2 are done, terminate the group chat\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"with Cache.disk():\n",
|
||||
" await user_proxy.a_initiate_chat( # noqa: F704\n",
|
||||
" manager,\n",
|
||||
" message=message,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -390,7 +427,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -63,6 +63,8 @@
|
||||
"from typing_extensions import Annotated\n",
|
||||
"\n",
|
||||
"import autogen\n",
|
||||
"from autogen.cache import Cache\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"config_list = autogen.config_list_from_json(\n",
|
||||
" \"OAI_CONFIG_LIST\",\n",
|
||||
@ -274,9 +276,9 @@
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"currency_calculator\" *****\u001b[0m\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_ubo7cKE3TKumGHkqGjQtZisy\" *****\u001b[0m\n",
|
||||
"112.22727272727272 EUR\n",
|
||||
"\u001b[32m************************************************************\u001b[0m\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
@ -298,11 +300,12 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# start the conversation\n",
|
||||
"user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 123.45 USD in EUR?\",\n",
|
||||
")"
|
||||
"with Cache.disk():\n",
|
||||
" # start the conversation\n",
|
||||
" user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 123.45 USD in EUR?\",\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -353,14 +356,21 @@
|
||||
" amount: Annotated[float, Field(0, description=\"Amount of currency\", ge=0)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@user_proxy.register_for_execution()\n",
|
||||
"@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n",
|
||||
"# another way to register a function is to use register_function instead of register_for_execution and register_for_llm decorators\n",
|
||||
"def currency_calculator(\n",
|
||||
" base: Annotated[Currency, \"Base currency: amount and currency symbol\"],\n",
|
||||
" quote_currency: Annotated[CurrencySymbol, \"Quote currency symbol\"] = \"USD\",\n",
|
||||
") -> Currency:\n",
|
||||
" quote_amount = exchange_rate(base.currency, quote_currency) * base.amount\n",
|
||||
" return Currency(amount=quote_amount, currency=quote_currency)"
|
||||
" return Currency(amount=quote_amount, currency=quote_currency)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"autogen.agentchat.register_function(\n",
|
||||
" currency_calculator,\n",
|
||||
" caller=chatbot,\n",
|
||||
" executor=user_proxy,\n",
|
||||
" description=\"Currency exchange calculator.\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -434,14 +444,14 @@
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"currency_calculator\" *****\u001b[0m\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_0VuU2rATuOgYrGmcBnXzPXlh\" *****\u001b[0m\n",
|
||||
"{\"currency\":\"USD\",\"amount\":123.45300000000002}\n",
|
||||
"\u001b[32m************************************************************\u001b[0m\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
|
||||
"112.23 Euros is approximately 123.45 US Dollars.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
@ -458,11 +468,12 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# start the conversation\n",
|
||||
"user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 112.23 Euros in US Dollars?\",\n",
|
||||
")"
|
||||
"with Cache.disk():\n",
|
||||
" # start the conversation\n",
|
||||
" user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 112.23 Euros in US Dollars?\",\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -494,9 +505,9 @@
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"currency_calculator\" *****\u001b[0m\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_A6lqMu7s5SyDvftTSeQTtPcj\" *****\u001b[0m\n",
|
||||
"{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n",
|
||||
"\u001b[32m************************************************************\u001b[0m\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
@ -518,11 +529,12 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# start the conversation\n",
|
||||
"user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 123.45 US Dollars in Euros?\",\n",
|
||||
")"
|
||||
"with Cache.disk():\n",
|
||||
" # start the conversation\n",
|
||||
" user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 123.45 US Dollars in Euros?\",\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -13,6 +13,7 @@ from typing_extensions import Annotated
|
||||
import autogen
|
||||
|
||||
from autogen.agentchat import ConversableAgent, UserProxyAgent
|
||||
from autogen.agentchat.conversable_agent import register_function
|
||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
|
||||
from conftest import skip_openai
|
||||
|
||||
@ -823,6 +824,47 @@ def test_register_for_execution():
|
||||
assert get_origin(user_proxy_1.function_map) == expected_function_map
|
||||
|
||||
|
||||
def test_register_functions():
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setenv("OPENAI_API_KEY", "mock")
|
||||
agent = ConversableAgent(name="agent", llm_config={"config_list": []})
|
||||
user_proxy = UserProxyAgent(name="user_proxy")
|
||||
|
||||
def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str:
|
||||
pass
|
||||
|
||||
register_function(
|
||||
exec_python,
|
||||
caller=agent,
|
||||
executor=user_proxy,
|
||||
description="run cell in ipython and return the execution result.",
|
||||
)
|
||||
|
||||
expected_function_map = {"exec_python": exec_python}
|
||||
assert get_origin(user_proxy.function_map) == expected_function_map
|
||||
|
||||
expected = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "run cell in ipython and return the execution result.",
|
||||
"name": "exec_python",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cell": {
|
||||
"type": "string",
|
||||
"description": "Valid Python cell to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["cell"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
assert agent.llm_config["tools"] == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip or not sys.version.startswith("3.10"),
|
||||
reason="do not run if openai is not installed or py!=3.10",
|
||||
@ -860,7 +902,7 @@ def test_function_registration_e2e_sync() -> None:
|
||||
timer_mock = unittest.mock.MagicMock()
|
||||
stopwatch_mock = unittest.mock.MagicMock()
|
||||
|
||||
# An example async function
|
||||
# An example async function registered using decorators
|
||||
@user_proxy.register_for_execution()
|
||||
@coder.register_for_llm(description="create a timer for N seconds")
|
||||
def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
|
||||
@ -873,9 +915,7 @@ def test_function_registration_e2e_sync() -> None:
|
||||
timer_mock(num_seconds=num_seconds)
|
||||
return "Timer is done!"
|
||||
|
||||
# An example sync function
|
||||
@user_proxy.register_for_execution()
|
||||
@coder.register_for_llm(description="create a stopwatch for N seconds")
|
||||
# An example sync function registered using register_function
|
||||
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
|
||||
print("stopwatch is running")
|
||||
# assert False, "stopwatch's alive!"
|
||||
@ -887,6 +927,8 @@ def test_function_registration_e2e_sync() -> None:
|
||||
stopwatch_mock(num_seconds=num_seconds)
|
||||
return "Stopwatch is done!"
|
||||
|
||||
register_function(stopwatch, caller=coder, executor=user_proxy, description="create a stopwatch for N seconds")
|
||||
|
||||
# start the conversation
|
||||
# 'await' is used to pause and resume code execution for async IO operations.
|
||||
# Without 'await', an async function returns a coroutine object but doesn't execute the function.
|
||||
@ -938,9 +980,7 @@ async def test_function_registration_e2e_async() -> None:
|
||||
timer_mock = unittest.mock.MagicMock()
|
||||
stopwatch_mock = unittest.mock.MagicMock()
|
||||
|
||||
# An example async function
|
||||
@user_proxy.register_for_execution()
|
||||
@coder.register_for_llm(description="create a timer for N seconds")
|
||||
# An example async function registered using register_function
|
||||
async def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
|
||||
print("timer is running")
|
||||
for i in range(int(num_seconds)):
|
||||
@ -951,7 +991,9 @@ async def test_function_registration_e2e_async() -> None:
|
||||
timer_mock(num_seconds=num_seconds)
|
||||
return "Timer is done!"
|
||||
|
||||
# An example sync function
|
||||
register_function(timer, caller=coder, executor=user_proxy, description="create a timer for N seconds")
|
||||
|
||||
# An example sync function registered using decorators
|
||||
@user_proxy.register_for_execution()
|
||||
@coder.register_for_llm(description="create a stopwatch for N seconds")
|
||||
def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
|
||||
|
||||
378
test/agentchat/test_function_and_tool_calling.py
Normal file
378
test/agentchat/test_function_and_tool_calling.py
Normal file
@ -0,0 +1,378 @@
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from autogen.agentchat.conversable_agent import ConversableAgent
|
||||
|
||||
|
||||
def _tool_func_1(arg1: str, arg2: str) -> str:
|
||||
return f"_tool_func_1: {arg1} {arg2}"
|
||||
|
||||
|
||||
def _tool_func_2(arg1: str, arg2: str) -> str:
|
||||
return f"_tool_func_2: {arg1} {arg2}"
|
||||
|
||||
|
||||
def _tool_func_error(arg1: str, arg2: str) -> str:
|
||||
raise RuntimeError("Error in tool function")
|
||||
|
||||
|
||||
async def _a_tool_func_1(arg1: str, arg2: str) -> str:
|
||||
return f"_tool_func_1: {arg1} {arg2}"
|
||||
|
||||
|
||||
async def _a_tool_func_2(arg1: str, arg2: str) -> str:
|
||||
return f"_tool_func_2: {arg1} {arg2}"
|
||||
|
||||
|
||||
async def _a_tool_func_error(arg1: str, arg2: str) -> str:
|
||||
raise RuntimeError("Error in tool function")
|
||||
|
||||
|
||||
_tool_use_message_1 = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "_tool_func_1",
|
||||
"arguments": json.dumps({"arg1": "value1", "arg2": "value2"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "_tool_func_2",
|
||||
"arguments": json.dumps({"arg1": "value3", "arg2": "value4"}),
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
_tool_use_message_1_bad_json = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "_tool_func_1",
|
||||
# add extra comma to make json invalid
|
||||
"arguments": json.dumps({"arg1": "value3", "arg2": "value4"})[:-1] + ",}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "_tool_func_2",
|
||||
"arguments": json.dumps({"arg1": "value3", "arg2": "value4"}),
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
_tool_use_message_1_expected_reply = {
|
||||
"role": "tool",
|
||||
"tool_responses": [
|
||||
{"tool_call_id": "1", "role": "tool", "content": "_tool_func_1: value1 value2"},
|
||||
{"tool_call_id": "2", "role": "tool", "content": "_tool_func_2: value3 value4"},
|
||||
],
|
||||
# "content": "Tool Call Id: 1\n_tool_func_1: value1 value2\n\nTool Call Id: 2\n_tool_func_2: value3 value4",
|
||||
"content": "_tool_func_1: value1 value2\n\n_tool_func_2: value3 value4",
|
||||
}
|
||||
|
||||
|
||||
_tool_use_message_1_bad_json_expected_reply = {
|
||||
"role": "tool",
|
||||
"tool_responses": [
|
||||
{
|
||||
"tool_call_id": "1",
|
||||
"role": "tool",
|
||||
"content": "Error: Expecting property name enclosed in double quotes: line 1 column 37 (char 36)\n You argument should follow json format.",
|
||||
},
|
||||
{"tool_call_id": "2", "role": "tool", "content": "_tool_func_2: value3 value4"},
|
||||
],
|
||||
"content": "Error: Expecting property name enclosed in double quotes: line 1 column 37 (char 36)\n You argument should follow json format.\n\n_tool_func_2: value3 value4",
|
||||
}
|
||||
|
||||
_tool_use_message_1_error_expected_reply = {
|
||||
"role": "tool",
|
||||
"tool_responses": [
|
||||
{"tool_call_id": "1", "role": "tool", "content": "_tool_func_1: value1 value2"},
|
||||
{
|
||||
"tool_call_id": "2",
|
||||
"role": "tool",
|
||||
"content": "Error: Error in tool function",
|
||||
},
|
||||
],
|
||||
"content": "_tool_func_1: value1 value2\n\nError: Error in tool function",
|
||||
}
|
||||
|
||||
_tool_use_message_1_not_found_expected_reply = {
|
||||
"role": "tool",
|
||||
"tool_responses": [
|
||||
{"tool_call_id": "1", "role": "tool", "content": "_tool_func_1: value1 value2"},
|
||||
{
|
||||
"tool_call_id": "2",
|
||||
"role": "tool",
|
||||
"content": "Error: Function _tool_func_2 not found.",
|
||||
},
|
||||
],
|
||||
"content": "_tool_func_1: value1 value2\n\nError: Function _tool_func_2 not found.",
|
||||
}
|
||||
|
||||
_function_use_message_1 = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"function_call": {
|
||||
"name": "_tool_func_1",
|
||||
"arguments": json.dumps({"arg1": "value1", "arg2": "value2"}),
|
||||
},
|
||||
}
|
||||
|
||||
_function_use_message_1_bad_json = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"function_call": {
|
||||
"name": "_tool_func_1",
|
||||
"arguments": json.dumps({"arg1": "value1", "arg2": "value2"})[:-1] + ",}",
|
||||
},
|
||||
}
|
||||
|
||||
_function_use_message_1_expected_reply = {
|
||||
"name": "_tool_func_1",
|
||||
"role": "function",
|
||||
"content": "_tool_func_1: value1 value2",
|
||||
}
|
||||
|
||||
_function_use_message_1_bad_json_expected_reply = {
|
||||
"name": "_tool_func_1",
|
||||
"role": "function",
|
||||
"content": "Error: Expecting property name enclosed in double quotes: line 1 column 37 (char 36)\n You argument should follow json format.",
|
||||
}
|
||||
|
||||
_function_use_message_1_error_expected_reply = {
|
||||
"name": "_tool_func_1",
|
||||
"role": "function",
|
||||
"content": "Error: Error in tool function",
|
||||
}
|
||||
|
||||
_function_use_message_1_not_found_expected_reply = {
|
||||
"name": "_tool_func_1",
|
||||
"role": "function",
|
||||
"content": "Error: Function _tool_func_1 not found.",
|
||||
}
|
||||
|
||||
_text_message = {"content": "Hi!", "role": "user"}
|
||||
|
||||
|
||||
def _get_function_map(is_function_async: bool, drop_tool_2: bool = False) -> Dict[str, Callable[..., Any]]:
|
||||
if is_function_async:
|
||||
return (
|
||||
{
|
||||
"_tool_func_1": _a_tool_func_1,
|
||||
"_tool_func_2": _a_tool_func_2,
|
||||
}
|
||||
if not drop_tool_2
|
||||
else {
|
||||
"_tool_func_1": _a_tool_func_1,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return (
|
||||
{
|
||||
"_tool_func_1": _tool_func_1,
|
||||
"_tool_func_2": _tool_func_2,
|
||||
}
|
||||
if not drop_tool_2
|
||||
else {
|
||||
"_tool_func_1": _tool_func_1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_error_function_map(
|
||||
is_function_async: bool, error_on_tool_func_2: bool = True
|
||||
) -> Dict[str, Callable[..., Any]]:
|
||||
if is_function_async:
|
||||
return {
|
||||
"_tool_func_1": _a_tool_func_1 if error_on_tool_func_2 else _a_tool_func_error,
|
||||
"_tool_func_2": _a_tool_func_error if error_on_tool_func_2 else _a_tool_func_2,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"_tool_func_1": _tool_func_1 if error_on_tool_func_2 else _tool_func_error,
|
||||
"_tool_func_2": _tool_func_error if error_on_tool_func_2 else _tool_func_2,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_function_async", [True, False])
|
||||
def test_generate_function_call_reply_on_function_call_message(is_function_async: bool) -> None:
|
||||
agent = ConversableAgent(name="agent", llm_config=False)
|
||||
|
||||
# empty function_map
|
||||
agent._function_map = {}
|
||||
messages = [_function_use_message_1]
|
||||
finished, retval = agent.generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (True, _function_use_message_1_not_found_expected_reply)
|
||||
|
||||
# function map set
|
||||
agent._function_map = _get_function_map(is_function_async)
|
||||
|
||||
# correct function call, multiple times to make sure cleanups are done properly
|
||||
for _ in range(3):
|
||||
messages = [_function_use_message_1]
|
||||
finished, retval = agent.generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (True, _function_use_message_1_expected_reply)
|
||||
|
||||
# bad JSON
|
||||
messages = [_function_use_message_1_bad_json]
|
||||
finished, retval = agent.generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (True, _function_use_message_1_bad_json_expected_reply)
|
||||
|
||||
# tool call
|
||||
messages = [_tool_use_message_1]
|
||||
finished, retval = agent.generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (False, None)
|
||||
|
||||
# text message
|
||||
messages: List[Dict[str, str]] = [_text_message]
|
||||
finished, retval = agent.generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (False, None)
|
||||
|
||||
# error in function (raises Exception)
|
||||
agent._function_map = _get_error_function_map(is_function_async, error_on_tool_func_2=False)
|
||||
messages = [_function_use_message_1]
|
||||
finished, retval = agent.generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (True, _function_use_message_1_error_expected_reply)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("is_function_async", [True, False])
|
||||
async def test_a_generate_function_call_reply_on_function_call_message(is_function_async: bool) -> None:
|
||||
agent = ConversableAgent(name="agent", llm_config=False)
|
||||
|
||||
# empty function_map
|
||||
agent._function_map = {}
|
||||
messages = [_function_use_message_1]
|
||||
finished, retval = await agent.a_generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (True, _function_use_message_1_not_found_expected_reply)
|
||||
|
||||
# function map set
|
||||
agent._function_map = _get_function_map(is_function_async)
|
||||
|
||||
# correct function call, multiple times to make sure cleanups are done properly
|
||||
for _ in range(3):
|
||||
messages = [_function_use_message_1]
|
||||
finished, retval = await agent.a_generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (True, _function_use_message_1_expected_reply)
|
||||
|
||||
# bad JSON
|
||||
messages = [_function_use_message_1_bad_json]
|
||||
finished, retval = await agent.a_generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (True, _function_use_message_1_bad_json_expected_reply)
|
||||
|
||||
# tool call
|
||||
messages = [_tool_use_message_1]
|
||||
finished, retval = await agent.a_generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (False, None)
|
||||
|
||||
# text message
|
||||
messages: List[Dict[str, str]] = [_text_message]
|
||||
finished, retval = await agent.a_generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (False, None)
|
||||
|
||||
# error in function (raises Exception)
|
||||
agent._function_map = _get_error_function_map(is_function_async, error_on_tool_func_2=False)
|
||||
messages = [_function_use_message_1]
|
||||
finished, retval = await agent.a_generate_function_call_reply(messages)
|
||||
assert (finished, retval) == (True, _function_use_message_1_error_expected_reply)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_function_async", [True, False])
|
||||
def test_generate_tool_calls_reply_on_function_call_message(is_function_async: bool) -> None:
|
||||
agent = ConversableAgent(name="agent", llm_config=False)
|
||||
|
||||
# empty function_map
|
||||
agent._function_map = _get_function_map(is_function_async, drop_tool_2=True)
|
||||
messages = [_tool_use_message_1]
|
||||
finished, retval = agent.generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (True, _tool_use_message_1_not_found_expected_reply)
|
||||
|
||||
# function map set
|
||||
agent._function_map = _get_function_map(is_function_async)
|
||||
|
||||
# correct function call, multiple times to make sure cleanups are done properly
|
||||
for _ in range(3):
|
||||
messages = [_tool_use_message_1]
|
||||
finished, retval = agent.generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (True, _tool_use_message_1_expected_reply)
|
||||
|
||||
# bad JSON
|
||||
messages = [_tool_use_message_1_bad_json]
|
||||
finished, retval = agent.generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (True, _tool_use_message_1_bad_json_expected_reply)
|
||||
|
||||
# function call
|
||||
messages = [_function_use_message_1]
|
||||
finished, retval = agent.generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (False, None)
|
||||
|
||||
# text message
|
||||
messages: List[Dict[str, str]] = [_text_message]
|
||||
finished, retval = agent.generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (False, None)
|
||||
|
||||
# error in function (raises Exception)
|
||||
agent._function_map = _get_error_function_map(is_function_async)
|
||||
messages = [_tool_use_message_1]
|
||||
finished, retval = agent.generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (True, _tool_use_message_1_error_expected_reply)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("is_function_async", [True, False])
|
||||
async def test_a_generate_tool_calls_reply_on_function_call_message(is_function_async: bool) -> None:
|
||||
agent = ConversableAgent(name="agent", llm_config=False)
|
||||
|
||||
# empty function_map
|
||||
agent._function_map = _get_function_map(is_function_async, drop_tool_2=True)
|
||||
messages = [_tool_use_message_1]
|
||||
finished, retval = await agent.a_generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (True, _tool_use_message_1_not_found_expected_reply)
|
||||
|
||||
# function map set
|
||||
agent._function_map = _get_function_map(is_function_async)
|
||||
|
||||
# correct function call, multiple times to make sure cleanups are done properly
|
||||
for _ in range(3):
|
||||
messages = [_tool_use_message_1]
|
||||
finished, retval = await agent.a_generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (True, _tool_use_message_1_expected_reply)
|
||||
|
||||
# bad JSON
|
||||
messages = [_tool_use_message_1_bad_json]
|
||||
finished, retval = await agent.a_generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (True, _tool_use_message_1_bad_json_expected_reply)
|
||||
|
||||
# function call
|
||||
messages = [_function_use_message_1]
|
||||
finished, retval = await agent.a_generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (False, None)
|
||||
|
||||
# text message
|
||||
messages: List[Dict[str, str]] = [_text_message]
|
||||
finished, retval = await agent.a_generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (False, None)
|
||||
|
||||
# error in function (raises Exception)
|
||||
agent._function_map = _get_error_function_map(is_function_async)
|
||||
messages = [_tool_use_message_1]
|
||||
finished, retval = await agent.a_generate_tool_calls_reply(messages)
|
||||
assert (finished, retval) == (True, _tool_use_message_1_error_expected_reply)
|
||||
@ -102,7 +102,6 @@ user_proxy = autogen.UserProxyAgent(
|
||||
``` python
|
||||
CurrencySymbol = Literal["USD", "EUR"]
|
||||
|
||||
|
||||
def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:
|
||||
if base_currency == quote_currency:
|
||||
return 1.0
|
||||
@ -156,12 +155,30 @@ you can call the decorators as functions:
|
||||
|
||||
```python
|
||||
# Register the function with the chatbot's llm_config.
|
||||
chatbot.register_for_llm(description="Currency exchange calculator.")(currency_calculator)
|
||||
currency_calculator = chatbot.register_for_llm(description="Currency exchange calculator.")(currency_calculator)
|
||||
|
||||
# Register the function with the user_proxy's function_map.
|
||||
user_proxy.register_for_execution()(currency_calculator)
|
||||
```
|
||||
|
||||
Alternatevely, you can also use `autogen.agentchat.register_function()` instead as follows:
|
||||
```python
|
||||
def currency_calculator(
|
||||
base_amount: Annotated[float, "Amount of currency in base_currency"],
|
||||
base_currency: Annotated[CurrencySymbol, "Base currency"] = "USD",
|
||||
quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
|
||||
) -> str:
|
||||
quote_amount = exchange_rate(base_currency, quote_currency) * base_amount
|
||||
return f"{quote_amount} {quote_currency}"
|
||||
|
||||
autogen.agentchat.register_function(
|
||||
currency_calculator,
|
||||
agent=chatbot,
|
||||
executor=user_proxy,
|
||||
description="Currency exchange calculator.",
|
||||
)
|
||||
```
|
||||
|
||||
4. Agents can now use the function as follows:
|
||||
```python
|
||||
user_proxy.initiate_chat(
|
||||
@ -216,14 +233,19 @@ class Currency(BaseModel):
|
||||
# parameter of type float, must be greater or equal to 0 with default value 0
|
||||
amount: Annotated[float, Field(0, description="Amount of currency", ge=0)]
|
||||
|
||||
@user_proxy.register_for_execution()
|
||||
@chatbot.register_for_llm(description="Currency exchange calculator.")
|
||||
def currency_calculator(
|
||||
base: Annotated[Currency, "Base currency: amount and currency symbol"],
|
||||
quote_currency: Annotated[CurrencySymbol, "Quote currency symbol"] = "USD",
|
||||
) -> Currency:
|
||||
quote_amount = exchange_rate(base.currency, quote_currency) * base.amount
|
||||
return Currency(amount=quote_amount, currency=quote_currency)
|
||||
|
||||
autogen.agentchat.register_function(
|
||||
currency_calculator,
|
||||
agent=chatbot,
|
||||
executor=user_proxy,
|
||||
description="Currency exchange calculator.",
|
||||
)
|
||||
```
|
||||
|
||||
The generated JSON schema has additional properties such as minimum value encoded:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user