Support function_call in autogen/agent (#1091)

* update funccall

* code format

* update to comments

* update notebook

* remove test for py3.7

* allow funccall to class functions

* add test and clean up notebook

* revise notebook and test

* update

* update mathagent

* Update flaml/autogen/agent/agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update flaml/autogen/agent/user_proxy_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* revise to comments

* revise function call design, notebook and test. add doc

* code format

* ad message_to_dict function

* update mathproxyagent

* revise docstr

* update

* Update flaml/autogen/agent/math_user_proxy_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update flaml/autogen/agent/math_user_proxy_agent.py

Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>

* Update flaml/autogen/agent/user_proxy_agent.py

Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>

* simply funccall in userproxyagent, rewind auto-gen.md, revise to comments

* code format

* update

* remove notebook for another pr

* revise oai_conversation part  in agent, revise function exec in user_proxy_agent

* update test_funccall

* update

* update

* fix pydantic version

* Update test/autogen/test_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* fix bug

* fix bug

* update

* update is_termination_msg to accept dict

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
Yiran Wu 2023-07-06 06:08:44 +08:00 committed by GitHub
parent dd9202bb01
commit ca10b286cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 306 additions and 39 deletions

View File

@ -1,5 +1,6 @@
from .agent import Agent from .agent import Agent
from .assistant_agent import AssistantAgent from .assistant_agent import AssistantAgent
from .user_proxy_agent import UserProxyAgent from .user_proxy_agent import UserProxyAgent
from .math_user_proxy_agent import MathUserProxyAgent
__all__ = ["Agent", "AssistantAgent", "UserProxyAgent"] __all__ = ["Agent", "AssistantAgent", "UserProxyAgent", "MathUserProxyAgent"]

View File

@ -1,4 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict, Union
class Agent: class Agent:
@ -17,7 +18,7 @@ class Agent:
# empty memory # empty memory
self._memory = [] self._memory = []
# a dictionary of conversations, default value is list # a dictionary of conversations, default value is list
self._conversations = defaultdict(list) self._oai_conversations = defaultdict(list)
self._name = name self._name = name
self._system_message = system_message self._system_message = system_message
@ -30,22 +31,95 @@ class Agent:
"""Remember something.""" """Remember something."""
self._memory.append(memory) self._memory.append(memory)
def _send(self, message, recipient): @staticmethod
def _message_to_dict(message: Union[Dict, str]):
"""Convert a message to a dictionary.
The message can be a string or a dictionary. The string with be put in the "content" field of the new dictionary.
"""
if isinstance(message, str):
return {"content": message}
else:
return message
def _append_oai_message(self, message: Union[Dict, str], role, conversation_id):
"""Append a message to the openai conversation.
If the message received is a string, it will be put in the "content" field of the new dictionary.
If the message received is a dictionary but does not have any of the two fields "content" or "function_call",
this message is not a valid openai message and will be ignored.
Args:
message (dict or str): message to be appended to the openai conversation.
role (str): role of the message, can be "assistant" or "function".
conversation_id (str): id of the conversation, should be the name of the recipient or sender.
"""
message = self._message_to_dict(message)
# create openai message to be appended to the openai conversation that can be passed to oai directly.
oai_message = {k: message[k] for k in ("content", "function_call", "name") if k in message}
if "content" not in oai_message and "function_call" not in oai_message:
return
oai_message["role"] = "function" if message.get("role") == "function" else role
self._oai_conversations[conversation_id].append(oai_message)
def _send(self, message: Union[Dict, str], recipient):
"""Send a message to another agent.""" """Send a message to another agent."""
self._conversations[recipient.name].append({"content": message, "role": "assistant"}) # When the agent composes and sends the message, the role of the message is "assistant". (If 'role' exists and is 'function', it will remain unchanged.)
self._append_oai_message(message, "assistant", recipient.name)
recipient.receive(message, self) recipient.receive(message, self)
def _receive(self, message, sender): def _receive(self, message: Union[Dict, str], sender):
"""Receive a message from another agent.""" """Receive a message from another agent.
print("\n", "-" * 80, "\n", flush=True)
print(sender.name, "(to", f"{self.name}):", flush=True)
print(message, flush=True)
self._conversations[sender.name].append({"content": message, "role": "user"})
def receive(self, message, sender): Args:
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (All fields are optional).
1. "content": content of the message, can be None.
2. "function_call": a dictionary containing the function name and arguments.
3. "role": role of the message, can be "assistant", "user", "function".
This field is only needed to distinguish between "function" or "assistant"/"user".
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
sender: sender of an Agent instance.
"""
message = self._message_to_dict(message)
# print the message received
print(sender.name, "(to", f"{self.name}):\n", flush=True)
if message.get("role") == "function":
func_print = f"***** Response from calling function \"{message['name']}\" *****"
print(func_print, flush=True)
print(message["content"], flush=True)
print("*" * len(func_print), flush=True)
else:
if message.get("content") is not None:
print(message["content"], flush=True)
if "function_call" in message:
func_print = f"***** Suggested function Call: {message['function_call'].get('name', '(No function name found)')} *****"
print(func_print, flush=True)
print(
"Arguments: \n",
message["function_call"].get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
print("*" * len(func_print), flush=True)
print("\n", "-" * 80, flush=True, sep="")
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
self._append_oai_message(message, "user", sender.name)
def receive(self, message: Union[Dict, str], sender):
"""Receive a message from another agent. """Receive a message from another agent.
This method is called by the sender. This method is called by the sender.
It needs to be overriden by the subclass to perform followup actions. It needs to be overriden by the subclass to perform followup actions.
Args:
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (All fields are optional).
1. "content": content of the message, can be None.
2. "function_call": a dictionary containing the function name and arguments.
3. "role": role of the message, can be "assistant", "user", "function".
This field is only needed to distinguish between "function" or "assistant"/"user".
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
sender: sender of an Agent instance.
""" """
self._receive(message, sender) self._receive(message, sender)
# perform actions based on the message # perform actions based on the message

View File

@ -1,6 +1,7 @@
from .agent import Agent from .agent import Agent
from flaml.autogen.code_utils import DEFAULT_MODEL from flaml.autogen.code_utils import DEFAULT_MODEL
from flaml import oai from flaml import oai
from typing import Dict, Union
class AssistantAgent(Agent): class AssistantAgent(Agent):
@ -33,16 +34,15 @@ class AssistantAgent(Agent):
self._config.update(config) self._config.update(config)
self._sender_dict = {} self._sender_dict = {}
def receive(self, message, sender): def receive(self, message: Union[Dict, str], sender):
if sender.name not in self._sender_dict: if sender.name not in self._sender_dict:
self._sender_dict[sender.name] = sender self._sender_dict[sender.name] = sender
self._conversations[sender.name] = [{"content": self._system_message, "role": "system"}] self._oai_conversations[sender.name] = [{"content": self._system_message, "role": "system"}]
super().receive(message, sender) super().receive(message, sender)
responses = oai.ChatCompletion.create(messages=self._conversations[sender.name], **self._config) responses = oai.ChatCompletion.create(messages=self._oai_conversations[sender.name], **self._config)
# TODO: handle function_call self._send(oai.ChatCompletion.extract_text_or_function_call(responses)[0], sender)
response = oai.ChatCompletion.extract_text(responses)[0]
self._send(response, sender)
def reset(self): def reset(self):
self._sender_dict.clear() self._sender_dict.clear()
self._conversations.clear() self._oai_conversations.clear()

View File

@ -84,6 +84,10 @@ Problem: """,
def is_termination_msg(x): def is_termination_msg(x):
"""Check if a message is a termination message.""" """Check if a message is a termination message."""
if isinstance(x, dict):
x = x.get("content")
if x is None:
return False
cb = extract_code(x) cb = extract_code(x)
contain_code = False contain_code = False
for c in cb: for c in cb:
@ -129,6 +133,7 @@ class MathUserProxyAgent(UserProxyAgent):
name="MathChatAgent", # default set to MathChatAgent name="MathChatAgent", # default set to MathChatAgent
system_message="", system_message="",
work_dir=None, work_dir=None,
function_map=defaultdict(callable),
human_input_mode="NEVER", # Fully automated human_input_mode="NEVER", # Fully automated
max_consecutive_auto_reply=None, max_consecutive_auto_reply=None,
is_termination_msg=is_termination_msg, is_termination_msg=is_termination_msg,
@ -150,11 +155,12 @@ class MathUserProxyAgent(UserProxyAgent):
the number of auto reply reaches the max_consecutive_auto_reply. the number of auto reply reaches the max_consecutive_auto_reply.
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
The limit only plays a role when human_input_mode is not "ALWAYS". The limit only plays a role when human_input_mode is not "ALWAYS".
is_termination_msg (function): a function that takes a message and returns a boolean value. is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message.
This function is used to determine if a received message is a termination message. The dict can contain the following keys: "content", "role", "name", "function_call".
use_docker (bool): whether to use docker to execute the code. use_docker (bool): whether to use docker to execute the code.
max_invalid_q_per_step (int): (ADDED) the maximum number of invalid queries per step. max_invalid_q_per_step (int): (ADDED) the maximum number of invalid queries per step.
**config (dict): other configurations. **config (dict): other configurations.
@ -163,6 +169,7 @@ class MathUserProxyAgent(UserProxyAgent):
name=name, name=name,
system_message=system_message, system_message=system_message,
work_dir=work_dir, work_dir=work_dir,
function_map=function_map,
human_input_mode=human_input_mode, human_input_mode=human_input_mode,
max_consecutive_auto_reply=max_consecutive_auto_reply, max_consecutive_auto_reply=max_consecutive_auto_reply,
is_termination_msg=is_termination_msg, is_termination_msg=is_termination_msg,
@ -208,7 +215,7 @@ class MathUserProxyAgent(UserProxyAgent):
return PROMPTS[prompt_type] + problem return PROMPTS[prompt_type] + problem
def _reset(self): def _reset(self):
self._conversations.clear() self._oai_conversations.clear()
self._valid_q_count = 0 self._valid_q_count = 0
self._total_q_count = 0 self._total_q_count = 0
self._accum_invalid_q_per_step = 0 self._accum_invalid_q_per_step = 0
@ -288,6 +295,7 @@ class MathUserProxyAgent(UserProxyAgent):
def auto_reply(self, message, sender, default_reply=""): def auto_reply(self, message, sender, default_reply=""):
"""Generate an auto reply.""" """Generate an auto reply."""
message = message.get("content", "")
code_blocks = extract_code(message) code_blocks = extract_code(message)
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN: if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
@ -391,7 +399,7 @@ class WolframAlphaAPIWrapper(BaseModel):
extra = Extra.forbid extra = Extra.forbid
@root_validator() @root_validator(skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
wolfram_alpha_appid = get_from_dict_or_env(values, "wolfram_alpha_appid", "WOLFRAM_ALPHA_APPID") wolfram_alpha_appid = get_from_dict_or_env(values, "wolfram_alpha_appid", "WOLFRAM_ALPHA_APPID")

View File

@ -1,7 +1,8 @@
from typing import Union
from .agent import Agent from .agent import Agent
from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
from collections import defaultdict from collections import defaultdict
import json
from typing import Dict, Union
class UserProxyAgent(Agent): class UserProxyAgent(Agent):
@ -15,6 +16,7 @@ class UserProxyAgent(Agent):
system_message="", system_message="",
work_dir=None, work_dir=None,
human_input_mode="ALWAYS", human_input_mode="ALWAYS",
function_map={},
max_consecutive_auto_reply=None, max_consecutive_auto_reply=None,
is_termination_msg=None, is_termination_msg=None,
use_docker=True, use_docker=True,
@ -34,11 +36,12 @@ class UserProxyAgent(Agent):
the number of auto reply reaches the max_consecutive_auto_reply. the number of auto reply reaches the max_consecutive_auto_reply.
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
The limit only plays a role when human_input_mode is not "ALWAYS". The limit only plays a role when human_input_mode is not "ALWAYS".
is_termination_msg (function): a function that takes a message and returns a boolean value. is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message.
This function is used to determine if a received message is a termination message. The dict can contain the following keys: "content", "role", "name", "function_call".
use_docker (bool or str): bool value of whether to use docker to execute the code, use_docker (bool or str): bool value of whether to use docker to execute the code,
or str value of the docker image name to use. or str value of the docker image name to use.
**config (dict): other configurations. **config (dict): other configurations.
@ -47,7 +50,7 @@ class UserProxyAgent(Agent):
self._work_dir = work_dir self._work_dir = work_dir
self._human_input_mode = human_input_mode self._human_input_mode = human_input_mode
self._is_termination_msg = ( self._is_termination_msg = (
is_termination_msg if is_termination_msg is not None else (lambda x: x == "TERMINATE") is_termination_msg if is_termination_msg is not None else (lambda x: x.get("content") == "TERMINATE")
) )
self._config = config self._config = config
self._max_consecutive_auto_reply = ( self._max_consecutive_auto_reply = (
@ -56,6 +59,8 @@ class UserProxyAgent(Agent):
self._consecutive_auto_reply_counter = defaultdict(int) self._consecutive_auto_reply_counter = defaultdict(int)
self._use_docker = use_docker self._use_docker = use_docker
self._function_map = function_map
@property @property
def use_docker(self) -> Union[bool, str]: def use_docker(self) -> Union[bool, str]:
"""bool value of whether to use docker to execute the code, """bool value of whether to use docker to execute the code,
@ -89,8 +94,8 @@ class UserProxyAgent(Agent):
) )
logs = logs.decode("utf-8") logs = logs.decode("utf-8")
else: else:
# TODO: could this happen? # In case the language is not supported, we return an error message.
exitcode, logs, image = 1, f"unknown language {lang}" exitcode, logs, image = 1, f"unknown language {lang}", self._use_docker
# raise NotImplementedError # raise NotImplementedError
self._use_docker = image self._use_docker = image
logs_all += "\n" + logs logs_all += "\n" + logs
@ -98,11 +103,86 @@ class UserProxyAgent(Agent):
return exitcode, logs_all return exitcode, logs_all
return exitcode, logs_all return exitcode, logs_all
def auto_reply(self, message, sender, default_reply=""): @staticmethod
def _format_json_str(jstr):
"""Remove newlines outside of quotes, and hanlde JSON escape sequences.
1. this function removes the newline in the query outside of quotes otherwise json.loads(s) will fail.
Ex 1:
"{\n"tool": "python",\n"query": "print('hello')\nprint('world')"\n}" -> "{"tool": "python","query": "print('hello')\nprint('world')"}"
Ex 2:
"{\n \"location\": \"Boston, MA\"\n}" -> "{"location": "Boston, MA"}"
2. this function also handles JSON escape sequences inside quotes,
Ex 1:
'{"args": "a\na\na\ta"}' -> '{"args": "a\\na\\na\\ta"}'
"""
result = []
inside_quotes = False
last_char = " "
for char in jstr:
if last_char != "\\" and char == '"':
inside_quotes = not inside_quotes
last_char = char
if not inside_quotes and char == "\n":
continue
if inside_quotes and char == "\n":
char = "\\n"
if inside_quotes and char == "\t":
char = "\\t"
result.append(char)
return "".join(result)
def _execute_function(self, func_call):
"""Execute a function call and return the result.
Args:
func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".
Returns:
A tuple of (is_exec_success, result_dict).
is_exec_success (boolean): whether the execution is successful.
result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
"""
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
is_exec_success = False
if func is not None:
# Extract arguments from a json-like string and put it into a dict.
input_string = self._format_json_str(func_call.get("arguments", "{}"))
try:
arguments = json.loads(input_string)
except json.JSONDecodeError as e:
arguments = None
content = f"Error: {e}\n You argument should follow json format."
# Try to execute the function
if arguments:
try:
content = func(**arguments)
is_exec_success = True
except Exception as e:
content = f"Error: {e}"
else:
content = f"Error: Function {func_name} not found."
return is_exec_success, {
"name": func_name,
"role": "function",
"content": str(content),
}
def auto_reply(self, message: dict, sender, default_reply=""):
"""Generate an auto reply.""" """Generate an auto reply."""
code_blocks = extract_code(message) if "function_call" in message:
is_exec_success, func_return = self._execute_function(message["function_call"])
self._send(func_return, sender)
return
code_blocks = extract_code(message["content"])
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN: if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
# no code block is found, lang should be `UNKNOWN`` # no code block is found, lang should be `UNKNOWN`
self._send(default_reply, sender) self._send(default_reply, sender)
else: else:
# try to execute the code # try to execute the code
@ -110,11 +190,12 @@ class UserProxyAgent(Agent):
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed" exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}", sender) self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}", sender)
def receive(self, message, sender): def receive(self, message: Union[Dict, str], sender):
"""Receive a message from the sender agent. """Receive a message from the sender agent.
Once a message is received, this function sends a reply to the sender or simply stop. Once a message is received, this function sends a reply to the sender or simply stop.
The reply can be generated automatically or entered manually by a human. The reply can be generated automatically or entered manually by a human.
""" """
message = self._message_to_dict(message)
super().receive(message, sender) super().receive(message, sender)
# default reply is empty (i.e., no reply, in this case we will try to generate auto reply) # default reply is empty (i.e., no reply, in this case we will try to generate auto reply)
reply = "" reply = ""

View File

@ -197,7 +197,7 @@
" \"user\",\n", " \"user\",\n",
" human_input_mode=\"NEVER\",\n", " human_input_mode=\"NEVER\",\n",
" max_consecutive_auto_reply=10,\n", " max_consecutive_auto_reply=10,\n",
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\") or x.rstrip().endswith('\"TERMINATE\".'),\n", " is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\") or x.get(\"content\", \"\").rstrip().endswith('\"TERMINATE\".'),\n",
" work_dir=\"coding\",\n", " work_dir=\"coding\",\n",
" use_docker=False, # set to True if you are using docker\n", " use_docker=False, # set to True if you are using docker\n",
")\n", ")\n",
@ -464,7 +464,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.15" "version": "3.9.16"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

View File

@ -148,7 +148,7 @@
"user = UserProxyAgent(\n", "user = UserProxyAgent(\n",
" name=\"user\",\n", " name=\"user\",\n",
" human_input_mode=\"ALWAYS\",\n", " human_input_mode=\"ALWAYS\",\n",
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\"),\n", " is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
")\n", ")\n",
"\n", "\n",
"# the purpose of the following line is to log the conversation history\n", "# the purpose of the following line is to log the conversation history\n",

View File

@ -108,7 +108,7 @@
" name=\"user\",\n", " name=\"user\",\n",
" human_input_mode=\"TERMINATE\",\n", " human_input_mode=\"TERMINATE\",\n",
" max_consecutive_auto_reply=10,\n", " max_consecutive_auto_reply=10,\n",
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\"),\n", " is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
" work_dir='web',\n", " work_dir='web',\n",
")" ")"
] ]

View File

@ -0,0 +1,40 @@
def test_agent():
from flaml.autogen.agent import Agent
dummy_agent_1 = Agent(name="dummy_agent_1")
dummy_agent_2 = Agent(name="dummy_agent_2")
dummy_agent_1.receive("hello", dummy_agent_2) # receive a str
dummy_agent_1.receive(
{
"content": "hello",
},
dummy_agent_2,
) # receive a dict
# receive dict without openai fields to be printed, such as "content", 'function_call'. There should be no error raised.
pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"])
dummy_agent_1.receive({"message": "hello"}, dummy_agent_2)
assert pre_len == len(
dummy_agent_1._oai_conversations["dummy_agent_2"]
), "When the message is not an valid openai message, it should not be appended to the oai conversation."
dummy_agent_1._send("hello", dummy_agent_2) # send a str
dummy_agent_1._send(
{
"content": "hello",
},
dummy_agent_2,
) # send a dict
# receive dict with no openai fields
pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"])
dummy_agent_1._send({"message": "hello"}, dummy_agent_2) # send dict with wrong field
assert pre_len == len(
dummy_agent_1._oai_conversations["dummy_agent_2"]
), "When the message is not a valid openai message, it should not be appended to the oai conversation."
if __name__ == "__main__":
test_agent()

View File

@ -23,7 +23,7 @@ def test_gpt35(human_input_mode="NEVER", max_consecutive_auto_reply=5):
"user", "user",
work_dir=f"{here}/test_agent_scripts", work_dir=f"{here}/test_agent_scripts",
human_input_mode=human_input_mode, human_input_mode=human_input_mode,
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE"), is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
max_consecutive_auto_reply=max_consecutive_auto_reply, max_consecutive_auto_reply=max_consecutive_auto_reply,
use_docker="python:3", use_docker="python:3",
) )
@ -51,7 +51,7 @@ def test_create_execute_script(human_input_mode="NEVER", max_consecutive_auto_re
"user", "user",
human_input_mode=human_input_mode, human_input_mode=human_input_mode,
max_consecutive_auto_reply=max_consecutive_auto_reply, max_consecutive_auto_reply=max_consecutive_auto_reply,
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE"), is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
) )
assistant.receive( assistant.receive(
"""Create and execute a script to plot a rocket without using matplotlib""", """Create and execute a script to plot a rocket without using matplotlib""",

View File

@ -60,5 +60,68 @@ def test_eval_math_responses():
print(eval_math_responses(**arguments)) print(eval_math_responses(**arguments))
def test_json_extraction():
from flaml.autogen.agent import UserProxyAgent
user = UserProxyAgent(name="test", use_docker=False)
jstr = '{\n"location": "Boston, MA"\n}'
assert user._format_json_str(jstr) == '{"location": "Boston, MA"}'
jstr = '{\n"code": "python",\n"query": "x=3\nprint(x)"}'
assert user._format_json_str(jstr) == '{"code": "python","query": "x=3\\nprint(x)"}'
jstr = '{"code": "a=\\"hello\\""}'
assert user._format_json_str(jstr) == '{"code": "a=\\"hello\\""}'
def test_execute_function():
from flaml.autogen.agent import UserProxyAgent
# 1. test calling a simple function
def add_num(num_to_be_added):
given_num = 10
return num_to_be_added + given_num
user = UserProxyAgent(name="test", function_map={"add_num": add_num})
# correct execution
correct_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
assert user._execute_function(func_call=correct_args)[1]["content"] == "15"
# function name called is wrong or doesn't exist
wrong_func_name = {"name": "subtract_num", "arguments": '{ "num_to_be_added": 5 }'}
assert "Error: Function" in user._execute_function(func_call=wrong_func_name)[1]["content"]
# arguments passed is not in correct json format
wrong_json_format = {
"name": "add_num",
"arguments": '{ "num_to_be_added": 5, given_num: 10 }',
} # should be "given_num" with quotes
assert (
"You argument should follow json format." in user._execute_function(func_call=wrong_json_format)[1]["content"]
)
# function execution error with wrong arguments passed
wrong_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5, "given_num": 10 }'}
assert "Error: " in user._execute_function(func_call=wrong_args)[1]["content"]
# 2. test calling a class method
class AddNum:
def __init__(self, given_num):
self.given_num = given_num
def add(self, num_to_be_added):
self.given_num = num_to_be_added + self.given_num
return self.given_num
user = UserProxyAgent(name="test", function_map={"add_num": AddNum(given_num=10).add})
func_call = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
assert user._execute_function(func_call=func_call)[1]["content"] == "15"
assert user._execute_function(func_call=func_call)[1]["content"] == "20"
if __name__ == "__main__": if __name__ == "__main__":
test_json_extraction()
test_execute_function()
test_eval_math_responses() test_eval_math_responses()

View File

@ -38,7 +38,7 @@ user_proxy = UserProxyAgent(
name="user_proxy", name="user_proxy",
human_input_mode="NEVER", # in this mode, the agent will never solicit human input but always auto reply human_input_mode="NEVER", # in this mode, the agent will never solicit human input but always auto reply
max_consecutive_auto_reply=10, # the maximum number of consecutive auto replies max_consecutive_auto_reply=10, # the maximum number of consecutive auto replies
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE") or x.rstrip().endswith('"TERMINATE".'), # the function to determine whether a message is a termination message is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE") or x.get("content", "").rstrip().endswith('"TERMINATE".'), # the function to determine whether a message is a termination message
work_dir=".", work_dir=".",
) )