Support function_call in autogen/agent (#1091)

* update funccall

* code format

* update to comments

* update notebook

* remove test for py3.7

* allow funccall to class functions

* add test and clean up notebook

* revise notebook and test

* update

* update mathagent

* Update flaml/autogen/agent/agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update flaml/autogen/agent/user_proxy_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* revise to comments

* revise function call design, notebook and test. add doc

* code format

* ad message_to_dict function

* update mathproxyagent

* revise docstr

* update

* Update flaml/autogen/agent/math_user_proxy_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update flaml/autogen/agent/math_user_proxy_agent.py

Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>

* Update flaml/autogen/agent/user_proxy_agent.py

Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>

* simply funccall in userproxyagent, rewind auto-gen.md, revise to comments

* code format

* update

* remove notebook for another pr

* revise oai_conversation part  in agent, revise function exec in user_proxy_agent

* update test_funccall

* update

* update

* fix pydantic version

* Update test/autogen/test_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* fix bug

* fix bug

* update

* update is_termination_msg to accept dict

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
Yiran Wu 2023-07-06 06:08:44 +08:00 committed by GitHub
parent dd9202bb01
commit ca10b286cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 306 additions and 39 deletions

View File

@ -1,5 +1,6 @@
from .agent import Agent
from .assistant_agent import AssistantAgent
from .user_proxy_agent import UserProxyAgent
from .math_user_proxy_agent import MathUserProxyAgent
__all__ = ["Agent", "AssistantAgent", "UserProxyAgent"]
__all__ = ["Agent", "AssistantAgent", "UserProxyAgent", "MathUserProxyAgent"]

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from typing import Dict, Union
class Agent:
@ -17,7 +18,7 @@ class Agent:
# empty memory
self._memory = []
# a dictionary of conversations, default value is list
self._conversations = defaultdict(list)
self._oai_conversations = defaultdict(list)
self._name = name
self._system_message = system_message
@ -30,22 +31,95 @@ class Agent:
"""Remember something."""
self._memory.append(memory)
def _send(self, message, recipient):
@staticmethod
def _message_to_dict(message: Union[Dict, str]):
"""Convert a message to a dictionary.
The message can be a string or a dictionary. The string with be put in the "content" field of the new dictionary.
"""
if isinstance(message, str):
return {"content": message}
else:
return message
def _append_oai_message(self, message: Union[Dict, str], role, conversation_id):
"""Append a message to the openai conversation.
If the message received is a string, it will be put in the "content" field of the new dictionary.
If the message received is a dictionary but does not have any of the two fields "content" or "function_call",
this message is not a valid openai message and will be ignored.
Args:
message (dict or str): message to be appended to the openai conversation.
role (str): role of the message, can be "assistant" or "function".
conversation_id (str): id of the conversation, should be the name of the recipient or sender.
"""
message = self._message_to_dict(message)
# create openai message to be appended to the openai conversation that can be passed to oai directly.
oai_message = {k: message[k] for k in ("content", "function_call", "name") if k in message}
if "content" not in oai_message and "function_call" not in oai_message:
return
oai_message["role"] = "function" if message.get("role") == "function" else role
self._oai_conversations[conversation_id].append(oai_message)
def _send(self, message: Union[Dict, str], recipient):
"""Send a message to another agent."""
self._conversations[recipient.name].append({"content": message, "role": "assistant"})
# When the agent composes and sends the message, the role of the message is "assistant". (If 'role' exists and is 'function', it will remain unchanged.)
self._append_oai_message(message, "assistant", recipient.name)
recipient.receive(message, self)
def _receive(self, message, sender):
"""Receive a message from another agent."""
print("\n", "-" * 80, "\n", flush=True)
print(sender.name, "(to", f"{self.name}):", flush=True)
print(message, flush=True)
self._conversations[sender.name].append({"content": message, "role": "user"})
def _receive(self, message: Union[Dict, str], sender):
"""Receive a message from another agent.
def receive(self, message, sender):
Args:
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (All fields are optional).
1. "content": content of the message, can be None.
2. "function_call": a dictionary containing the function name and arguments.
3. "role": role of the message, can be "assistant", "user", "function".
This field is only needed to distinguish between "function" or "assistant"/"user".
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
sender: sender of an Agent instance.
"""
message = self._message_to_dict(message)
# print the message received
print(sender.name, "(to", f"{self.name}):\n", flush=True)
if message.get("role") == "function":
func_print = f"***** Response from calling function \"{message['name']}\" *****"
print(func_print, flush=True)
print(message["content"], flush=True)
print("*" * len(func_print), flush=True)
else:
if message.get("content") is not None:
print(message["content"], flush=True)
if "function_call" in message:
func_print = f"***** Suggested function Call: {message['function_call'].get('name', '(No function name found)')} *****"
print(func_print, flush=True)
print(
"Arguments: \n",
message["function_call"].get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
print("*" * len(func_print), flush=True)
print("\n", "-" * 80, flush=True, sep="")
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
self._append_oai_message(message, "user", sender.name)
def receive(self, message: Union[Dict, str], sender):
"""Receive a message from another agent.
This method is called by the sender.
It needs to be overriden by the subclass to perform followup actions.
Args:
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (All fields are optional).
1. "content": content of the message, can be None.
2. "function_call": a dictionary containing the function name and arguments.
3. "role": role of the message, can be "assistant", "user", "function".
This field is only needed to distinguish between "function" or "assistant"/"user".
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
sender: sender of an Agent instance.
"""
self._receive(message, sender)
# perform actions based on the message

View File

@ -1,6 +1,7 @@
from .agent import Agent
from flaml.autogen.code_utils import DEFAULT_MODEL
from flaml import oai
from typing import Dict, Union
class AssistantAgent(Agent):
@ -33,16 +34,15 @@ class AssistantAgent(Agent):
self._config.update(config)
self._sender_dict = {}
def receive(self, message, sender):
def receive(self, message: Union[Dict, str], sender):
if sender.name not in self._sender_dict:
self._sender_dict[sender.name] = sender
self._conversations[sender.name] = [{"content": self._system_message, "role": "system"}]
self._oai_conversations[sender.name] = [{"content": self._system_message, "role": "system"}]
super().receive(message, sender)
responses = oai.ChatCompletion.create(messages=self._conversations[sender.name], **self._config)
# TODO: handle function_call
response = oai.ChatCompletion.extract_text(responses)[0]
self._send(response, sender)
responses = oai.ChatCompletion.create(messages=self._oai_conversations[sender.name], **self._config)
self._send(oai.ChatCompletion.extract_text_or_function_call(responses)[0], sender)
def reset(self):
self._sender_dict.clear()
self._conversations.clear()
self._oai_conversations.clear()

View File

@ -84,6 +84,10 @@ Problem: """,
def is_termination_msg(x):
"""Check if a message is a termination message."""
if isinstance(x, dict):
x = x.get("content")
if x is None:
return False
cb = extract_code(x)
contain_code = False
for c in cb:
@ -129,6 +133,7 @@ class MathUserProxyAgent(UserProxyAgent):
name="MathChatAgent", # default set to MathChatAgent
system_message="",
work_dir=None,
function_map=defaultdict(callable),
human_input_mode="NEVER", # Fully automated
max_consecutive_auto_reply=None,
is_termination_msg=is_termination_msg,
@ -150,11 +155,12 @@ class MathUserProxyAgent(UserProxyAgent):
the number of auto reply reaches the max_consecutive_auto_reply.
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
The limit only plays a role when human_input_mode is not "ALWAYS".
is_termination_msg (function): a function that takes a message and returns a boolean value.
This function is used to determine if a received message is a termination message.
is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call".
use_docker (bool): whether to use docker to execute the code.
max_invalid_q_per_step (int): (ADDED) the maximum number of invalid queries per step.
**config (dict): other configurations.
@ -163,6 +169,7 @@ class MathUserProxyAgent(UserProxyAgent):
name=name,
system_message=system_message,
work_dir=work_dir,
function_map=function_map,
human_input_mode=human_input_mode,
max_consecutive_auto_reply=max_consecutive_auto_reply,
is_termination_msg=is_termination_msg,
@ -208,7 +215,7 @@ class MathUserProxyAgent(UserProxyAgent):
return PROMPTS[prompt_type] + problem
def _reset(self):
self._conversations.clear()
self._oai_conversations.clear()
self._valid_q_count = 0
self._total_q_count = 0
self._accum_invalid_q_per_step = 0
@ -288,6 +295,7 @@ class MathUserProxyAgent(UserProxyAgent):
def auto_reply(self, message, sender, default_reply=""):
"""Generate an auto reply."""
message = message.get("content", "")
code_blocks = extract_code(message)
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
@ -391,7 +399,7 @@ class WolframAlphaAPIWrapper(BaseModel):
extra = Extra.forbid
@root_validator()
@root_validator(skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
wolfram_alpha_appid = get_from_dict_or_env(values, "wolfram_alpha_appid", "WOLFRAM_ALPHA_APPID")

View File

@ -1,7 +1,8 @@
from typing import Union
from .agent import Agent
from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
from collections import defaultdict
import json
from typing import Dict, Union
class UserProxyAgent(Agent):
@ -15,6 +16,7 @@ class UserProxyAgent(Agent):
system_message="",
work_dir=None,
human_input_mode="ALWAYS",
function_map={},
max_consecutive_auto_reply=None,
is_termination_msg=None,
use_docker=True,
@ -34,11 +36,12 @@ class UserProxyAgent(Agent):
the number of auto reply reaches the max_consecutive_auto_reply.
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
The limit only plays a role when human_input_mode is not "ALWAYS".
is_termination_msg (function): a function that takes a message and returns a boolean value.
This function is used to determine if a received message is a termination message.
is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call".
use_docker (bool or str): bool value of whether to use docker to execute the code,
or str value of the docker image name to use.
**config (dict): other configurations.
@ -47,7 +50,7 @@ class UserProxyAgent(Agent):
self._work_dir = work_dir
self._human_input_mode = human_input_mode
self._is_termination_msg = (
is_termination_msg if is_termination_msg is not None else (lambda x: x == "TERMINATE")
is_termination_msg if is_termination_msg is not None else (lambda x: x.get("content") == "TERMINATE")
)
self._config = config
self._max_consecutive_auto_reply = (
@ -56,6 +59,8 @@ class UserProxyAgent(Agent):
self._consecutive_auto_reply_counter = defaultdict(int)
self._use_docker = use_docker
self._function_map = function_map
@property
def use_docker(self) -> Union[bool, str]:
"""bool value of whether to use docker to execute the code,
@ -89,8 +94,8 @@ class UserProxyAgent(Agent):
)
logs = logs.decode("utf-8")
else:
# TODO: could this happen?
exitcode, logs, image = 1, f"unknown language {lang}"
# In case the language is not supported, we return an error message.
exitcode, logs, image = 1, f"unknown language {lang}", self._use_docker
# raise NotImplementedError
self._use_docker = image
logs_all += "\n" + logs
@ -98,11 +103,86 @@ class UserProxyAgent(Agent):
return exitcode, logs_all
return exitcode, logs_all
def auto_reply(self, message, sender, default_reply=""):
@staticmethod
def _format_json_str(jstr):
"""Remove newlines outside of quotes, and hanlde JSON escape sequences.
1. this function removes the newline in the query outside of quotes otherwise json.loads(s) will fail.
Ex 1:
"{\n"tool": "python",\n"query": "print('hello')\nprint('world')"\n}" -> "{"tool": "python","query": "print('hello')\nprint('world')"}"
Ex 2:
"{\n \"location\": \"Boston, MA\"\n}" -> "{"location": "Boston, MA"}"
2. this function also handles JSON escape sequences inside quotes,
Ex 1:
'{"args": "a\na\na\ta"}' -> '{"args": "a\\na\\na\\ta"}'
"""
result = []
inside_quotes = False
last_char = " "
for char in jstr:
if last_char != "\\" and char == '"':
inside_quotes = not inside_quotes
last_char = char
if not inside_quotes and char == "\n":
continue
if inside_quotes and char == "\n":
char = "\\n"
if inside_quotes and char == "\t":
char = "\\t"
result.append(char)
return "".join(result)
def _execute_function(self, func_call):
"""Execute a function call and return the result.
Args:
func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".
Returns:
A tuple of (is_exec_success, result_dict).
is_exec_success (boolean): whether the execution is successful.
result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
"""
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
is_exec_success = False
if func is not None:
# Extract arguments from a json-like string and put it into a dict.
input_string = self._format_json_str(func_call.get("arguments", "{}"))
try:
arguments = json.loads(input_string)
except json.JSONDecodeError as e:
arguments = None
content = f"Error: {e}\n You argument should follow json format."
# Try to execute the function
if arguments:
try:
content = func(**arguments)
is_exec_success = True
except Exception as e:
content = f"Error: {e}"
else:
content = f"Error: Function {func_name} not found."
return is_exec_success, {
"name": func_name,
"role": "function",
"content": str(content),
}
def auto_reply(self, message: dict, sender, default_reply=""):
"""Generate an auto reply."""
code_blocks = extract_code(message)
if "function_call" in message:
is_exec_success, func_return = self._execute_function(message["function_call"])
self._send(func_return, sender)
return
code_blocks = extract_code(message["content"])
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
# no code block is found, lang should be `UNKNOWN``
# no code block is found, lang should be `UNKNOWN`
self._send(default_reply, sender)
else:
# try to execute the code
@ -110,11 +190,12 @@ class UserProxyAgent(Agent):
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}", sender)
def receive(self, message, sender):
def receive(self, message: Union[Dict, str], sender):
"""Receive a message from the sender agent.
Once a message is received, this function sends a reply to the sender or simply stop.
The reply can be generated automatically or entered manually by a human.
"""
message = self._message_to_dict(message)
super().receive(message, sender)
# default reply is empty (i.e., no reply, in this case we will try to generate auto reply)
reply = ""

View File

@ -197,7 +197,7 @@
" \"user\",\n",
" human_input_mode=\"NEVER\",\n",
" max_consecutive_auto_reply=10,\n",
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\") or x.rstrip().endswith('\"TERMINATE\".'),\n",
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\") or x.get(\"content\", \"\").rstrip().endswith('\"TERMINATE\".'),\n",
" work_dir=\"coding\",\n",
" use_docker=False, # set to True if you are using docker\n",
")\n",
@ -464,7 +464,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.9.16"
},
"vscode": {
"interpreter": {

View File

@ -148,7 +148,7 @@
"user = UserProxyAgent(\n",
" name=\"user\",\n",
" human_input_mode=\"ALWAYS\",\n",
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\"),\n",
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
")\n",
"\n",
"# the purpose of the following line is to log the conversation history\n",

View File

@ -108,7 +108,7 @@
" name=\"user\",\n",
" human_input_mode=\"TERMINATE\",\n",
" max_consecutive_auto_reply=10,\n",
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\"),\n",
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
" work_dir='web',\n",
")"
]

View File

@ -0,0 +1,40 @@
def test_agent():
from flaml.autogen.agent import Agent
dummy_agent_1 = Agent(name="dummy_agent_1")
dummy_agent_2 = Agent(name="dummy_agent_2")
dummy_agent_1.receive("hello", dummy_agent_2) # receive a str
dummy_agent_1.receive(
{
"content": "hello",
},
dummy_agent_2,
) # receive a dict
# receive dict without openai fields to be printed, such as "content", 'function_call'. There should be no error raised.
pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"])
dummy_agent_1.receive({"message": "hello"}, dummy_agent_2)
assert pre_len == len(
dummy_agent_1._oai_conversations["dummy_agent_2"]
), "When the message is not an valid openai message, it should not be appended to the oai conversation."
dummy_agent_1._send("hello", dummy_agent_2) # send a str
dummy_agent_1._send(
{
"content": "hello",
},
dummy_agent_2,
) # send a dict
# receive dict with no openai fields
pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"])
dummy_agent_1._send({"message": "hello"}, dummy_agent_2) # send dict with wrong field
assert pre_len == len(
dummy_agent_1._oai_conversations["dummy_agent_2"]
), "When the message is not a valid openai message, it should not be appended to the oai conversation."
if __name__ == "__main__":
test_agent()

View File

@ -23,7 +23,7 @@ def test_gpt35(human_input_mode="NEVER", max_consecutive_auto_reply=5):
"user",
work_dir=f"{here}/test_agent_scripts",
human_input_mode=human_input_mode,
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE"),
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
max_consecutive_auto_reply=max_consecutive_auto_reply,
use_docker="python:3",
)
@ -51,7 +51,7 @@ def test_create_execute_script(human_input_mode="NEVER", max_consecutive_auto_re
"user",
human_input_mode=human_input_mode,
max_consecutive_auto_reply=max_consecutive_auto_reply,
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE"),
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
)
assistant.receive(
"""Create and execute a script to plot a rocket without using matplotlib""",

View File

@ -60,5 +60,68 @@ def test_eval_math_responses():
print(eval_math_responses(**arguments))
def test_json_extraction():
from flaml.autogen.agent import UserProxyAgent
user = UserProxyAgent(name="test", use_docker=False)
jstr = '{\n"location": "Boston, MA"\n}'
assert user._format_json_str(jstr) == '{"location": "Boston, MA"}'
jstr = '{\n"code": "python",\n"query": "x=3\nprint(x)"}'
assert user._format_json_str(jstr) == '{"code": "python","query": "x=3\\nprint(x)"}'
jstr = '{"code": "a=\\"hello\\""}'
assert user._format_json_str(jstr) == '{"code": "a=\\"hello\\""}'
def test_execute_function():
from flaml.autogen.agent import UserProxyAgent
# 1. test calling a simple function
def add_num(num_to_be_added):
given_num = 10
return num_to_be_added + given_num
user = UserProxyAgent(name="test", function_map={"add_num": add_num})
# correct execution
correct_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
assert user._execute_function(func_call=correct_args)[1]["content"] == "15"
# function name called is wrong or doesn't exist
wrong_func_name = {"name": "subtract_num", "arguments": '{ "num_to_be_added": 5 }'}
assert "Error: Function" in user._execute_function(func_call=wrong_func_name)[1]["content"]
# arguments passed is not in correct json format
wrong_json_format = {
"name": "add_num",
"arguments": '{ "num_to_be_added": 5, given_num: 10 }',
} # should be "given_num" with quotes
assert (
"You argument should follow json format." in user._execute_function(func_call=wrong_json_format)[1]["content"]
)
# function execution error with wrong arguments passed
wrong_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5, "given_num": 10 }'}
assert "Error: " in user._execute_function(func_call=wrong_args)[1]["content"]
# 2. test calling a class method
class AddNum:
def __init__(self, given_num):
self.given_num = given_num
def add(self, num_to_be_added):
self.given_num = num_to_be_added + self.given_num
return self.given_num
user = UserProxyAgent(name="test", function_map={"add_num": AddNum(given_num=10).add})
func_call = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
assert user._execute_function(func_call=func_call)[1]["content"] == "15"
assert user._execute_function(func_call=func_call)[1]["content"] == "20"
if __name__ == "__main__":
test_json_extraction()
test_execute_function()
test_eval_math_responses()

View File

@ -38,7 +38,7 @@ user_proxy = UserProxyAgent(
name="user_proxy",
human_input_mode="NEVER", # in this mode, the agent will never solicit human input but always auto reply
max_consecutive_auto_reply=10, # the maximum number of consecutive auto replies
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE") or x.rstrip().endswith('"TERMINATE".'), # the function to determine whether a message is a termination message
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE") or x.get("content", "").rstrip().endswith('"TERMINATE".'), # the function to determine whether a message is a termination message
work_dir=".",
)