mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-01 18:29:49 +00:00
Support function_call in autogen/agent (#1091)
* update funccall * code format * update to comments * update notebook * remove test for py3.7 * allow funccall to class functions * add test and clean up notebook * revise notebook and test * update * update mathagent * Update flaml/autogen/agent/agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * Update flaml/autogen/agent/user_proxy_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * revise to comments * revise function call design, notebook and test. add doc * code format * ad message_to_dict function * update mathproxyagent * revise docstr * update * Update flaml/autogen/agent/math_user_proxy_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * Update flaml/autogen/agent/math_user_proxy_agent.py Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> * Update flaml/autogen/agent/user_proxy_agent.py Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> * simply funccall in userproxyagent, rewind auto-gen.md, revise to comments * code format * update * remove notebook for another pr * revise oai_conversation part in agent, revise function exec in user_proxy_agent * update test_funccall * update * update * fix pydantic version * Update test/autogen/test_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * fix bug * fix bug * update * update is_termination_msg to accept dict --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
parent
dd9202bb01
commit
ca10b286cc
@ -1,5 +1,6 @@
|
||||
from .agent import Agent
|
||||
from .assistant_agent import AssistantAgent
|
||||
from .user_proxy_agent import UserProxyAgent
|
||||
from .math_user_proxy_agent import MathUserProxyAgent
|
||||
|
||||
__all__ = ["Agent", "AssistantAgent", "UserProxyAgent"]
|
||||
__all__ = ["Agent", "AssistantAgent", "UserProxyAgent", "MathUserProxyAgent"]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Union
|
||||
|
||||
|
||||
class Agent:
|
||||
@ -17,7 +18,7 @@ class Agent:
|
||||
# empty memory
|
||||
self._memory = []
|
||||
# a dictionary of conversations, default value is list
|
||||
self._conversations = defaultdict(list)
|
||||
self._oai_conversations = defaultdict(list)
|
||||
self._name = name
|
||||
self._system_message = system_message
|
||||
|
||||
@ -30,22 +31,95 @@ class Agent:
|
||||
"""Remember something."""
|
||||
self._memory.append(memory)
|
||||
|
||||
def _send(self, message, recipient):
|
||||
@staticmethod
|
||||
def _message_to_dict(message: Union[Dict, str]):
|
||||
"""Convert a message to a dictionary.
|
||||
|
||||
The message can be a string or a dictionary. The string with be put in the "content" field of the new dictionary.
|
||||
"""
|
||||
if isinstance(message, str):
|
||||
return {"content": message}
|
||||
else:
|
||||
return message
|
||||
|
||||
def _append_oai_message(self, message: Union[Dict, str], role, conversation_id):
|
||||
"""Append a message to the openai conversation.
|
||||
|
||||
If the message received is a string, it will be put in the "content" field of the new dictionary.
|
||||
If the message received is a dictionary but does not have any of the two fields "content" or "function_call",
|
||||
this message is not a valid openai message and will be ignored.
|
||||
|
||||
Args:
|
||||
message (dict or str): message to be appended to the openai conversation.
|
||||
role (str): role of the message, can be "assistant" or "function".
|
||||
conversation_id (str): id of the conversation, should be the name of the recipient or sender.
|
||||
"""
|
||||
message = self._message_to_dict(message)
|
||||
# create openai message to be appended to the openai conversation that can be passed to oai directly.
|
||||
oai_message = {k: message[k] for k in ("content", "function_call", "name") if k in message}
|
||||
if "content" not in oai_message and "function_call" not in oai_message:
|
||||
return
|
||||
|
||||
oai_message["role"] = "function" if message.get("role") == "function" else role
|
||||
self._oai_conversations[conversation_id].append(oai_message)
|
||||
|
||||
def _send(self, message: Union[Dict, str], recipient):
|
||||
"""Send a message to another agent."""
|
||||
self._conversations[recipient.name].append({"content": message, "role": "assistant"})
|
||||
# When the agent composes and sends the message, the role of the message is "assistant". (If 'role' exists and is 'function', it will remain unchanged.)
|
||||
self._append_oai_message(message, "assistant", recipient.name)
|
||||
recipient.receive(message, self)
|
||||
|
||||
def _receive(self, message, sender):
|
||||
"""Receive a message from another agent."""
|
||||
print("\n", "-" * 80, "\n", flush=True)
|
||||
print(sender.name, "(to", f"{self.name}):", flush=True)
|
||||
print(message, flush=True)
|
||||
self._conversations[sender.name].append({"content": message, "role": "user"})
|
||||
def _receive(self, message: Union[Dict, str], sender):
|
||||
"""Receive a message from another agent.
|
||||
|
||||
def receive(self, message, sender):
|
||||
Args:
|
||||
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (All fields are optional).
|
||||
1. "content": content of the message, can be None.
|
||||
2. "function_call": a dictionary containing the function name and arguments.
|
||||
3. "role": role of the message, can be "assistant", "user", "function".
|
||||
This field is only needed to distinguish between "function" or "assistant"/"user".
|
||||
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
|
||||
sender: sender of an Agent instance.
|
||||
"""
|
||||
message = self._message_to_dict(message)
|
||||
# print the message received
|
||||
print(sender.name, "(to", f"{self.name}):\n", flush=True)
|
||||
if message.get("role") == "function":
|
||||
func_print = f"***** Response from calling function \"{message['name']}\" *****"
|
||||
print(func_print, flush=True)
|
||||
print(message["content"], flush=True)
|
||||
print("*" * len(func_print), flush=True)
|
||||
else:
|
||||
if message.get("content") is not None:
|
||||
print(message["content"], flush=True)
|
||||
if "function_call" in message:
|
||||
func_print = f"***** Suggested function Call: {message['function_call'].get('name', '(No function name found)')} *****"
|
||||
print(func_print, flush=True)
|
||||
print(
|
||||
"Arguments: \n",
|
||||
message["function_call"].get("arguments", "(No arguments found)"),
|
||||
flush=True,
|
||||
sep="",
|
||||
)
|
||||
print("*" * len(func_print), flush=True)
|
||||
print("\n", "-" * 80, flush=True, sep="")
|
||||
|
||||
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
|
||||
self._append_oai_message(message, "user", sender.name)
|
||||
|
||||
def receive(self, message: Union[Dict, str], sender):
|
||||
"""Receive a message from another agent.
|
||||
This method is called by the sender.
|
||||
It needs to be overriden by the subclass to perform followup actions.
|
||||
|
||||
Args:
|
||||
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (All fields are optional).
|
||||
1. "content": content of the message, can be None.
|
||||
2. "function_call": a dictionary containing the function name and arguments.
|
||||
3. "role": role of the message, can be "assistant", "user", "function".
|
||||
This field is only needed to distinguish between "function" or "assistant"/"user".
|
||||
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
|
||||
sender: sender of an Agent instance.
|
||||
"""
|
||||
self._receive(message, sender)
|
||||
# perform actions based on the message
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from .agent import Agent
|
||||
from flaml.autogen.code_utils import DEFAULT_MODEL
|
||||
from flaml import oai
|
||||
from typing import Dict, Union
|
||||
|
||||
|
||||
class AssistantAgent(Agent):
|
||||
@ -33,16 +34,15 @@ class AssistantAgent(Agent):
|
||||
self._config.update(config)
|
||||
self._sender_dict = {}
|
||||
|
||||
def receive(self, message, sender):
|
||||
def receive(self, message: Union[Dict, str], sender):
|
||||
if sender.name not in self._sender_dict:
|
||||
self._sender_dict[sender.name] = sender
|
||||
self._conversations[sender.name] = [{"content": self._system_message, "role": "system"}]
|
||||
self._oai_conversations[sender.name] = [{"content": self._system_message, "role": "system"}]
|
||||
|
||||
super().receive(message, sender)
|
||||
responses = oai.ChatCompletion.create(messages=self._conversations[sender.name], **self._config)
|
||||
# TODO: handle function_call
|
||||
response = oai.ChatCompletion.extract_text(responses)[0]
|
||||
self._send(response, sender)
|
||||
responses = oai.ChatCompletion.create(messages=self._oai_conversations[sender.name], **self._config)
|
||||
self._send(oai.ChatCompletion.extract_text_or_function_call(responses)[0], sender)
|
||||
|
||||
def reset(self):
|
||||
self._sender_dict.clear()
|
||||
self._conversations.clear()
|
||||
self._oai_conversations.clear()
|
||||
|
||||
@ -84,6 +84,10 @@ Problem: """,
|
||||
|
||||
def is_termination_msg(x):
|
||||
"""Check if a message is a termination message."""
|
||||
if isinstance(x, dict):
|
||||
x = x.get("content")
|
||||
if x is None:
|
||||
return False
|
||||
cb = extract_code(x)
|
||||
contain_code = False
|
||||
for c in cb:
|
||||
@ -129,6 +133,7 @@ class MathUserProxyAgent(UserProxyAgent):
|
||||
name="MathChatAgent", # default set to MathChatAgent
|
||||
system_message="",
|
||||
work_dir=None,
|
||||
function_map=defaultdict(callable),
|
||||
human_input_mode="NEVER", # Fully automated
|
||||
max_consecutive_auto_reply=None,
|
||||
is_termination_msg=is_termination_msg,
|
||||
@ -150,11 +155,12 @@ class MathUserProxyAgent(UserProxyAgent):
|
||||
the number of auto reply reaches the max_consecutive_auto_reply.
|
||||
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
|
||||
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
|
||||
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
|
||||
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
|
||||
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
|
||||
The limit only plays a role when human_input_mode is not "ALWAYS".
|
||||
is_termination_msg (function): a function that takes a message and returns a boolean value.
|
||||
This function is used to determine if a received message is a termination message.
|
||||
is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message.
|
||||
The dict can contain the following keys: "content", "role", "name", "function_call".
|
||||
use_docker (bool): whether to use docker to execute the code.
|
||||
max_invalid_q_per_step (int): (ADDED) the maximum number of invalid queries per step.
|
||||
**config (dict): other configurations.
|
||||
@ -163,6 +169,7 @@ class MathUserProxyAgent(UserProxyAgent):
|
||||
name=name,
|
||||
system_message=system_message,
|
||||
work_dir=work_dir,
|
||||
function_map=function_map,
|
||||
human_input_mode=human_input_mode,
|
||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||
is_termination_msg=is_termination_msg,
|
||||
@ -208,7 +215,7 @@ class MathUserProxyAgent(UserProxyAgent):
|
||||
return PROMPTS[prompt_type] + problem
|
||||
|
||||
def _reset(self):
|
||||
self._conversations.clear()
|
||||
self._oai_conversations.clear()
|
||||
self._valid_q_count = 0
|
||||
self._total_q_count = 0
|
||||
self._accum_invalid_q_per_step = 0
|
||||
@ -288,6 +295,7 @@ class MathUserProxyAgent(UserProxyAgent):
|
||||
|
||||
def auto_reply(self, message, sender, default_reply=""):
|
||||
"""Generate an auto reply."""
|
||||
message = message.get("content", "")
|
||||
code_blocks = extract_code(message)
|
||||
|
||||
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
|
||||
@ -391,7 +399,7 @@ class WolframAlphaAPIWrapper(BaseModel):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
@root_validator(skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
wolfram_alpha_appid = get_from_dict_or_env(values, "wolfram_alpha_appid", "WOLFRAM_ALPHA_APPID")
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import Union
|
||||
from .agent import Agent
|
||||
from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
|
||||
from collections import defaultdict
|
||||
import json
|
||||
from typing import Dict, Union
|
||||
|
||||
|
||||
class UserProxyAgent(Agent):
|
||||
@ -15,6 +16,7 @@ class UserProxyAgent(Agent):
|
||||
system_message="",
|
||||
work_dir=None,
|
||||
human_input_mode="ALWAYS",
|
||||
function_map={},
|
||||
max_consecutive_auto_reply=None,
|
||||
is_termination_msg=None,
|
||||
use_docker=True,
|
||||
@ -34,11 +36,12 @@ class UserProxyAgent(Agent):
|
||||
the number of auto reply reaches the max_consecutive_auto_reply.
|
||||
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
|
||||
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
|
||||
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
|
||||
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
|
||||
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
|
||||
The limit only plays a role when human_input_mode is not "ALWAYS".
|
||||
is_termination_msg (function): a function that takes a message and returns a boolean value.
|
||||
This function is used to determine if a received message is a termination message.
|
||||
is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message.
|
||||
The dict can contain the following keys: "content", "role", "name", "function_call".
|
||||
use_docker (bool or str): bool value of whether to use docker to execute the code,
|
||||
or str value of the docker image name to use.
|
||||
**config (dict): other configurations.
|
||||
@ -47,7 +50,7 @@ class UserProxyAgent(Agent):
|
||||
self._work_dir = work_dir
|
||||
self._human_input_mode = human_input_mode
|
||||
self._is_termination_msg = (
|
||||
is_termination_msg if is_termination_msg is not None else (lambda x: x == "TERMINATE")
|
||||
is_termination_msg if is_termination_msg is not None else (lambda x: x.get("content") == "TERMINATE")
|
||||
)
|
||||
self._config = config
|
||||
self._max_consecutive_auto_reply = (
|
||||
@ -56,6 +59,8 @@ class UserProxyAgent(Agent):
|
||||
self._consecutive_auto_reply_counter = defaultdict(int)
|
||||
self._use_docker = use_docker
|
||||
|
||||
self._function_map = function_map
|
||||
|
||||
@property
|
||||
def use_docker(self) -> Union[bool, str]:
|
||||
"""bool value of whether to use docker to execute the code,
|
||||
@ -89,8 +94,8 @@ class UserProxyAgent(Agent):
|
||||
)
|
||||
logs = logs.decode("utf-8")
|
||||
else:
|
||||
# TODO: could this happen?
|
||||
exitcode, logs, image = 1, f"unknown language {lang}"
|
||||
# In case the language is not supported, we return an error message.
|
||||
exitcode, logs, image = 1, f"unknown language {lang}", self._use_docker
|
||||
# raise NotImplementedError
|
||||
self._use_docker = image
|
||||
logs_all += "\n" + logs
|
||||
@ -98,11 +103,86 @@ class UserProxyAgent(Agent):
|
||||
return exitcode, logs_all
|
||||
return exitcode, logs_all
|
||||
|
||||
def auto_reply(self, message, sender, default_reply=""):
|
||||
@staticmethod
|
||||
def _format_json_str(jstr):
|
||||
"""Remove newlines outside of quotes, and hanlde JSON escape sequences.
|
||||
|
||||
1. this function removes the newline in the query outside of quotes otherwise json.loads(s) will fail.
|
||||
Ex 1:
|
||||
"{\n"tool": "python",\n"query": "print('hello')\nprint('world')"\n}" -> "{"tool": "python","query": "print('hello')\nprint('world')"}"
|
||||
Ex 2:
|
||||
"{\n \"location\": \"Boston, MA\"\n}" -> "{"location": "Boston, MA"}"
|
||||
|
||||
2. this function also handles JSON escape sequences inside quotes,
|
||||
Ex 1:
|
||||
'{"args": "a\na\na\ta"}' -> '{"args": "a\\na\\na\\ta"}'
|
||||
"""
|
||||
result = []
|
||||
inside_quotes = False
|
||||
last_char = " "
|
||||
for char in jstr:
|
||||
if last_char != "\\" and char == '"':
|
||||
inside_quotes = not inside_quotes
|
||||
last_char = char
|
||||
if not inside_quotes and char == "\n":
|
||||
continue
|
||||
if inside_quotes and char == "\n":
|
||||
char = "\\n"
|
||||
if inside_quotes and char == "\t":
|
||||
char = "\\t"
|
||||
result.append(char)
|
||||
return "".join(result)
|
||||
|
||||
def _execute_function(self, func_call):
|
||||
"""Execute a function call and return the result.
|
||||
|
||||
Args:
|
||||
func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".
|
||||
|
||||
Returns:
|
||||
A tuple of (is_exec_success, result_dict).
|
||||
is_exec_success (boolean): whether the execution is successful.
|
||||
result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
|
||||
"""
|
||||
func_name = func_call.get("name", "")
|
||||
func = self._function_map.get(func_name, None)
|
||||
|
||||
is_exec_success = False
|
||||
if func is not None:
|
||||
# Extract arguments from a json-like string and put it into a dict.
|
||||
input_string = self._format_json_str(func_call.get("arguments", "{}"))
|
||||
try:
|
||||
arguments = json.loads(input_string)
|
||||
except json.JSONDecodeError as e:
|
||||
arguments = None
|
||||
content = f"Error: {e}\n You argument should follow json format."
|
||||
|
||||
# Try to execute the function
|
||||
if arguments:
|
||||
try:
|
||||
content = func(**arguments)
|
||||
is_exec_success = True
|
||||
except Exception as e:
|
||||
content = f"Error: {e}"
|
||||
else:
|
||||
content = f"Error: Function {func_name} not found."
|
||||
|
||||
return is_exec_success, {
|
||||
"name": func_name,
|
||||
"role": "function",
|
||||
"content": str(content),
|
||||
}
|
||||
|
||||
def auto_reply(self, message: dict, sender, default_reply=""):
|
||||
"""Generate an auto reply."""
|
||||
code_blocks = extract_code(message)
|
||||
if "function_call" in message:
|
||||
is_exec_success, func_return = self._execute_function(message["function_call"])
|
||||
self._send(func_return, sender)
|
||||
return
|
||||
|
||||
code_blocks = extract_code(message["content"])
|
||||
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
|
||||
# no code block is found, lang should be `UNKNOWN``
|
||||
# no code block is found, lang should be `UNKNOWN`
|
||||
self._send(default_reply, sender)
|
||||
else:
|
||||
# try to execute the code
|
||||
@ -110,11 +190,12 @@ class UserProxyAgent(Agent):
|
||||
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
|
||||
self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}", sender)
|
||||
|
||||
def receive(self, message, sender):
|
||||
def receive(self, message: Union[Dict, str], sender):
|
||||
"""Receive a message from the sender agent.
|
||||
Once a message is received, this function sends a reply to the sender or simply stop.
|
||||
The reply can be generated automatically or entered manually by a human.
|
||||
"""
|
||||
message = self._message_to_dict(message)
|
||||
super().receive(message, sender)
|
||||
# default reply is empty (i.e., no reply, in this case we will try to generate auto reply)
|
||||
reply = ""
|
||||
|
||||
@ -197,7 +197,7 @@
|
||||
" \"user\",\n",
|
||||
" human_input_mode=\"NEVER\",\n",
|
||||
" max_consecutive_auto_reply=10,\n",
|
||||
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\") or x.rstrip().endswith('\"TERMINATE\".'),\n",
|
||||
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\") or x.get(\"content\", \"\").rstrip().endswith('\"TERMINATE\".'),\n",
|
||||
" work_dir=\"coding\",\n",
|
||||
" use_docker=False, # set to True if you are using docker\n",
|
||||
")\n",
|
||||
@ -464,7 +464,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.9.16"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
||||
@ -148,7 +148,7 @@
|
||||
"user = UserProxyAgent(\n",
|
||||
" name=\"user\",\n",
|
||||
" human_input_mode=\"ALWAYS\",\n",
|
||||
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\"),\n",
|
||||
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# the purpose of the following line is to log the conversation history\n",
|
||||
|
||||
@ -108,7 +108,7 @@
|
||||
" name=\"user\",\n",
|
||||
" human_input_mode=\"TERMINATE\",\n",
|
||||
" max_consecutive_auto_reply=10,\n",
|
||||
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\"),\n",
|
||||
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
|
||||
" work_dir='web',\n",
|
||||
")"
|
||||
]
|
||||
|
||||
40
test/autogen/test_agent.py
Normal file
40
test/autogen/test_agent.py
Normal file
@ -0,0 +1,40 @@
|
||||
def test_agent():
|
||||
from flaml.autogen.agent import Agent
|
||||
|
||||
dummy_agent_1 = Agent(name="dummy_agent_1")
|
||||
dummy_agent_2 = Agent(name="dummy_agent_2")
|
||||
|
||||
dummy_agent_1.receive("hello", dummy_agent_2) # receive a str
|
||||
dummy_agent_1.receive(
|
||||
{
|
||||
"content": "hello",
|
||||
},
|
||||
dummy_agent_2,
|
||||
) # receive a dict
|
||||
|
||||
# receive dict without openai fields to be printed, such as "content", 'function_call'. There should be no error raised.
|
||||
pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"])
|
||||
dummy_agent_1.receive({"message": "hello"}, dummy_agent_2)
|
||||
assert pre_len == len(
|
||||
dummy_agent_1._oai_conversations["dummy_agent_2"]
|
||||
), "When the message is not an valid openai message, it should not be appended to the oai conversation."
|
||||
|
||||
dummy_agent_1._send("hello", dummy_agent_2) # send a str
|
||||
dummy_agent_1._send(
|
||||
{
|
||||
"content": "hello",
|
||||
},
|
||||
dummy_agent_2,
|
||||
) # send a dict
|
||||
|
||||
# receive dict with no openai fields
|
||||
pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"])
|
||||
dummy_agent_1._send({"message": "hello"}, dummy_agent_2) # send dict with wrong field
|
||||
|
||||
assert pre_len == len(
|
||||
dummy_agent_1._oai_conversations["dummy_agent_2"]
|
||||
), "When the message is not a valid openai message, it should not be appended to the oai conversation."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_agent()
|
||||
@ -23,7 +23,7 @@ def test_gpt35(human_input_mode="NEVER", max_consecutive_auto_reply=5):
|
||||
"user",
|
||||
work_dir=f"{here}/test_agent_scripts",
|
||||
human_input_mode=human_input_mode,
|
||||
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE"),
|
||||
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||
use_docker="python:3",
|
||||
)
|
||||
@ -51,7 +51,7 @@ def test_create_execute_script(human_input_mode="NEVER", max_consecutive_auto_re
|
||||
"user",
|
||||
human_input_mode=human_input_mode,
|
||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE"),
|
||||
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||
)
|
||||
assistant.receive(
|
||||
"""Create and execute a script to plot a rocket without using matplotlib""",
|
||||
|
||||
@ -60,5 +60,68 @@ def test_eval_math_responses():
|
||||
print(eval_math_responses(**arguments))
|
||||
|
||||
|
||||
def test_json_extraction():
|
||||
from flaml.autogen.agent import UserProxyAgent
|
||||
|
||||
user = UserProxyAgent(name="test", use_docker=False)
|
||||
|
||||
jstr = '{\n"location": "Boston, MA"\n}'
|
||||
assert user._format_json_str(jstr) == '{"location": "Boston, MA"}'
|
||||
|
||||
jstr = '{\n"code": "python",\n"query": "x=3\nprint(x)"}'
|
||||
assert user._format_json_str(jstr) == '{"code": "python","query": "x=3\\nprint(x)"}'
|
||||
|
||||
jstr = '{"code": "a=\\"hello\\""}'
|
||||
assert user._format_json_str(jstr) == '{"code": "a=\\"hello\\""}'
|
||||
|
||||
|
||||
def test_execute_function():
|
||||
from flaml.autogen.agent import UserProxyAgent
|
||||
|
||||
# 1. test calling a simple function
|
||||
def add_num(num_to_be_added):
|
||||
given_num = 10
|
||||
return num_to_be_added + given_num
|
||||
|
||||
user = UserProxyAgent(name="test", function_map={"add_num": add_num})
|
||||
|
||||
# correct execution
|
||||
correct_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
|
||||
assert user._execute_function(func_call=correct_args)[1]["content"] == "15"
|
||||
|
||||
# function name called is wrong or doesn't exist
|
||||
wrong_func_name = {"name": "subtract_num", "arguments": '{ "num_to_be_added": 5 }'}
|
||||
assert "Error: Function" in user._execute_function(func_call=wrong_func_name)[1]["content"]
|
||||
|
||||
# arguments passed is not in correct json format
|
||||
wrong_json_format = {
|
||||
"name": "add_num",
|
||||
"arguments": '{ "num_to_be_added": 5, given_num: 10 }',
|
||||
} # should be "given_num" with quotes
|
||||
assert (
|
||||
"You argument should follow json format." in user._execute_function(func_call=wrong_json_format)[1]["content"]
|
||||
)
|
||||
|
||||
# function execution error with wrong arguments passed
|
||||
wrong_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5, "given_num": 10 }'}
|
||||
assert "Error: " in user._execute_function(func_call=wrong_args)[1]["content"]
|
||||
|
||||
# 2. test calling a class method
|
||||
class AddNum:
|
||||
def __init__(self, given_num):
|
||||
self.given_num = given_num
|
||||
|
||||
def add(self, num_to_be_added):
|
||||
self.given_num = num_to_be_added + self.given_num
|
||||
return self.given_num
|
||||
|
||||
user = UserProxyAgent(name="test", function_map={"add_num": AddNum(given_num=10).add})
|
||||
func_call = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
|
||||
assert user._execute_function(func_call=func_call)[1]["content"] == "15"
|
||||
assert user._execute_function(func_call=func_call)[1]["content"] == "20"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_json_extraction()
|
||||
test_execute_function()
|
||||
test_eval_math_responses()
|
||||
|
||||
@ -38,7 +38,7 @@ user_proxy = UserProxyAgent(
|
||||
name="user_proxy",
|
||||
human_input_mode="NEVER", # in this mode, the agent will never solicit human input but always auto reply
|
||||
max_consecutive_auto_reply=10, # the maximum number of consecutive auto replies
|
||||
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE") or x.rstrip().endswith('"TERMINATE".'), # the function to determine whether a message is a termination message
|
||||
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE") or x.get("content", "").rstrip().endswith('"TERMINATE".'), # the function to determine whether a message is a termination message
|
||||
work_dir=".",
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user