mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 06:59:03 +00:00
[Core] [Tool Call] adjust conversable agent to support tool_calls (#974)
* adjust conversable and compressible agents to support tool_calls * split out tools into their own reply def * copilot typo * address review comments * revert compressible_agent and token_count_utils calls * cleanup terminate check and remove unnecessary code * doc search and update * return function/tool calls as interrupted when user provides a reply to a tool call request * fix tool name reference * fix formatting * fix initiate receiving a dict * missed changed roled * ignore incoming role, more similiar to existing code * consistency * redundant to_dict * fix todo comment * uneeded change * handle dict reply in groupchat * Fix generate_tool_call_calls_reply_comment * change method annotation for register_for_llm from functions to tools * typo autogen/agentchat/conversable_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * add deprecation comments for function_call * tweak doc strings * switch to ToolFunction type * update the return to * fix generate_init_message return type * Revert "fix generate_init_message return type" This reverts commit 645ba8b76afa06f160223ecdac6f3dc1822fd249. * undo force init to dict * fix notebooks and groupchat tool handling * fix type * use get for key error * fix teachable to pull content from dict * change single message tool response * cleanup unnessary changes * little better tool response concatenation * update tools tests * add skip openai check to tools tests * fix nits * move func name normalization to oai_reply and assert configured names * fix whitespace * remove extra normalize * tool name is now normalized in the generate_reply function, so will not be incorrect when sent to receive * validate function names in init and expand comments for validation methods * fix dict comprehension * Dummy llm config for unit tests * handle tool_calls set to None * fix tool name reference * method operates on responses not calls --------- Co-authored-by: Yiran Wu <32823396+kevin666aa@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
e673500129
commit
40dbf31a92
@ -4,6 +4,7 @@ import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
@ -80,7 +81,7 @@ class ConversableAgent(Agent):
|
||||
the number of auto reply reaches the max_consecutive_auto_reply.
|
||||
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
|
||||
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
|
||||
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
|
||||
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions, also used for tool calls.
|
||||
code_execution_config (dict or False): config for the code execution.
|
||||
To disable code execution, set to False. Otherwise, set to a dictionary with the following keys:
|
||||
- work_dir (Optional, str): The working directory for the code execution.
|
||||
@ -133,13 +134,19 @@ class ConversableAgent(Agent):
|
||||
)
|
||||
self._consecutive_auto_reply_counter = defaultdict(int)
|
||||
self._max_consecutive_auto_reply_dict = defaultdict(self.max_consecutive_auto_reply)
|
||||
self._function_map = {} if function_map is None else function_map
|
||||
self._function_map = (
|
||||
{}
|
||||
if function_map is None
|
||||
else {name: callable for name, callable in function_map.items() if self._assert_valid_name(name)}
|
||||
)
|
||||
self._default_auto_reply = default_auto_reply
|
||||
self._reply_func_list = []
|
||||
self.reply_at_receive = defaultdict(bool)
|
||||
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.a_generate_function_call_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
|
||||
@ -280,13 +287,35 @@ class ConversableAgent(Agent):
|
||||
else:
|
||||
return dict(message)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_name(name):
|
||||
"""
|
||||
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||
|
||||
Prefer _assert_valid_name for validating user configuration or input
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||
|
||||
@staticmethod
|
||||
def _assert_valid_name(name):
|
||||
"""
|
||||
Ensure that configured names are valid, raises ValueError if not.
|
||||
|
||||
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||
if len(name) > 64:
|
||||
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||
return name
|
||||
|
||||
def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent) -> bool:
|
||||
"""Append a message to the ChatCompletion conversation.
|
||||
|
||||
If the message received is a string, it will be put in the "content" field of the new dictionary.
|
||||
If the message received is a dictionary but does not have any of the two fields "content" or "function_call",
|
||||
If the message received is a dictionary but does not have any of the three fields "content", "function_call", or "tool_calls",
|
||||
this message is not a valid ChatCompletion message.
|
||||
If only "function_call" is provided, "content" will be set to None if not provided, and the role of the message will be forced "assistant".
|
||||
If only "function_call" or "tool_calls" is provided, "content" will be set to None if not provided, and the role of the message will be forced "assistant".
|
||||
|
||||
Args:
|
||||
message (dict or str): message to be appended to the ChatCompletion conversation.
|
||||
@ -298,17 +327,24 @@ class ConversableAgent(Agent):
|
||||
"""
|
||||
message = self._message_to_dict(message)
|
||||
# create oai message to be appended to the oai conversation that can be passed to oai directly.
|
||||
oai_message = {k: message[k] for k in ("content", "function_call", "name", "context") if k in message}
|
||||
oai_message = {
|
||||
k: message[k]
|
||||
for k in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context")
|
||||
if k in message and message[k] is not None
|
||||
}
|
||||
if "content" not in oai_message:
|
||||
if "function_call" in oai_message:
|
||||
if "function_call" in oai_message or "tool_calls" in oai_message:
|
||||
oai_message["content"] = None # if only function_call is provided, content will be set to None.
|
||||
else:
|
||||
return False
|
||||
|
||||
oai_message["role"] = "function" if message.get("role") == "function" else role
|
||||
if "function_call" in oai_message:
|
||||
if message.get("role") in ["function", "tool"]:
|
||||
oai_message["role"] = message.get("role")
|
||||
else:
|
||||
oai_message["role"] = role
|
||||
|
||||
if oai_message.get("function_call", False) or oai_message.get("tool_calls", False):
|
||||
oai_message["role"] = "assistant" # only messages with role 'assistant' can have a function call.
|
||||
oai_message["function_call"] = dict(oai_message["function_call"])
|
||||
self._oai_messages[conversation_id].append(oai_message)
|
||||
return True
|
||||
|
||||
@ -415,8 +451,14 @@ class ConversableAgent(Agent):
|
||||
print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
|
||||
message = self._message_to_dict(message)
|
||||
|
||||
if message.get("role") == "function":
|
||||
func_print = f"***** Response from calling function \"{message['name']}\" *****"
|
||||
if message.get("tool_responses"): # Handle tool multi-call responses
|
||||
for tool_response in message["tool_responses"]:
|
||||
self._print_received_message(tool_response, sender)
|
||||
if message.get("role") == "tool":
|
||||
return # If role is tool, then content is just a concatenation of all tool_responses
|
||||
|
||||
if message.get("role") in ["function", "tool"]:
|
||||
func_print = f"***** Response from calling {message['role']} \"{message['name']}\" *****"
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
print(message["content"], flush=True)
|
||||
print(colored("*" * len(func_print), "green"), flush=True)
|
||||
@ -430,7 +472,7 @@ class ConversableAgent(Agent):
|
||||
self.llm_config and self.llm_config.get("allow_format_str_template", False),
|
||||
)
|
||||
print(content_str(content), flush=True)
|
||||
if "function_call" in message:
|
||||
if "function_call" in message and message["function_call"]:
|
||||
function_call = dict(message["function_call"])
|
||||
func_print = (
|
||||
f"***** Suggested function Call: {function_call.get('name', '(No function name found)')} *****"
|
||||
@ -443,10 +485,23 @@ class ConversableAgent(Agent):
|
||||
sep="",
|
||||
)
|
||||
print(colored("*" * len(func_print), "green"), flush=True)
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
for tool_call in message["tool_calls"]:
|
||||
id = tool_call.get("id", "(No id found)")
|
||||
function_call = dict(tool_call.get("function", {}))
|
||||
func_print = f"***** Suggested tool Call ({id}): {function_call.get('name', '(No function name found)')} *****"
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
print(
|
||||
"Arguments: \n",
|
||||
function_call.get("arguments", "(No arguments found)"),
|
||||
flush=True,
|
||||
sep="",
|
||||
)
|
||||
print(colored("*" * len(func_print), "green"), flush=True)
|
||||
|
||||
print("\n", "-" * 80, flush=True, sep="")
|
||||
|
||||
def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool):
|
||||
message = self._message_to_dict(message)
|
||||
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
|
||||
valid = self._append_oai_message(message, "user", sender)
|
||||
if not valid:
|
||||
@ -471,11 +526,12 @@ class ConversableAgent(Agent):
|
||||
Args:
|
||||
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content or function_call need to be provided).
|
||||
1. "content": content of the message, can be None.
|
||||
2. "function_call": a dictionary containing the function name and arguments.
|
||||
3. "role": role of the message, can be "assistant", "user", "function".
|
||||
2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls")
|
||||
3. "tool_calls": a list of dictionaries containing the function name and arguments.
|
||||
4. "role": role of the message, can be "assistant", "user", "function", "tool".
|
||||
This field is only needed to distinguish between "function" or "assistant"/"user".
|
||||
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
|
||||
5. "context" (dict): the context of the message, which will be passed to
|
||||
5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
|
||||
6. "context" (dict): the context of the message, which will be passed to
|
||||
[OpenAIWrapper.create](../oai/client#create).
|
||||
sender: sender of an Agent instance.
|
||||
request_reply (bool or None): whether a reply is requested from the sender.
|
||||
@ -507,11 +563,12 @@ class ConversableAgent(Agent):
|
||||
Args:
|
||||
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content or function_call need to be provided).
|
||||
1. "content": content of the message, can be None.
|
||||
2. "function_call": a dictionary containing the function name and arguments.
|
||||
3. "role": role of the message, can be "assistant", "user", "function".
|
||||
2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls")
|
||||
3. "tool_calls": a list of dictionaries containing the function name and arguments.
|
||||
4. "role": role of the message, can be "assistant", "user", "function".
|
||||
This field is only needed to distinguish between "function" or "assistant"/"user".
|
||||
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
|
||||
5. "context" (dict): the context of the message, which will be passed to
|
||||
5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
|
||||
6. "context" (dict): the context of the message, which will be passed to
|
||||
[OpenAIWrapper.create](../oai/client#create).
|
||||
sender: sender of an Agent instance.
|
||||
request_reply (bool or None): whether a reply is requested from the sender.
|
||||
@ -631,15 +688,35 @@ class ConversableAgent(Agent):
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
# unroll tool_responses
|
||||
all_messages = []
|
||||
for message in messages:
|
||||
tool_responses = message.get("tool_responses", [])
|
||||
if tool_responses:
|
||||
all_messages += tool_responses
|
||||
# tool role on the parent message means the content is just concatentation of all of the tool_responses
|
||||
if message.get("role") != "tool":
|
||||
all_messages.append({key: message[key] for key in message if key != "tool_responses"})
|
||||
else:
|
||||
all_messages.append(message)
|
||||
|
||||
# TODO: #1143 handle token limit exceeded error
|
||||
response = client.create(
|
||||
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages
|
||||
context=messages[-1].pop("context", None), messages=self._oai_system_message + all_messages
|
||||
)
|
||||
|
||||
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
|
||||
extracted_response = client.extract_text_or_completion_object(response)[0]
|
||||
|
||||
# ensure function and tool calls will be accepted when sent back to the LLM
|
||||
if not isinstance(extracted_response, str):
|
||||
extracted_response = model_dump(extracted_response)
|
||||
if isinstance(extracted_response, dict):
|
||||
if extracted_response.get("function_call"):
|
||||
extracted_response["function_call"]["name"] = self._normalize_name(
|
||||
extracted_response["function_call"]["name"]
|
||||
)
|
||||
for tool_call in extracted_response.get("tool_calls") or []:
|
||||
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])
|
||||
return True, extracted_response
|
||||
|
||||
async def a_generate_oai_reply(
|
||||
@ -708,13 +785,23 @@ class ConversableAgent(Agent):
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[Dict, None]]:
|
||||
"""Generate a reply using function call."""
|
||||
"""
|
||||
Generate a reply using function call.
|
||||
|
||||
"function_call" replaced by "tool_calls" as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
|
||||
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-functions
|
||||
"""
|
||||
if config is None:
|
||||
config = self
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
if "function_call" in message:
|
||||
if "function_call" in message and message["function_call"]:
|
||||
func_call = message["function_call"]
|
||||
func = self._function_map.get(func_call.get("name", None), None)
|
||||
if asyncio.coroutines.iscoroutinefunction(func):
|
||||
return False, None
|
||||
|
||||
_, func_return = self.execute_function(message["function_call"])
|
||||
return True, func_return
|
||||
return False, None
|
||||
@ -725,7 +812,12 @@ class ConversableAgent(Agent):
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[Dict, None]]:
|
||||
"""Generate a reply using async function call."""
|
||||
"""
|
||||
Generate a reply using async function call.
|
||||
|
||||
"function_call" replaced by "tool_calls" as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
|
||||
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-functions
|
||||
"""
|
||||
if config is None:
|
||||
config = self
|
||||
if messages is None:
|
||||
@ -741,6 +833,89 @@ class ConversableAgent(Agent):
|
||||
|
||||
return False, None
|
||||
|
||||
def _str_for_tool_response(self, tool_response):
|
||||
func_name = tool_response.get("name", "")
|
||||
func_id = tool_response.get("tool_call_id", "")
|
||||
response = tool_response.get("content", "")
|
||||
return f"Tool call: {func_name}\nId: {func_id}\n{response}"
|
||||
|
||||
def generate_tool_calls_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[Dict, None]]:
|
||||
"""Generate a reply using tool call."""
|
||||
if config is None:
|
||||
config = self
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
tool_calls = message["tool_calls"]
|
||||
tool_returns = []
|
||||
for tool_call in tool_calls:
|
||||
id = tool_call["id"]
|
||||
function_call = tool_call.get("function", {})
|
||||
func = self._function_map.get(function_call.get("name", None), None)
|
||||
if asyncio.coroutines.iscoroutinefunction(func):
|
||||
continue
|
||||
_, func_return = self.execute_function(function_call)
|
||||
tool_returns.append(
|
||||
{
|
||||
"tool_call_id": id,
|
||||
"role": "tool",
|
||||
"name": func_return.get("name", ""),
|
||||
"content": func_return.get("content", ""),
|
||||
}
|
||||
)
|
||||
return True, {
|
||||
"role": "tool",
|
||||
"tool_responses": tool_returns,
|
||||
"content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]),
|
||||
}
|
||||
return False, None
|
||||
|
||||
async def _a_execute_tool_call(self, tool_call):
|
||||
id = tool_call["id"]
|
||||
function_call = tool_call.get("function", {})
|
||||
_, func_return = await self.a_execute_function(function_call)
|
||||
return {
|
||||
"tool_call_id": id,
|
||||
"role": "tool",
|
||||
"name": func_return.get("name", ""),
|
||||
"content": func_return.get("content", ""),
|
||||
}
|
||||
|
||||
async def a_generate_tool_calls_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[Dict, None]]:
|
||||
"""Generate a reply using async function call."""
|
||||
if config is None:
|
||||
config = self
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
async_tool_calls = []
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
func = self._function_map.get(tool_call.get("function", {}).get("name", None), None)
|
||||
if func and asyncio.coroutines.iscoroutinefunction(func):
|
||||
async_tool_calls.append(self._a_execute_tool_call(tool_call))
|
||||
if len(async_tool_calls) > 0:
|
||||
tool_returns = await asyncio.gather(*async_tool_calls)
|
||||
return True, {
|
||||
"role": "tool",
|
||||
"tool_responses": tool_returns,
|
||||
"content": "\n\n".join(
|
||||
[self._str_for_tool_response(tool_return["content"]) for tool_return in tool_returns]
|
||||
),
|
||||
}
|
||||
|
||||
return False, None
|
||||
|
||||
def check_termination_and_human_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
@ -821,7 +996,28 @@ class ConversableAgent(Agent):
|
||||
if reply or self._max_consecutive_auto_reply_dict[sender] == 0:
|
||||
# reset the consecutive_auto_reply_counter
|
||||
self._consecutive_auto_reply_counter[sender] = 0
|
||||
return True, reply
|
||||
# User provided a custom response, return function and tool failures indicating user interruption
|
||||
tool_returns = []
|
||||
if message.get("function_call", False):
|
||||
tool_returns.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": message["function_call"].get("name", ""),
|
||||
"content": "USER INTERRUPTED",
|
||||
}
|
||||
)
|
||||
|
||||
if message.get("tool_calls", False):
|
||||
tool_returns.extend(
|
||||
[
|
||||
{"role": "tool", "tool_call_id": tool_call.get("id", ""), "content": "USER INTERRUPTED"}
|
||||
for tool_call in message["tool_calls"]
|
||||
]
|
||||
)
|
||||
|
||||
response = {"role": "user", "content": reply, "tool_responses": tool_returns}
|
||||
|
||||
return True, response
|
||||
|
||||
# increment the consecutive_auto_reply_counter
|
||||
self._consecutive_auto_reply_counter[sender] += 1
|
||||
@ -906,9 +1102,29 @@ class ConversableAgent(Agent):
|
||||
|
||||
# send the human reply
|
||||
if reply or self._max_consecutive_auto_reply_dict[sender] == 0:
|
||||
# User provided a custom response, return function and tool results indicating user interruption
|
||||
# reset the consecutive_auto_reply_counter
|
||||
self._consecutive_auto_reply_counter[sender] = 0
|
||||
return True, reply
|
||||
tool_returns = []
|
||||
if message.get("function_call", False):
|
||||
tool_returns.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": message["function_call"].get("name", ""),
|
||||
"content": "USER INTERRUPTED",
|
||||
}
|
||||
)
|
||||
|
||||
if message.get("tool_calls", False):
|
||||
tool_returns.extend(
|
||||
[
|
||||
{"role": "tool", "tool_call_id": tool_call.get("id", ""), "content": "USER INTERRUPTED"}
|
||||
for tool_call in message["tool_calls"]
|
||||
]
|
||||
)
|
||||
|
||||
response = {"role": "user", "content": reply, "tool_responses": tool_returns}
|
||||
return True, response
|
||||
|
||||
# increment the consecutive_auto_reply_counter
|
||||
self._consecutive_auto_reply_counter[sender] += 1
|
||||
@ -930,9 +1146,10 @@ class ConversableAgent(Agent):
|
||||
Use registered auto reply functions to generate replies.
|
||||
By default, the following functions are checked in order:
|
||||
1. check_termination_and_human_reply
|
||||
2. generate_function_call_reply
|
||||
3. generate_code_execution_reply
|
||||
4. generate_oai_reply
|
||||
2. generate_function_call_reply (deprecated in favor of tool_calls)
|
||||
3. generate_tool_calls_reply
|
||||
4. generate_code_execution_reply
|
||||
5. generate_oai_reply
|
||||
Every function returns a tuple (final, reply).
|
||||
When a function returns final=False, the next function will be checked.
|
||||
So by default, termination and human reply will be checked first.
|
||||
@ -982,8 +1199,9 @@ class ConversableAgent(Agent):
|
||||
By default, the following functions are checked in order:
|
||||
1. check_termination_and_human_reply
|
||||
2. generate_function_call_reply
|
||||
3. generate_code_execution_reply
|
||||
4. generate_oai_reply
|
||||
3. generate_tool_calls_reply
|
||||
4. generate_code_execution_reply
|
||||
5. generate_oai_reply
|
||||
Every function returns a tuple (final, reply).
|
||||
When a function returns final=False, the next function will be checked.
|
||||
So by default, termination and human reply will be checked first.
|
||||
@ -1173,15 +1391,18 @@ class ConversableAgent(Agent):
|
||||
def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict[str, str]]:
|
||||
"""Execute a function call and return the result.
|
||||
|
||||
Override this function to modify the way to execute a function call.
|
||||
Override this function to modify the way to execute function and tool calls.
|
||||
|
||||
Args:
|
||||
func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".
|
||||
func_call: a dictionary extracted from openai message at "function_call" or "tool_calls" with keys "name" and "arguments".
|
||||
|
||||
Returns:
|
||||
A tuple of (is_exec_success, result_dict).
|
||||
is_exec_success (boolean): whether the execution is successful.
|
||||
result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
|
||||
|
||||
"function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
|
||||
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
|
||||
"""
|
||||
func_name = func_call.get("name", "")
|
||||
func = self._function_map.get(func_name, None)
|
||||
@ -1225,15 +1446,18 @@ class ConversableAgent(Agent):
|
||||
async def a_execute_function(self, func_call):
|
||||
"""Execute an async function call and return the result.
|
||||
|
||||
Override this function to modify the way async functions are executed.
|
||||
Override this function to modify the way async functions and tools are executed.
|
||||
|
||||
Args:
|
||||
func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".
|
||||
func_call: a dictionary extracted from openai message at key "function_call" or "tool_calls" with keys "name" and "arguments".
|
||||
|
||||
Returns:
|
||||
A tuple of (is_exec_success, result_dict).
|
||||
is_exec_success (boolean): whether the execution is successful.
|
||||
result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
|
||||
|
||||
"function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
|
||||
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
|
||||
"""
|
||||
func_name = func_call.get("name", "")
|
||||
func = self._function_map.get(func_name, None)
|
||||
@ -1289,6 +1513,8 @@ class ConversableAgent(Agent):
|
||||
Args:
|
||||
function_map: a dictionary mapping function names to functions.
|
||||
"""
|
||||
for name in function_map.keys():
|
||||
self._assert_valid_name(name)
|
||||
self._function_map.update(function_map)
|
||||
|
||||
def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None):
|
||||
@ -1297,6 +1523,9 @@ class ConversableAgent(Agent):
|
||||
Args:
|
||||
func_sig (str or dict): description/name of the function to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions
|
||||
is_remove: whether removing the function from llm_config with name 'func_sig'
|
||||
|
||||
Deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
|
||||
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
|
||||
"""
|
||||
|
||||
if not isinstance(self.llm_config, dict):
|
||||
@ -1314,6 +1543,7 @@ class ConversableAgent(Agent):
|
||||
func for func in self.llm_config["functions"] if func["name"] != func_sig
|
||||
]
|
||||
else:
|
||||
self._assert_valid_name(func_sig["name"])
|
||||
if "functions" in self.llm_config.keys():
|
||||
self.llm_config["functions"] = [
|
||||
func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"]
|
||||
@ -1326,9 +1556,48 @@ class ConversableAgent(Agent):
|
||||
|
||||
self.client = OpenAIWrapper(**self.llm_config)
|
||||
|
||||
def can_execute_function(self, name: str) -> bool:
|
||||
def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
|
||||
"""update a tool_signature in the LLM configuration for tool_call.
|
||||
|
||||
Args:
|
||||
tool_sig (str or dict): description/name of the tool to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools
|
||||
is_remove: whether removing the tool from llm_config with name 'tool_sig'
|
||||
"""
|
||||
|
||||
if not self.llm_config:
|
||||
error_msg = "To update a tool signature, agent must have an llm_config"
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
if is_remove:
|
||||
if "tools" not in self.llm_config.keys():
|
||||
error_msg = "The agent config doesn't have tool {name}.".format(name=tool_sig)
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
else:
|
||||
self.llm_config["tools"] = [
|
||||
tool for tool in self.llm_config["tools"] if tool["function"]["name"] != tool_sig
|
||||
]
|
||||
else:
|
||||
self._assert_valid_name(tool_sig["function"]["name"])
|
||||
if "tools" in self.llm_config.keys():
|
||||
self.llm_config["tools"] = [
|
||||
tool
|
||||
for tool in self.llm_config["tools"]
|
||||
if tool.get("function", {}).get("name") != tool_sig["function"]["name"]
|
||||
] + [tool_sig]
|
||||
else:
|
||||
self.llm_config["tools"] = [tool_sig]
|
||||
|
||||
if len(self.llm_config["tools"]) == 0:
|
||||
del self.llm_config["tools"]
|
||||
|
||||
self.client = OpenAIWrapper(**self.llm_config)
|
||||
|
||||
def can_execute_function(self, name: Union[List[str], str]) -> bool:
|
||||
"""Whether the agent can execute the function."""
|
||||
return name in self._function_map
|
||||
names = name if isinstance(name, list) else [name]
|
||||
return all([n in self._function_map for n in names])
|
||||
|
||||
@property
|
||||
def function_map(self) -> Dict[str, Callable]:
|
||||
@ -1433,7 +1702,7 @@ class ConversableAgent(Agent):
|
||||
if self.llm_config is None:
|
||||
raise RuntimeError("LLM config must be setup before registering a function for LLM.")
|
||||
|
||||
self.update_function_signature(f, is_remove=False)
|
||||
self.update_tool_signature(f, is_remove=False)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
@ -155,11 +155,21 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
||||
"Or, use direct communication instead."
|
||||
)
|
||||
|
||||
if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
|
||||
if (
|
||||
self.func_call_filter
|
||||
and self.messages
|
||||
and ("function_call" in self.messages[-1] or "tool_calls" in self.messages[-1])
|
||||
):
|
||||
funcs = []
|
||||
if "function_call" in self.messages[-1]:
|
||||
funcs += [self.messages[-1]["function_call"]["name"]]
|
||||
if "tool_calls" in self.messages[-1]:
|
||||
funcs += [
|
||||
tool["function"]["name"] for tool in self.messages[-1]["tool_calls"] if tool["type"] == "function"
|
||||
]
|
||||
|
||||
# find agents with the right function_map which contains the function name
|
||||
agents = [
|
||||
agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"])
|
||||
]
|
||||
agents = [agent for agent in self.agents if agent.can_execute_function(funcs)]
|
||||
if len(agents) == 1:
|
||||
# only one agent can execute the function
|
||||
return agents[0], agents
|
||||
@ -170,7 +180,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
||||
return agents[0], agents
|
||||
elif not agents:
|
||||
raise ValueError(
|
||||
f"No agent can execute the function {self.messages[-1]['function_call']['name']}. "
|
||||
f"No agent can execute the function {', '.join(funcs)}. "
|
||||
"Please check the function_map of the agents."
|
||||
)
|
||||
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
|
||||
@ -193,7 +203,14 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
||||
return selected_agent
|
||||
# auto speaker selection
|
||||
selector.update_system_message(self.select_speaker_msg(agents))
|
||||
context = self.messages + [{"role": "system", "content": self.select_speaker_prompt(agents)}]
|
||||
|
||||
# If last message is a tool call or function call, blank the call so the api doesn't throw
|
||||
messages = self.messages.copy()
|
||||
if messages[-1].get("function_call", False):
|
||||
messages[-1] = dict(messages[-1], function_call=None)
|
||||
if messages[-1].get("tool_calls", False):
|
||||
messages[-1] = dict(messages[-1], tool_calls=None)
|
||||
context = messages + [{"role": "system", "content": self.select_speaker_prompt(agents)}]
|
||||
final, name = selector.generate_oai_reply(context)
|
||||
|
||||
if not final:
|
||||
@ -275,6 +292,8 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
||||
Dict: a counter for mentioned agents.
|
||||
"""
|
||||
# Cast message content to str
|
||||
if isinstance(message_content, dict):
|
||||
message_content = message_content["content"]
|
||||
message_content = content_str(message_content)
|
||||
|
||||
mentions = dict()
|
||||
|
||||
@ -103,6 +103,13 @@ class Function(BaseModel):
|
||||
parameters: Annotated[Parameters, Field(description="Parameters of the function")]
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
"""A function under tool as defined by the OpenAI API."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: Annotated[Function, Field(description="Function under tool")]
|
||||
|
||||
|
||||
def get_parameter_json_schema(
|
||||
k: str, v: Union[Annotated[Type, str], Type], default_values: Dict[str, Any]
|
||||
) -> JsonSchemaValue:
|
||||
@ -260,10 +267,12 @@ def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, de
|
||||
|
||||
parameters = get_parameters(required, param_annotations, default_values=default_values)
|
||||
|
||||
function = Function(
|
||||
description=description,
|
||||
name=fname,
|
||||
parameters=parameters,
|
||||
function = ToolFunction(
|
||||
function=Function(
|
||||
description=description,
|
||||
name=fname,
|
||||
parameters=parameters,
|
||||
)
|
||||
)
|
||||
|
||||
return model_dump(function)
|
||||
|
||||
@ -287,9 +287,9 @@ class OpenAIWrapper:
|
||||
|
||||
def _completions_create(self, client, params):
|
||||
completions = client.chat.completions if "messages" in params else client.completions
|
||||
# If streaming is enabled, has messages, and does not have functions, then
|
||||
# If streaming is enabled, has messages, and does not have functions or tools, then
|
||||
# iterate over the chunks of the response
|
||||
if params.get("stream", False) and "messages" in params and "functions" not in params:
|
||||
if params.get("stream", False) and "messages" in params and "functions" not in params and "tools" not in params:
|
||||
response_contents = [""] * params.get("n", 1)
|
||||
finish_reasons = [""] * params.get("n", 1)
|
||||
completion_tokens = 0
|
||||
@ -352,8 +352,8 @@ class OpenAIWrapper:
|
||||
|
||||
response.choices.append(choice)
|
||||
else:
|
||||
# If streaming is not enabled or using functions, send a regular chat completion request
|
||||
# Functions are not supported, so ensure streaming is disabled
|
||||
# If streaming is not enabled, using functions, or tools, send a regular chat completion request
|
||||
# Functions and Tools are not supported, so ensure streaming is disabled
|
||||
params = params.copy()
|
||||
params["stream"] = False
|
||||
response = completions.create(**params)
|
||||
|
||||
@ -185,20 +185,21 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'description': 'Currency exchange calculator.',\n",
|
||||
" 'name': 'currency_calculator',\n",
|
||||
" 'parameters': {'type': 'object',\n",
|
||||
" 'properties': {'base_amount': {'type': 'number',\n",
|
||||
" 'description': 'Amount of currency in base_currency'},\n",
|
||||
" 'base_currency': {'enum': ['USD', 'EUR'],\n",
|
||||
" 'type': 'string',\n",
|
||||
" 'default': 'USD',\n",
|
||||
" 'description': 'Base currency'},\n",
|
||||
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
|
||||
" 'type': 'string',\n",
|
||||
" 'default': 'EUR',\n",
|
||||
" 'description': 'Quote currency'}},\n",
|
||||
" 'required': ['base_amount']}}]"
|
||||
"[{'type': 'function',\n",
|
||||
" 'function': {'description': 'Currency exchange calculator.',\n",
|
||||
" 'name': 'currency_calculator',\n",
|
||||
" 'parameters': {'type': 'object',\n",
|
||||
" 'properties': {'base_amount': {'type': 'number',\n",
|
||||
" 'description': 'Amount of currency in base_currency'},\n",
|
||||
" 'base_currency': {'enum': ['USD', 'EUR'],\n",
|
||||
" 'type': 'string',\n",
|
||||
" 'default': 'USD',\n",
|
||||
" 'description': 'Base currency'},\n",
|
||||
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
|
||||
" 'type': 'string',\n",
|
||||
" 'default': 'EUR',\n",
|
||||
" 'description': 'Quote currency'}},\n",
|
||||
" 'required': ['base_amount']}}}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
@ -207,7 +208,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chatbot.llm_config[\"functions\"]"
|
||||
"chatbot.llm_config[\"tools\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -259,10 +260,14 @@
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: currency_calculator *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_2mZCDF9fe8WJh6SveIwdGGEy): currency_calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"base_amount\":123.45,\"base_currency\":\"USD\",\"quote_currency\":\"EUR\"}\n",
|
||||
"\u001b[32m********************************************************\u001b[0m\n",
|
||||
"{\n",
|
||||
" \"base_amount\": 123.45,\n",
|
||||
" \"base_currency\": \"USD\",\n",
|
||||
" \"quote_currency\": \"EUR\"\n",
|
||||
"}\n",
|
||||
"\u001b[32m************************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
@ -276,7 +281,7 @@
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"123.45 USD is equivalent to approximately 112.23 EUR.\n",
|
||||
"123.45 USD is approximately 112.23 EUR.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
@ -370,27 +375,28 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'description': 'Currency exchange calculator.',\n",
|
||||
" 'name': 'currency_calculator',\n",
|
||||
" 'parameters': {'type': 'object',\n",
|
||||
" 'properties': {'base': {'properties': {'currency': {'description': 'Currency symbol',\n",
|
||||
" 'enum': ['USD', 'EUR'],\n",
|
||||
" 'title': 'Currency',\n",
|
||||
" 'type': 'string'},\n",
|
||||
" 'amount': {'default': 0,\n",
|
||||
" 'description': 'Amount of currency',\n",
|
||||
" 'minimum': 0.0,\n",
|
||||
" 'title': 'Amount',\n",
|
||||
" 'type': 'number'}},\n",
|
||||
" 'required': ['currency'],\n",
|
||||
" 'title': 'Currency',\n",
|
||||
" 'type': 'object',\n",
|
||||
" 'description': 'Base currency: amount and currency symbol'},\n",
|
||||
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
|
||||
" 'type': 'string',\n",
|
||||
" 'default': 'USD',\n",
|
||||
" 'description': 'Quote currency symbol'}},\n",
|
||||
" 'required': ['base']}}]"
|
||||
"[{'type': 'function',\n",
|
||||
" 'function': {'description': 'Currency exchange calculator.',\n",
|
||||
" 'name': 'currency_calculator',\n",
|
||||
" 'parameters': {'type': 'object',\n",
|
||||
" 'properties': {'base': {'properties': {'currency': {'description': 'Currency symbol',\n",
|
||||
" 'enum': ['USD', 'EUR'],\n",
|
||||
" 'title': 'Currency',\n",
|
||||
" 'type': 'string'},\n",
|
||||
" 'amount': {'default': 0,\n",
|
||||
" 'description': 'Amount of currency',\n",
|
||||
" 'minimum': 0.0,\n",
|
||||
" 'title': 'Amount',\n",
|
||||
" 'type': 'number'}},\n",
|
||||
" 'required': ['currency'],\n",
|
||||
" 'title': 'Currency',\n",
|
||||
" 'type': 'object',\n",
|
||||
" 'description': 'Base currency: amount and currency symbol'},\n",
|
||||
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
|
||||
" 'type': 'string',\n",
|
||||
" 'default': 'USD',\n",
|
||||
" 'description': 'Quote currency symbol'}},\n",
|
||||
" 'required': ['base']}}}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
@ -399,7 +405,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chatbot.llm_config[\"functions\"]"
|
||||
"chatbot.llm_config[\"tools\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -419,10 +425,16 @@
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: currency_calculator *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_MLtsPcVJXhdpvDPNNxfTB3OB): currency_calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"base\":{\"currency\":\"EUR\",\"amount\":112.23},\"quote_currency\":\"USD\"}\n",
|
||||
"\u001b[32m********************************************************\u001b[0m\n",
|
||||
"{\n",
|
||||
" \"base\": {\n",
|
||||
" \"currency\": \"EUR\",\n",
|
||||
" \"amount\": 112.23\n",
|
||||
" },\n",
|
||||
" \"quote_currency\": \"USD\"\n",
|
||||
"}\n",
|
||||
"\u001b[32m************************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
@ -436,7 +448,7 @@
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
|
||||
"112.23 Euros is approximately 123.45 US Dollars.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
@ -477,10 +489,16 @@
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested function Call: currency_calculator *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_WrBjnoLeXilBPuj9nTJLM5wh): currency_calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}\n",
|
||||
"\u001b[32m********************************************************\u001b[0m\n",
|
||||
"{\n",
|
||||
" \"base\": {\n",
|
||||
" \"currency\": \"USD\",\n",
|
||||
" \"amount\": 123.45\n",
|
||||
" },\n",
|
||||
" \"quote_currency\": \"EUR\"\n",
|
||||
"}\n",
|
||||
"\u001b[32m************************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
@ -543,7 +561,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
"version": "3.11.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -489,9 +489,9 @@ def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]
|
||||
def test_register_for_llm():
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setenv("OPENAI_API_KEY", "mock")
|
||||
agent3 = ConversableAgent(name="agent3", llm_config={})
|
||||
agent2 = ConversableAgent(name="agent2", llm_config={})
|
||||
agent1 = ConversableAgent(name="agent1", llm_config={})
|
||||
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
|
||||
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
|
||||
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})
|
||||
|
||||
@agent3.register_for_llm()
|
||||
@agent2.register_for_llm(name="python")
|
||||
@ -501,27 +501,30 @@ def test_register_for_llm():
|
||||
|
||||
expected1 = [
|
||||
{
|
||||
"description": "run cell in ipython and return the execution result.",
|
||||
"name": "exec_python",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cell": {
|
||||
"type": "string",
|
||||
"description": "Valid Python cell to execute.",
|
||||
}
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "run cell in ipython and return the execution result.",
|
||||
"name": "exec_python",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cell": {
|
||||
"type": "string",
|
||||
"description": "Valid Python cell to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["cell"],
|
||||
},
|
||||
"required": ["cell"],
|
||||
},
|
||||
}
|
||||
]
|
||||
expected2 = copy.deepcopy(expected1)
|
||||
expected2[0]["name"] = "python"
|
||||
expected2[0]["function"]["name"] = "python"
|
||||
expected3 = expected2
|
||||
|
||||
assert agent1.llm_config["functions"] == expected1
|
||||
assert agent2.llm_config["functions"] == expected2
|
||||
assert agent3.llm_config["functions"] == expected3
|
||||
assert agent1.llm_config["tools"] == expected1
|
||||
assert agent2.llm_config["tools"] == expected2
|
||||
assert agent3.llm_config["tools"] == expected3
|
||||
|
||||
@agent3.register_for_llm()
|
||||
@agent2.register_for_llm()
|
||||
@ -531,26 +534,29 @@ def test_register_for_llm():
|
||||
|
||||
expected1 = expected1 + [
|
||||
{
|
||||
"name": "sh",
|
||||
"description": "run a shell script and return the execution result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "Valid shell script to execute.",
|
||||
}
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sh",
|
||||
"description": "run a shell script and return the execution result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "Valid shell script to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["script"],
|
||||
},
|
||||
"required": ["script"],
|
||||
},
|
||||
}
|
||||
]
|
||||
expected2 = expected2 + [expected1[1]]
|
||||
expected3 = expected3 + [expected1[1]]
|
||||
|
||||
assert agent1.llm_config["functions"] == expected1
|
||||
assert agent2.llm_config["functions"] == expected2
|
||||
assert agent3.llm_config["functions"] == expected3
|
||||
assert agent1.llm_config["tools"] == expected1
|
||||
assert agent2.llm_config["tools"] == expected2
|
||||
assert agent3.llm_config["tools"] == expected3
|
||||
|
||||
|
||||
def test_register_for_llm_without_description():
|
||||
@ -586,7 +592,7 @@ def test_register_for_llm_without_LLM():
|
||||
def test_register_for_execution():
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setenv("OPENAI_API_KEY", "mock")
|
||||
agent = ConversableAgent(name="agent", llm_config={})
|
||||
agent = ConversableAgent(name="agent", llm_config={"config_list": []})
|
||||
user_proxy_1 = UserProxyAgent(name="user_proxy_1")
|
||||
user_proxy_2 = UserProxyAgent(name="user_proxy_2")
|
||||
|
||||
|
||||
235
test/agentchat/test_tool_calls.py
Normal file
235
test/agentchat/test_tool_calls.py
Normal file
@ -0,0 +1,235 @@
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
OpenAI = None
|
||||
import inspect
|
||||
import pytest
|
||||
import json
|
||||
import autogen
|
||||
from conftest import skip_openai
|
||||
from autogen.math_utils import eval_math_responses
|
||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
|
||||
import sys
|
||||
from autogen.oai.client import TOOL_ENABLED
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
skip = True
|
||||
else:
|
||||
skip = False or skip_openai
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai or not TOOL_ENABLED, reason="openai>=1.1.0 not installed or requested to skip")
|
||||
def test_eval_math_responses():
|
||||
config_list = autogen.config_list_from_models(
|
||||
KEY_LOC, exclude="aoai", model_list=["gpt-4-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k"]
|
||||
)
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "eval_math_responses",
|
||||
"description": "Select a response for a math problem using voting, and check if the response is correct if the solution is provided",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"responses": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "The responses in a list",
|
||||
},
|
||||
"solution": {
|
||||
"type": "string",
|
||||
"description": "The canonical solution",
|
||||
},
|
||||
},
|
||||
"required": ["responses"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
client = autogen.OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": 'evaluate the math responses ["1", "5/2", "5/2"] against the true answer \\frac{5}{2}',
|
||||
},
|
||||
],
|
||||
tools=tools,
|
||||
)
|
||||
print(response)
|
||||
responses = client.extract_text_or_completion_object(response)
|
||||
print(responses[0])
|
||||
tool_calls = responses[0].tool_calls
|
||||
function_call = tool_calls[0].function
|
||||
name, arguments = function_call.name, json.loads(function_call.arguments)
|
||||
assert name == "eval_math_responses"
|
||||
print(arguments["responses"])
|
||||
# if isinstance(arguments["responses"], str):
|
||||
# arguments["responses"] = json.loads(arguments["responses"])
|
||||
arguments["responses"] = [f"\\boxed{{{x}}}" for x in arguments["responses"]]
|
||||
print(arguments["responses"])
|
||||
arguments["solution"] = f"\\boxed{{{arguments['solution']}}}"
|
||||
print(eval_math_responses(**arguments))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip_openai or not TOOL_ENABLED or not sys.version.startswith("3.10"),
|
||||
reason="do not run if openai is <1.1.0 or py!=3.10 or requested to skip",
|
||||
)
|
||||
def test_update_tool():
|
||||
config_list_gpt4 = autogen.config_list_from_json(
|
||||
OAI_CONFIG_LIST,
|
||||
filter_dict={
|
||||
"model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
|
||||
},
|
||||
file_location=KEY_LOC,
|
||||
)
|
||||
llm_config = {
|
||||
"config_list": config_list_gpt4,
|
||||
"seed": 42,
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
user_proxy = autogen.UserProxyAgent(
|
||||
name="user_proxy",
|
||||
human_input_mode="NEVER",
|
||||
is_termination_msg=lambda x: True if "TERMINATE" in x.get("content") else False,
|
||||
)
|
||||
assistant = autogen.AssistantAgent(name="test", llm_config=llm_config)
|
||||
|
||||
# Define a new function *after* the assistant has been created
|
||||
assistant.update_tool_signature(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "greet_user",
|
||||
"description": "Greets the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
is_remove=False,
|
||||
)
|
||||
user_proxy.initiate_chat(
|
||||
assistant,
|
||||
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
|
||||
)
|
||||
messages1 = assistant.chat_messages[user_proxy][-1]["content"]
|
||||
print(messages1)
|
||||
|
||||
assistant.update_tool_signature("greet_user", is_remove=True)
|
||||
user_proxy.initiate_chat(
|
||||
assistant,
|
||||
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
|
||||
)
|
||||
messages2 = assistant.chat_messages[user_proxy][-1]["content"]
|
||||
print(messages2)
|
||||
# The model should know about the function in the context of the conversation
|
||||
assert "greet_user" in messages1
|
||||
assert "greet_user" not in messages2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
|
||||
def test_multi_tool_call():
|
||||
class FakeAgent(autogen.Agent):
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
self.received = []
|
||||
|
||||
def receive(
|
||||
self,
|
||||
message,
|
||||
sender,
|
||||
request_reply=None,
|
||||
silent=False,
|
||||
):
|
||||
message = message if isinstance(message, list) else [message]
|
||||
self.received.extend(message)
|
||||
|
||||
user_proxy = autogen.UserProxyAgent(
|
||||
name="user_proxy",
|
||||
human_input_mode="NEVER",
|
||||
is_termination_msg=lambda x: True if "TERMINATE" in x.get("content") else False,
|
||||
)
|
||||
user_proxy.register_function({"echo": lambda str: str})
|
||||
|
||||
fake_agent = FakeAgent("fake_agent")
|
||||
|
||||
user_proxy.receive(
|
||||
message={
|
||||
"content": "test multi tool call",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tool_1",
|
||||
"type": "function",
|
||||
"function": {"name": "echo", "arguments": json.JSONEncoder().encode({"str": "hello world"})},
|
||||
},
|
||||
{
|
||||
"id": "tool_2",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "echo",
|
||||
"arguments": json.JSONEncoder().encode({"str": "goodbye and thanks for all the fish"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "tool_3",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multi_tool_call_echo", # normalized "multi_tool_call.echo"
|
||||
"arguments": json.JSONEncoder().encode({"str": "goodbye and thanks for all the fish"}),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
sender=fake_agent,
|
||||
request_reply=True,
|
||||
)
|
||||
|
||||
assert fake_agent.received == [
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_responses": [
|
||||
{"tool_call_id": "tool_1", "role": "tool", "name": "echo", "content": "hello world"},
|
||||
{
|
||||
"tool_call_id": "tool_2",
|
||||
"role": "tool",
|
||||
"name": "echo",
|
||||
"content": "goodbye and thanks for all the fish",
|
||||
},
|
||||
{
|
||||
"tool_call_id": "tool_3",
|
||||
"role": "tool",
|
||||
"name": "multi_tool_call_echo",
|
||||
"content": "Error: Function multi_tool_call_echo not found.",
|
||||
},
|
||||
],
|
||||
"content": inspect.cleandoc(
|
||||
"""
|
||||
Tool call: echo
|
||||
Id: tool_1
|
||||
hello world
|
||||
|
||||
Tool call: echo
|
||||
Id: tool_2
|
||||
goodbye and thanks for all the fish
|
||||
|
||||
Tool call: multi_tool_call_echo
|
||||
Id: tool_3
|
||||
Error: Function multi_tool_call_echo not found.
|
||||
"""
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_update_tool()
|
||||
test_eval_math_responses()
|
||||
test_multi_tool_call()
|
||||
@ -210,54 +210,60 @@ def test_get_function_schema_missing() -> None:
|
||||
|
||||
def test_get_function_schema() -> None:
|
||||
expected_v2 = {
|
||||
"description": "function g",
|
||||
"name": "fancy name for g",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string", "description": "Parameter a"},
|
||||
"b": {"type": "integer", "description": "b", "default": 2},
|
||||
"c": {"type": "number", "description": "Parameter c", "default": 0.1},
|
||||
"d": {
|
||||
"additionalProperties": {
|
||||
"maxItems": 2,
|
||||
"minItems": 2,
|
||||
"prefixItems": [
|
||||
{"anyOf": [{"type": "integer"}, {"type": "null"}]},
|
||||
{"items": {"type": "number"}, "type": "array"},
|
||||
],
|
||||
"type": "array",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "function g",
|
||||
"name": "fancy name for g",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string", "description": "Parameter a"},
|
||||
"b": {"type": "integer", "description": "b", "default": 2},
|
||||
"c": {"type": "number", "description": "Parameter c", "default": 0.1},
|
||||
"d": {
|
||||
"additionalProperties": {
|
||||
"maxItems": 2,
|
||||
"minItems": 2,
|
||||
"prefixItems": [
|
||||
{"anyOf": [{"type": "integer"}, {"type": "null"}]},
|
||||
{"items": {"type": "number"}, "type": "array"},
|
||||
],
|
||||
"type": "array",
|
||||
},
|
||||
"type": "object",
|
||||
"description": "d",
|
||||
},
|
||||
"type": "object",
|
||||
"description": "d",
|
||||
},
|
||||
"required": ["a", "d"],
|
||||
},
|
||||
"required": ["a", "d"],
|
||||
},
|
||||
}
|
||||
|
||||
# the difference is that the v1 version does not handle Union types (Optional is Union[T, None])
|
||||
expected_v1 = {
|
||||
"description": "function g",
|
||||
"name": "fancy name for g",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string", "description": "Parameter a"},
|
||||
"b": {"type": "integer", "description": "b", "default": 2},
|
||||
"c": {"type": "number", "description": "Parameter c", "default": 0.1},
|
||||
"d": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "array",
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
"items": [{"type": "integer"}, {"type": "array", "items": {"type": "number"}}],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "function g",
|
||||
"name": "fancy name for g",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string", "description": "Parameter a"},
|
||||
"b": {"type": "integer", "description": "b", "default": 2},
|
||||
"c": {"type": "number", "description": "Parameter c", "default": 0.1},
|
||||
"d": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "array",
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
"items": [{"type": "integer"}, {"type": "array", "items": {"type": "number"}}],
|
||||
},
|
||||
"description": "d",
|
||||
},
|
||||
"description": "d",
|
||||
},
|
||||
"required": ["a", "d"],
|
||||
},
|
||||
"required": ["a", "d"],
|
||||
},
|
||||
}
|
||||
|
||||
@ -291,39 +297,42 @@ def test_get_function_schema_pydantic() -> None:
|
||||
pass
|
||||
|
||||
expected = {
|
||||
"description": "Currency exchange calculator.",
|
||||
"name": "currency_calculator",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"base": {
|
||||
"properties": {
|
||||
"currency": {
|
||||
"description": "Currency code",
|
||||
"enum": ["USD", "EUR"],
|
||||
"title": "Currency",
|
||||
"type": "string",
|
||||
},
|
||||
"amount": {
|
||||
"default": 100.0,
|
||||
"description": "Amount of money in the currency",
|
||||
"title": "Amount",
|
||||
"type": "number",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Currency exchange calculator.",
|
||||
"name": "currency_calculator",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"base": {
|
||||
"properties": {
|
||||
"currency": {
|
||||
"description": "Currency code",
|
||||
"enum": ["USD", "EUR"],
|
||||
"title": "Currency",
|
||||
"type": "string",
|
||||
},
|
||||
"amount": {
|
||||
"default": 100.0,
|
||||
"description": "Amount of money in the currency",
|
||||
"title": "Amount",
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"required": ["currency"],
|
||||
"title": "Currency",
|
||||
"type": "object",
|
||||
"description": "Base currency: amount and currency symbol",
|
||||
},
|
||||
"quote_currency": {
|
||||
"enum": ["USD", "EUR"],
|
||||
"type": "string",
|
||||
"default": "EUR",
|
||||
"description": "Quote currency symbol (default: 'EUR')",
|
||||
},
|
||||
"required": ["currency"],
|
||||
"title": "Currency",
|
||||
"type": "object",
|
||||
"description": "Base currency: amount and currency symbol",
|
||||
},
|
||||
"quote_currency": {
|
||||
"enum": ["USD", "EUR"],
|
||||
"type": "string",
|
||||
"default": "EUR",
|
||||
"description": "Quote currency symbol (default: 'EUR')",
|
||||
},
|
||||
"required": ["base"],
|
||||
},
|
||||
"required": ["base"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ We have designed a generic [`ConversableAgent`](../reference/agentchat/conversab
|
||||
|
||||
- The [`AssistantAgent`](../reference/agentchat/assistant_agent.md#assistantagent-objects) is designed to act as an AI assistant, using LLMs by default but not requiring human input or code execution. It could write Python code (in a Python coding block) for a user to execute when a message (typically a description of a task that needs to be solved) is received. Under the hood, the Python code is written by LLM (e.g., GPT-4). It can also receive the execution results and suggest corrections or bug fixes. Its behavior can be altered by passing a new system message. The LLM [inference](#enhanced-inference) configuration can be configured via [`llm_config`].
|
||||
|
||||
- The [`UserProxyAgent`](../reference/agentchat/user_proxy_agent.md#userproxyagent-objects) is conceptually a proxy agent for humans, soliciting human input as the agent's reply at each interaction turn by default and also having the capability to execute code and call functions. The [`UserProxyAgent`](../reference/agentchat/user_proxy_agent.md#userproxyagent-objects) triggers code execution automatically when it detects an executable code block in the received message and no human user input is provided. Code execution can be disabled by setting the `code_execution_config` parameter to False. LLM-based response is disabled by default. It can be enabled by setting `llm_config` to a dict corresponding to the [inference](/docs/Use-Cases/enhanced_inference) configuration. When `llm_config` is set as a dictionary, [`UserProxyAgent`](../reference/agentchat/user_proxy_agent.md#userproxyagent-objects) can generate replies using an LLM when code execution is not performed.
|
||||
- The [`UserProxyAgent`](../reference/agentchat/user_proxy_agent.md#userproxyagent-objects) is conceptually a proxy agent for humans, soliciting human input as the agent's reply at each interaction turn by default and also having the capability to execute code and call functions or tools. The [`UserProxyAgent`](../reference/agentchat/user_proxy_agent.md#userproxyagent-objects) triggers code execution automatically when it detects an executable code block in the received message and no human user input is provided. Code execution can be disabled by setting the `code_execution_config` parameter to False. LLM-based response is disabled by default. It can be enabled by setting `llm_config` to a dict corresponding to the [inference](/docs/Use-Cases/enhanced_inference) configuration. When `llm_config` is set as a dictionary, [`UserProxyAgent`](../reference/agentchat/user_proxy_agent.md#userproxyagent-objects) can generate replies using an LLM when code execution is not performed.
|
||||
|
||||
The auto-reply capability of [`ConversableAgent`](../reference/agentchat/conversable_agent.md#conversableagent-objects) allows for more autonomous multi-agent communication while retaining the possibility of human intervention.
|
||||
One can also easily extend it by registering reply functions with the [`register_reply()`](../reference/agentchat/conversable_agent.md#register_reply) method.
|
||||
@ -39,16 +39,16 @@ assistant = AssistantAgent(name="assistant")
|
||||
# create a UserProxyAgent instance named "user_proxy"
|
||||
user_proxy = UserProxyAgent(name="user_proxy")
|
||||
```
|
||||
#### Function calling
|
||||
#### Tool calling
|
||||
|
||||
Function calling enables agents to interact with external tools and APIs more efficiently.
|
||||
Tool calling enables agents to interact with external tools and APIs more efficiently.
|
||||
This feature allows the AI model to intelligently choose to output a JSON object containing
|
||||
arguments to call specific functions based on the user's input. A function to be called is
|
||||
arguments to call specific tools based on the user's input. A tool to be called is
|
||||
specified with a JSON schema describing its parameters and their types. Writing such JSON schema
|
||||
is complex and error-prone and that is why AutoGen framework provides two high level function decorators for automatically generating such schema using type hints on standard Python datatypes
|
||||
or Pydantic models:
|
||||
|
||||
1. [`ConversableAgent.register_for_llm`](../reference/agentchat/conversable_agent#register_for_llm) is used to register the function in the `llm_config` of a ConversableAgent. The ConversableAgent agent can propose execution of a registrated function, but the actual execution will be performed by a UserProxy agent.
|
||||
1. [`ConversableAgent.register_for_llm`](../reference/agentchat/conversable_agent#register_for_llm) is used to register the function as a Tool in the `llm_config` of a ConversableAgent. The ConversableAgent agent can propose execution of a registrated Tool, but the actual execution will be performed by a UserProxy agent.
|
||||
|
||||
2. [`ConversableAgent.register_for_execution`](../reference/agentchat/conversable_agent#register_for_execution) is used to register the function in the `function_map` of a UserProxy agent.
|
||||
|
||||
@ -81,9 +81,10 @@ def currency_calculator(
|
||||
|
||||
Notice the use of [Annotated](https://docs.python.org/3/library/typing.html?highlight=annotated#typing.Annotated) to specify the type and the description of each parameter. The return value of the function must be either string or serializable to string using the [`json.dumps()`](https://docs.python.org/3/library/json.html#json.dumps) or [`Pydantic` model dump to JSON](https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump_json) (both version 1.x and 2.x are supported).
|
||||
|
||||
You can check the JSON schema generated by the decorator `chatbot.llm_config["functions"]`:
|
||||
You can check the JSON schema generated by the decorator `chatbot.llm_config["tools"]`:
|
||||
```python
|
||||
[{'description': 'Currency exchange calculator.',
|
||||
[{'type': 'function', 'function':
|
||||
{'description': 'Currency exchange calculator.',
|
||||
'name': 'currency_calculator',
|
||||
'parameters': {'type': 'object',
|
||||
'properties': {'base_amount': {'type': 'number',
|
||||
@ -96,7 +97,7 @@ You can check the JSON schema generated by the decorator `chatbot.llm_config["fu
|
||||
'type': 'string',
|
||||
'default': 'EUR',
|
||||
'description': 'Quote currency'}},
|
||||
'required': ['base_amount']}}]
|
||||
'required': ['base_amount']}}}]
|
||||
```
|
||||
Agents can now use the function as follows:
|
||||
```
|
||||
@ -107,7 +108,7 @@ How much is 123.45 USD in EUR?
|
||||
--------------------------------------------------------------------------------
|
||||
chatbot (to user_proxy):
|
||||
|
||||
***** Suggested function Call: currency_calculator *****
|
||||
***** Suggested tool Call: currency_calculator *****
|
||||
Arguments:
|
||||
{"base_amount":123.45,"base_currency":"USD","quote_currency":"EUR"}
|
||||
********************************************************
|
||||
@ -163,7 +164,8 @@ def currency_calculator(
|
||||
|
||||
The generated JSON schema has additional properties such as minimum value encoded:
|
||||
```python
|
||||
[{'description': 'Currency exchange calculator.',
|
||||
[{'type': 'function', 'function':
|
||||
{'description': 'Currency exchange calculator.',
|
||||
'name': 'currency_calculator',
|
||||
'parameters': {'type': 'object',
|
||||
'properties': {'base': {'properties': {'currency': {'description': 'Currency symbol',
|
||||
@ -183,7 +185,7 @@ The generated JSON schema has additional properties such as minimum value encode
|
||||
'type': 'string',
|
||||
'default': 'USD',
|
||||
'description': 'Quote currency symbol'}},
|
||||
'required': ['base']}}]
|
||||
'required': ['base']}}}]
|
||||
```
|
||||
|
||||
For more in-depth examples, please check the following:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user