mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-15 17:44:33 +00:00
Support function_call in autogen/agent (#1091)
* update funccall * code format * update to comments * update notebook * remove test for py3.7 * allow funccall to class functions * add test and clean up notebook * revise notebook and test * update * update mathagent * Update flaml/autogen/agent/agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * Update flaml/autogen/agent/user_proxy_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * revise to comments * revise function call design, notebook and test. add doc * code format * ad message_to_dict function * update mathproxyagent * revise docstr * update * Update flaml/autogen/agent/math_user_proxy_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * Update flaml/autogen/agent/math_user_proxy_agent.py Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> * Update flaml/autogen/agent/user_proxy_agent.py Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> * simply funccall in userproxyagent, rewind auto-gen.md, revise to comments * code format * update * remove notebook for another pr * revise oai_conversation part in agent, revise function exec in user_proxy_agent * update test_funccall * update * update * fix pydantic version * Update test/autogen/test_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * fix bug * fix bug * update * update is_termination_msg to accept dict --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
parent
dd9202bb01
commit
ca10b286cc
@ -1,5 +1,6 @@
|
|||||||
from .agent import Agent
|
from .agent import Agent
|
||||||
from .assistant_agent import AssistantAgent
|
from .assistant_agent import AssistantAgent
|
||||||
from .user_proxy_agent import UserProxyAgent
|
from .user_proxy_agent import UserProxyAgent
|
||||||
|
from .math_user_proxy_agent import MathUserProxyAgent
|
||||||
|
|
||||||
__all__ = ["Agent", "AssistantAgent", "UserProxyAgent"]
|
__all__ = ["Agent", "AssistantAgent", "UserProxyAgent", "MathUserProxyAgent"]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
@ -17,7 +18,7 @@ class Agent:
|
|||||||
# empty memory
|
# empty memory
|
||||||
self._memory = []
|
self._memory = []
|
||||||
# a dictionary of conversations, default value is list
|
# a dictionary of conversations, default value is list
|
||||||
self._conversations = defaultdict(list)
|
self._oai_conversations = defaultdict(list)
|
||||||
self._name = name
|
self._name = name
|
||||||
self._system_message = system_message
|
self._system_message = system_message
|
||||||
|
|
||||||
@ -30,22 +31,95 @@ class Agent:
|
|||||||
"""Remember something."""
|
"""Remember something."""
|
||||||
self._memory.append(memory)
|
self._memory.append(memory)
|
||||||
|
|
||||||
def _send(self, message, recipient):
|
@staticmethod
|
||||||
|
def _message_to_dict(message: Union[Dict, str]):
|
||||||
|
"""Convert a message to a dictionary.
|
||||||
|
|
||||||
|
The message can be a string or a dictionary. The string with be put in the "content" field of the new dictionary.
|
||||||
|
"""
|
||||||
|
if isinstance(message, str):
|
||||||
|
return {"content": message}
|
||||||
|
else:
|
||||||
|
return message
|
||||||
|
|
||||||
|
def _append_oai_message(self, message: Union[Dict, str], role, conversation_id):
|
||||||
|
"""Append a message to the openai conversation.
|
||||||
|
|
||||||
|
If the message received is a string, it will be put in the "content" field of the new dictionary.
|
||||||
|
If the message received is a dictionary but does not have any of the two fields "content" or "function_call",
|
||||||
|
this message is not a valid openai message and will be ignored.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict or str): message to be appended to the openai conversation.
|
||||||
|
role (str): role of the message, can be "assistant" or "function".
|
||||||
|
conversation_id (str): id of the conversation, should be the name of the recipient or sender.
|
||||||
|
"""
|
||||||
|
message = self._message_to_dict(message)
|
||||||
|
# create openai message to be appended to the openai conversation that can be passed to oai directly.
|
||||||
|
oai_message = {k: message[k] for k in ("content", "function_call", "name") if k in message}
|
||||||
|
if "content" not in oai_message and "function_call" not in oai_message:
|
||||||
|
return
|
||||||
|
|
||||||
|
oai_message["role"] = "function" if message.get("role") == "function" else role
|
||||||
|
self._oai_conversations[conversation_id].append(oai_message)
|
||||||
|
|
||||||
|
def _send(self, message: Union[Dict, str], recipient):
|
||||||
"""Send a message to another agent."""
|
"""Send a message to another agent."""
|
||||||
self._conversations[recipient.name].append({"content": message, "role": "assistant"})
|
# When the agent composes and sends the message, the role of the message is "assistant". (If 'role' exists and is 'function', it will remain unchanged.)
|
||||||
|
self._append_oai_message(message, "assistant", recipient.name)
|
||||||
recipient.receive(message, self)
|
recipient.receive(message, self)
|
||||||
|
|
||||||
def _receive(self, message, sender):
|
def _receive(self, message: Union[Dict, str], sender):
|
||||||
"""Receive a message from another agent."""
|
"""Receive a message from another agent.
|
||||||
print("\n", "-" * 80, "\n", flush=True)
|
|
||||||
print(sender.name, "(to", f"{self.name}):", flush=True)
|
|
||||||
print(message, flush=True)
|
|
||||||
self._conversations[sender.name].append({"content": message, "role": "user"})
|
|
||||||
|
|
||||||
def receive(self, message, sender):
|
Args:
|
||||||
|
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (All fields are optional).
|
||||||
|
1. "content": content of the message, can be None.
|
||||||
|
2. "function_call": a dictionary containing the function name and arguments.
|
||||||
|
3. "role": role of the message, can be "assistant", "user", "function".
|
||||||
|
This field is only needed to distinguish between "function" or "assistant"/"user".
|
||||||
|
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
|
||||||
|
sender: sender of an Agent instance.
|
||||||
|
"""
|
||||||
|
message = self._message_to_dict(message)
|
||||||
|
# print the message received
|
||||||
|
print(sender.name, "(to", f"{self.name}):\n", flush=True)
|
||||||
|
if message.get("role") == "function":
|
||||||
|
func_print = f"***** Response from calling function \"{message['name']}\" *****"
|
||||||
|
print(func_print, flush=True)
|
||||||
|
print(message["content"], flush=True)
|
||||||
|
print("*" * len(func_print), flush=True)
|
||||||
|
else:
|
||||||
|
if message.get("content") is not None:
|
||||||
|
print(message["content"], flush=True)
|
||||||
|
if "function_call" in message:
|
||||||
|
func_print = f"***** Suggested function Call: {message['function_call'].get('name', '(No function name found)')} *****"
|
||||||
|
print(func_print, flush=True)
|
||||||
|
print(
|
||||||
|
"Arguments: \n",
|
||||||
|
message["function_call"].get("arguments", "(No arguments found)"),
|
||||||
|
flush=True,
|
||||||
|
sep="",
|
||||||
|
)
|
||||||
|
print("*" * len(func_print), flush=True)
|
||||||
|
print("\n", "-" * 80, flush=True, sep="")
|
||||||
|
|
||||||
|
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
|
||||||
|
self._append_oai_message(message, "user", sender.name)
|
||||||
|
|
||||||
|
def receive(self, message: Union[Dict, str], sender):
|
||||||
"""Receive a message from another agent.
|
"""Receive a message from another agent.
|
||||||
This method is called by the sender.
|
This method is called by the sender.
|
||||||
It needs to be overriden by the subclass to perform followup actions.
|
It needs to be overriden by the subclass to perform followup actions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (All fields are optional).
|
||||||
|
1. "content": content of the message, can be None.
|
||||||
|
2. "function_call": a dictionary containing the function name and arguments.
|
||||||
|
3. "role": role of the message, can be "assistant", "user", "function".
|
||||||
|
This field is only needed to distinguish between "function" or "assistant"/"user".
|
||||||
|
4. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
|
||||||
|
sender: sender of an Agent instance.
|
||||||
"""
|
"""
|
||||||
self._receive(message, sender)
|
self._receive(message, sender)
|
||||||
# perform actions based on the message
|
# perform actions based on the message
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from .agent import Agent
|
from .agent import Agent
|
||||||
from flaml.autogen.code_utils import DEFAULT_MODEL
|
from flaml.autogen.code_utils import DEFAULT_MODEL
|
||||||
from flaml import oai
|
from flaml import oai
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
|
||||||
class AssistantAgent(Agent):
|
class AssistantAgent(Agent):
|
||||||
@ -33,16 +34,15 @@ class AssistantAgent(Agent):
|
|||||||
self._config.update(config)
|
self._config.update(config)
|
||||||
self._sender_dict = {}
|
self._sender_dict = {}
|
||||||
|
|
||||||
def receive(self, message, sender):
|
def receive(self, message: Union[Dict, str], sender):
|
||||||
if sender.name not in self._sender_dict:
|
if sender.name not in self._sender_dict:
|
||||||
self._sender_dict[sender.name] = sender
|
self._sender_dict[sender.name] = sender
|
||||||
self._conversations[sender.name] = [{"content": self._system_message, "role": "system"}]
|
self._oai_conversations[sender.name] = [{"content": self._system_message, "role": "system"}]
|
||||||
|
|
||||||
super().receive(message, sender)
|
super().receive(message, sender)
|
||||||
responses = oai.ChatCompletion.create(messages=self._conversations[sender.name], **self._config)
|
responses = oai.ChatCompletion.create(messages=self._oai_conversations[sender.name], **self._config)
|
||||||
# TODO: handle function_call
|
self._send(oai.ChatCompletion.extract_text_or_function_call(responses)[0], sender)
|
||||||
response = oai.ChatCompletion.extract_text(responses)[0]
|
|
||||||
self._send(response, sender)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._sender_dict.clear()
|
self._sender_dict.clear()
|
||||||
self._conversations.clear()
|
self._oai_conversations.clear()
|
||||||
|
|||||||
@ -84,6 +84,10 @@ Problem: """,
|
|||||||
|
|
||||||
def is_termination_msg(x):
|
def is_termination_msg(x):
|
||||||
"""Check if a message is a termination message."""
|
"""Check if a message is a termination message."""
|
||||||
|
if isinstance(x, dict):
|
||||||
|
x = x.get("content")
|
||||||
|
if x is None:
|
||||||
|
return False
|
||||||
cb = extract_code(x)
|
cb = extract_code(x)
|
||||||
contain_code = False
|
contain_code = False
|
||||||
for c in cb:
|
for c in cb:
|
||||||
@ -129,6 +133,7 @@ class MathUserProxyAgent(UserProxyAgent):
|
|||||||
name="MathChatAgent", # default set to MathChatAgent
|
name="MathChatAgent", # default set to MathChatAgent
|
||||||
system_message="",
|
system_message="",
|
||||||
work_dir=None,
|
work_dir=None,
|
||||||
|
function_map=defaultdict(callable),
|
||||||
human_input_mode="NEVER", # Fully automated
|
human_input_mode="NEVER", # Fully automated
|
||||||
max_consecutive_auto_reply=None,
|
max_consecutive_auto_reply=None,
|
||||||
is_termination_msg=is_termination_msg,
|
is_termination_msg=is_termination_msg,
|
||||||
@ -150,11 +155,12 @@ class MathUserProxyAgent(UserProxyAgent):
|
|||||||
the number of auto reply reaches the max_consecutive_auto_reply.
|
the number of auto reply reaches the max_consecutive_auto_reply.
|
||||||
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
|
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
|
||||||
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
|
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
|
||||||
|
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
|
||||||
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
|
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
|
||||||
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
|
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
|
||||||
The limit only plays a role when human_input_mode is not "ALWAYS".
|
The limit only plays a role when human_input_mode is not "ALWAYS".
|
||||||
is_termination_msg (function): a function that takes a message and returns a boolean value.
|
is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message.
|
||||||
This function is used to determine if a received message is a termination message.
|
The dict can contain the following keys: "content", "role", "name", "function_call".
|
||||||
use_docker (bool): whether to use docker to execute the code.
|
use_docker (bool): whether to use docker to execute the code.
|
||||||
max_invalid_q_per_step (int): (ADDED) the maximum number of invalid queries per step.
|
max_invalid_q_per_step (int): (ADDED) the maximum number of invalid queries per step.
|
||||||
**config (dict): other configurations.
|
**config (dict): other configurations.
|
||||||
@ -163,6 +169,7 @@ class MathUserProxyAgent(UserProxyAgent):
|
|||||||
name=name,
|
name=name,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
work_dir=work_dir,
|
work_dir=work_dir,
|
||||||
|
function_map=function_map,
|
||||||
human_input_mode=human_input_mode,
|
human_input_mode=human_input_mode,
|
||||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||||
is_termination_msg=is_termination_msg,
|
is_termination_msg=is_termination_msg,
|
||||||
@ -208,7 +215,7 @@ class MathUserProxyAgent(UserProxyAgent):
|
|||||||
return PROMPTS[prompt_type] + problem
|
return PROMPTS[prompt_type] + problem
|
||||||
|
|
||||||
def _reset(self):
|
def _reset(self):
|
||||||
self._conversations.clear()
|
self._oai_conversations.clear()
|
||||||
self._valid_q_count = 0
|
self._valid_q_count = 0
|
||||||
self._total_q_count = 0
|
self._total_q_count = 0
|
||||||
self._accum_invalid_q_per_step = 0
|
self._accum_invalid_q_per_step = 0
|
||||||
@ -288,6 +295,7 @@ class MathUserProxyAgent(UserProxyAgent):
|
|||||||
|
|
||||||
def auto_reply(self, message, sender, default_reply=""):
|
def auto_reply(self, message, sender, default_reply=""):
|
||||||
"""Generate an auto reply."""
|
"""Generate an auto reply."""
|
||||||
|
message = message.get("content", "")
|
||||||
code_blocks = extract_code(message)
|
code_blocks = extract_code(message)
|
||||||
|
|
||||||
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
|
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
|
||||||
@ -391,7 +399,7 @@ class WolframAlphaAPIWrapper(BaseModel):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(skip_on_failure=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
wolfram_alpha_appid = get_from_dict_or_env(values, "wolfram_alpha_appid", "WOLFRAM_ALPHA_APPID")
|
wolfram_alpha_appid = get_from_dict_or_env(values, "wolfram_alpha_appid", "WOLFRAM_ALPHA_APPID")
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
from typing import Union
|
|
||||||
from .agent import Agent
|
from .agent import Agent
|
||||||
from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
|
from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
|
||||||
class UserProxyAgent(Agent):
|
class UserProxyAgent(Agent):
|
||||||
@ -15,6 +16,7 @@ class UserProxyAgent(Agent):
|
|||||||
system_message="",
|
system_message="",
|
||||||
work_dir=None,
|
work_dir=None,
|
||||||
human_input_mode="ALWAYS",
|
human_input_mode="ALWAYS",
|
||||||
|
function_map={},
|
||||||
max_consecutive_auto_reply=None,
|
max_consecutive_auto_reply=None,
|
||||||
is_termination_msg=None,
|
is_termination_msg=None,
|
||||||
use_docker=True,
|
use_docker=True,
|
||||||
@ -34,11 +36,12 @@ class UserProxyAgent(Agent):
|
|||||||
the number of auto reply reaches the max_consecutive_auto_reply.
|
the number of auto reply reaches the max_consecutive_auto_reply.
|
||||||
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
|
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
|
||||||
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
|
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
|
||||||
|
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
|
||||||
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
|
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
|
||||||
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
|
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
|
||||||
The limit only plays a role when human_input_mode is not "ALWAYS".
|
The limit only plays a role when human_input_mode is not "ALWAYS".
|
||||||
is_termination_msg (function): a function that takes a message and returns a boolean value.
|
is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message.
|
||||||
This function is used to determine if a received message is a termination message.
|
The dict can contain the following keys: "content", "role", "name", "function_call".
|
||||||
use_docker (bool or str): bool value of whether to use docker to execute the code,
|
use_docker (bool or str): bool value of whether to use docker to execute the code,
|
||||||
or str value of the docker image name to use.
|
or str value of the docker image name to use.
|
||||||
**config (dict): other configurations.
|
**config (dict): other configurations.
|
||||||
@ -47,7 +50,7 @@ class UserProxyAgent(Agent):
|
|||||||
self._work_dir = work_dir
|
self._work_dir = work_dir
|
||||||
self._human_input_mode = human_input_mode
|
self._human_input_mode = human_input_mode
|
||||||
self._is_termination_msg = (
|
self._is_termination_msg = (
|
||||||
is_termination_msg if is_termination_msg is not None else (lambda x: x == "TERMINATE")
|
is_termination_msg if is_termination_msg is not None else (lambda x: x.get("content") == "TERMINATE")
|
||||||
)
|
)
|
||||||
self._config = config
|
self._config = config
|
||||||
self._max_consecutive_auto_reply = (
|
self._max_consecutive_auto_reply = (
|
||||||
@ -56,6 +59,8 @@ class UserProxyAgent(Agent):
|
|||||||
self._consecutive_auto_reply_counter = defaultdict(int)
|
self._consecutive_auto_reply_counter = defaultdict(int)
|
||||||
self._use_docker = use_docker
|
self._use_docker = use_docker
|
||||||
|
|
||||||
|
self._function_map = function_map
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_docker(self) -> Union[bool, str]:
|
def use_docker(self) -> Union[bool, str]:
|
||||||
"""bool value of whether to use docker to execute the code,
|
"""bool value of whether to use docker to execute the code,
|
||||||
@ -89,8 +94,8 @@ class UserProxyAgent(Agent):
|
|||||||
)
|
)
|
||||||
logs = logs.decode("utf-8")
|
logs = logs.decode("utf-8")
|
||||||
else:
|
else:
|
||||||
# TODO: could this happen?
|
# In case the language is not supported, we return an error message.
|
||||||
exitcode, logs, image = 1, f"unknown language {lang}"
|
exitcode, logs, image = 1, f"unknown language {lang}", self._use_docker
|
||||||
# raise NotImplementedError
|
# raise NotImplementedError
|
||||||
self._use_docker = image
|
self._use_docker = image
|
||||||
logs_all += "\n" + logs
|
logs_all += "\n" + logs
|
||||||
@ -98,11 +103,86 @@ class UserProxyAgent(Agent):
|
|||||||
return exitcode, logs_all
|
return exitcode, logs_all
|
||||||
return exitcode, logs_all
|
return exitcode, logs_all
|
||||||
|
|
||||||
def auto_reply(self, message, sender, default_reply=""):
|
@staticmethod
|
||||||
|
def _format_json_str(jstr):
|
||||||
|
"""Remove newlines outside of quotes, and hanlde JSON escape sequences.
|
||||||
|
|
||||||
|
1. this function removes the newline in the query outside of quotes otherwise json.loads(s) will fail.
|
||||||
|
Ex 1:
|
||||||
|
"{\n"tool": "python",\n"query": "print('hello')\nprint('world')"\n}" -> "{"tool": "python","query": "print('hello')\nprint('world')"}"
|
||||||
|
Ex 2:
|
||||||
|
"{\n \"location\": \"Boston, MA\"\n}" -> "{"location": "Boston, MA"}"
|
||||||
|
|
||||||
|
2. this function also handles JSON escape sequences inside quotes,
|
||||||
|
Ex 1:
|
||||||
|
'{"args": "a\na\na\ta"}' -> '{"args": "a\\na\\na\\ta"}'
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
inside_quotes = False
|
||||||
|
last_char = " "
|
||||||
|
for char in jstr:
|
||||||
|
if last_char != "\\" and char == '"':
|
||||||
|
inside_quotes = not inside_quotes
|
||||||
|
last_char = char
|
||||||
|
if not inside_quotes and char == "\n":
|
||||||
|
continue
|
||||||
|
if inside_quotes and char == "\n":
|
||||||
|
char = "\\n"
|
||||||
|
if inside_quotes and char == "\t":
|
||||||
|
char = "\\t"
|
||||||
|
result.append(char)
|
||||||
|
return "".join(result)
|
||||||
|
|
||||||
|
def _execute_function(self, func_call):
|
||||||
|
"""Execute a function call and return the result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (is_exec_success, result_dict).
|
||||||
|
is_exec_success (boolean): whether the execution is successful.
|
||||||
|
result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
|
||||||
|
"""
|
||||||
|
func_name = func_call.get("name", "")
|
||||||
|
func = self._function_map.get(func_name, None)
|
||||||
|
|
||||||
|
is_exec_success = False
|
||||||
|
if func is not None:
|
||||||
|
# Extract arguments from a json-like string and put it into a dict.
|
||||||
|
input_string = self._format_json_str(func_call.get("arguments", "{}"))
|
||||||
|
try:
|
||||||
|
arguments = json.loads(input_string)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
arguments = None
|
||||||
|
content = f"Error: {e}\n You argument should follow json format."
|
||||||
|
|
||||||
|
# Try to execute the function
|
||||||
|
if arguments:
|
||||||
|
try:
|
||||||
|
content = func(**arguments)
|
||||||
|
is_exec_success = True
|
||||||
|
except Exception as e:
|
||||||
|
content = f"Error: {e}"
|
||||||
|
else:
|
||||||
|
content = f"Error: Function {func_name} not found."
|
||||||
|
|
||||||
|
return is_exec_success, {
|
||||||
|
"name": func_name,
|
||||||
|
"role": "function",
|
||||||
|
"content": str(content),
|
||||||
|
}
|
||||||
|
|
||||||
|
def auto_reply(self, message: dict, sender, default_reply=""):
|
||||||
"""Generate an auto reply."""
|
"""Generate an auto reply."""
|
||||||
code_blocks = extract_code(message)
|
if "function_call" in message:
|
||||||
|
is_exec_success, func_return = self._execute_function(message["function_call"])
|
||||||
|
self._send(func_return, sender)
|
||||||
|
return
|
||||||
|
|
||||||
|
code_blocks = extract_code(message["content"])
|
||||||
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
|
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
|
||||||
# no code block is found, lang should be `UNKNOWN``
|
# no code block is found, lang should be `UNKNOWN`
|
||||||
self._send(default_reply, sender)
|
self._send(default_reply, sender)
|
||||||
else:
|
else:
|
||||||
# try to execute the code
|
# try to execute the code
|
||||||
@ -110,11 +190,12 @@ class UserProxyAgent(Agent):
|
|||||||
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
|
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
|
||||||
self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}", sender)
|
self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}", sender)
|
||||||
|
|
||||||
def receive(self, message, sender):
|
def receive(self, message: Union[Dict, str], sender):
|
||||||
"""Receive a message from the sender agent.
|
"""Receive a message from the sender agent.
|
||||||
Once a message is received, this function sends a reply to the sender or simply stop.
|
Once a message is received, this function sends a reply to the sender or simply stop.
|
||||||
The reply can be generated automatically or entered manually by a human.
|
The reply can be generated automatically or entered manually by a human.
|
||||||
"""
|
"""
|
||||||
|
message = self._message_to_dict(message)
|
||||||
super().receive(message, sender)
|
super().receive(message, sender)
|
||||||
# default reply is empty (i.e., no reply, in this case we will try to generate auto reply)
|
# default reply is empty (i.e., no reply, in this case we will try to generate auto reply)
|
||||||
reply = ""
|
reply = ""
|
||||||
|
|||||||
@ -197,7 +197,7 @@
|
|||||||
" \"user\",\n",
|
" \"user\",\n",
|
||||||
" human_input_mode=\"NEVER\",\n",
|
" human_input_mode=\"NEVER\",\n",
|
||||||
" max_consecutive_auto_reply=10,\n",
|
" max_consecutive_auto_reply=10,\n",
|
||||||
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\") or x.rstrip().endswith('\"TERMINATE\".'),\n",
|
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\") or x.get(\"content\", \"\").rstrip().endswith('\"TERMINATE\".'),\n",
|
||||||
" work_dir=\"coding\",\n",
|
" work_dir=\"coding\",\n",
|
||||||
" use_docker=False, # set to True if you are using docker\n",
|
" use_docker=False, # set to True if you are using docker\n",
|
||||||
")\n",
|
")\n",
|
||||||
@ -464,7 +464,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.15"
|
"version": "3.9.16"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
|
|||||||
@ -148,7 +148,7 @@
|
|||||||
"user = UserProxyAgent(\n",
|
"user = UserProxyAgent(\n",
|
||||||
" name=\"user\",\n",
|
" name=\"user\",\n",
|
||||||
" human_input_mode=\"ALWAYS\",\n",
|
" human_input_mode=\"ALWAYS\",\n",
|
||||||
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\"),\n",
|
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# the purpose of the following line is to log the conversation history\n",
|
"# the purpose of the following line is to log the conversation history\n",
|
||||||
|
|||||||
@ -108,7 +108,7 @@
|
|||||||
" name=\"user\",\n",
|
" name=\"user\",\n",
|
||||||
" human_input_mode=\"TERMINATE\",\n",
|
" human_input_mode=\"TERMINATE\",\n",
|
||||||
" max_consecutive_auto_reply=10,\n",
|
" max_consecutive_auto_reply=10,\n",
|
||||||
" is_termination_msg=lambda x: x.rstrip().endswith(\"TERMINATE\"),\n",
|
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
|
||||||
" work_dir='web',\n",
|
" work_dir='web',\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
|
|||||||
40
test/autogen/test_agent.py
Normal file
40
test/autogen/test_agent.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
def test_agent():
|
||||||
|
from flaml.autogen.agent import Agent
|
||||||
|
|
||||||
|
dummy_agent_1 = Agent(name="dummy_agent_1")
|
||||||
|
dummy_agent_2 = Agent(name="dummy_agent_2")
|
||||||
|
|
||||||
|
dummy_agent_1.receive("hello", dummy_agent_2) # receive a str
|
||||||
|
dummy_agent_1.receive(
|
||||||
|
{
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
dummy_agent_2,
|
||||||
|
) # receive a dict
|
||||||
|
|
||||||
|
# receive dict without openai fields to be printed, such as "content", 'function_call'. There should be no error raised.
|
||||||
|
pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"])
|
||||||
|
dummy_agent_1.receive({"message": "hello"}, dummy_agent_2)
|
||||||
|
assert pre_len == len(
|
||||||
|
dummy_agent_1._oai_conversations["dummy_agent_2"]
|
||||||
|
), "When the message is not an valid openai message, it should not be appended to the oai conversation."
|
||||||
|
|
||||||
|
dummy_agent_1._send("hello", dummy_agent_2) # send a str
|
||||||
|
dummy_agent_1._send(
|
||||||
|
{
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
dummy_agent_2,
|
||||||
|
) # send a dict
|
||||||
|
|
||||||
|
# receive dict with no openai fields
|
||||||
|
pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"])
|
||||||
|
dummy_agent_1._send({"message": "hello"}, dummy_agent_2) # send dict with wrong field
|
||||||
|
|
||||||
|
assert pre_len == len(
|
||||||
|
dummy_agent_1._oai_conversations["dummy_agent_2"]
|
||||||
|
), "When the message is not a valid openai message, it should not be appended to the oai conversation."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_agent()
|
||||||
@ -23,7 +23,7 @@ def test_gpt35(human_input_mode="NEVER", max_consecutive_auto_reply=5):
|
|||||||
"user",
|
"user",
|
||||||
work_dir=f"{here}/test_agent_scripts",
|
work_dir=f"{here}/test_agent_scripts",
|
||||||
human_input_mode=human_input_mode,
|
human_input_mode=human_input_mode,
|
||||||
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE"),
|
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||||
use_docker="python:3",
|
use_docker="python:3",
|
||||||
)
|
)
|
||||||
@ -51,7 +51,7 @@ def test_create_execute_script(human_input_mode="NEVER", max_consecutive_auto_re
|
|||||||
"user",
|
"user",
|
||||||
human_input_mode=human_input_mode,
|
human_input_mode=human_input_mode,
|
||||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||||
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE"),
|
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||||
)
|
)
|
||||||
assistant.receive(
|
assistant.receive(
|
||||||
"""Create and execute a script to plot a rocket without using matplotlib""",
|
"""Create and execute a script to plot a rocket without using matplotlib""",
|
||||||
|
|||||||
@ -60,5 +60,68 @@ def test_eval_math_responses():
|
|||||||
print(eval_math_responses(**arguments))
|
print(eval_math_responses(**arguments))
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_extraction():
|
||||||
|
from flaml.autogen.agent import UserProxyAgent
|
||||||
|
|
||||||
|
user = UserProxyAgent(name="test", use_docker=False)
|
||||||
|
|
||||||
|
jstr = '{\n"location": "Boston, MA"\n}'
|
||||||
|
assert user._format_json_str(jstr) == '{"location": "Boston, MA"}'
|
||||||
|
|
||||||
|
jstr = '{\n"code": "python",\n"query": "x=3\nprint(x)"}'
|
||||||
|
assert user._format_json_str(jstr) == '{"code": "python","query": "x=3\\nprint(x)"}'
|
||||||
|
|
||||||
|
jstr = '{"code": "a=\\"hello\\""}'
|
||||||
|
assert user._format_json_str(jstr) == '{"code": "a=\\"hello\\""}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_function():
|
||||||
|
from flaml.autogen.agent import UserProxyAgent
|
||||||
|
|
||||||
|
# 1. test calling a simple function
|
||||||
|
def add_num(num_to_be_added):
|
||||||
|
given_num = 10
|
||||||
|
return num_to_be_added + given_num
|
||||||
|
|
||||||
|
user = UserProxyAgent(name="test", function_map={"add_num": add_num})
|
||||||
|
|
||||||
|
# correct execution
|
||||||
|
correct_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
|
||||||
|
assert user._execute_function(func_call=correct_args)[1]["content"] == "15"
|
||||||
|
|
||||||
|
# function name called is wrong or doesn't exist
|
||||||
|
wrong_func_name = {"name": "subtract_num", "arguments": '{ "num_to_be_added": 5 }'}
|
||||||
|
assert "Error: Function" in user._execute_function(func_call=wrong_func_name)[1]["content"]
|
||||||
|
|
||||||
|
# arguments passed is not in correct json format
|
||||||
|
wrong_json_format = {
|
||||||
|
"name": "add_num",
|
||||||
|
"arguments": '{ "num_to_be_added": 5, given_num: 10 }',
|
||||||
|
} # should be "given_num" with quotes
|
||||||
|
assert (
|
||||||
|
"You argument should follow json format." in user._execute_function(func_call=wrong_json_format)[1]["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# function execution error with wrong arguments passed
|
||||||
|
wrong_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5, "given_num": 10 }'}
|
||||||
|
assert "Error: " in user._execute_function(func_call=wrong_args)[1]["content"]
|
||||||
|
|
||||||
|
# 2. test calling a class method
|
||||||
|
class AddNum:
|
||||||
|
def __init__(self, given_num):
|
||||||
|
self.given_num = given_num
|
||||||
|
|
||||||
|
def add(self, num_to_be_added):
|
||||||
|
self.given_num = num_to_be_added + self.given_num
|
||||||
|
return self.given_num
|
||||||
|
|
||||||
|
user = UserProxyAgent(name="test", function_map={"add_num": AddNum(given_num=10).add})
|
||||||
|
func_call = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
|
||||||
|
assert user._execute_function(func_call=func_call)[1]["content"] == "15"
|
||||||
|
assert user._execute_function(func_call=func_call)[1]["content"] == "20"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
test_json_extraction()
|
||||||
|
test_execute_function()
|
||||||
test_eval_math_responses()
|
test_eval_math_responses()
|
||||||
|
|||||||
@ -38,7 +38,7 @@ user_proxy = UserProxyAgent(
|
|||||||
name="user_proxy",
|
name="user_proxy",
|
||||||
human_input_mode="NEVER", # in this mode, the agent will never solicit human input but always auto reply
|
human_input_mode="NEVER", # in this mode, the agent will never solicit human input but always auto reply
|
||||||
max_consecutive_auto_reply=10, # the maximum number of consecutive auto replies
|
max_consecutive_auto_reply=10, # the maximum number of consecutive auto replies
|
||||||
is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE") or x.rstrip().endswith('"TERMINATE".'), # the function to determine whether a message is a termination message
|
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE") or x.get("content", "").rstrip().endswith('"TERMINATE".'), # the function to determine whether a message is a termination message
|
||||||
work_dir=".",
|
work_dir=".",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user