simplify the initiation of chat (#1131)

* simplify the initiation of chat

* version update

* include openai

* completion
This commit is contained in:
Chi Wang 2023-07-17 20:40:41 -07:00 committed by GitHub
parent 7665f73e4b
commit 16f0fcd6f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 146 additions and 121 deletions

View File

@ -63,13 +63,13 @@ class Agent:
oai_message["role"] = "function" if message.get("role") == "function" else role oai_message["role"] = "function" if message.get("role") == "function" else role
self._oai_conversations[conversation_id].append(oai_message) self._oai_conversations[conversation_id].append(oai_message)
def _send(self, message: Union[Dict, str], recipient): def send(self, message: Union[Dict, str], recipient):
"""Send a message to another agent.""" """Send a message to another agent."""
# When the agent composes and sends the message, the role of the message is "assistant". (If 'role' exists and is 'function', it will remain unchanged.) # When the agent composes and sends the message, the role of the message is "assistant". (If 'role' exists and is 'function', it will remain unchanged.)
self._append_oai_message(message, "assistant", recipient.name) self._append_oai_message(message, "assistant", recipient.name)
recipient.receive(message, self) recipient.receive(message, self)
def _receive(self, message: Union[Dict, str], sender): def _receive(self, message: Union[Dict, str], sender: "Agent"):
"""Receive a message from another agent. """Receive a message from another agent.
Args: Args:

View File

@ -42,7 +42,7 @@ class AssistantAgent(Agent):
super().receive(message, sender) super().receive(message, sender)
responses = oai.ChatCompletion.create(messages=self._oai_conversations[sender.name], **self._config) responses = oai.ChatCompletion.create(messages=self._oai_conversations[sender.name], **self._config)
self._send(oai.ChatCompletion.extract_text_or_function_call(responses)[0], sender) self.send(oai.ChatCompletion.extract_text_or_function_call(responses)[0], sender)
def reset(self): def reset(self):
self._sender_dict.clear() self._sender_dict.clear()

View File

@ -222,7 +222,7 @@ class MathUserProxyAgent(UserProxyAgent):
self._previous_code = "" self._previous_code = ""
self.last_reply = None self.last_reply = None
def _execute_one_python_code(self, pycode): def execute_one_python_code(self, pycode):
"""Execute python code blocks. """Execute python code blocks.
Previous python code will be saved and executed together with the new code. Previous python code will be saved and executed together with the new code.
@ -278,7 +278,7 @@ class MathUserProxyAgent(UserProxyAgent):
self._previous_code = tmp self._previous_code = tmp
return output, is_success return output, is_success
def _execute_one_wolfram_query(self, query: str): def execute_one_wolfram_query(self, query: str):
""" """
Run one wolfram query and return the output. Run one wolfram query and return the output.
return: return:
@ -302,7 +302,7 @@ class MathUserProxyAgent(UserProxyAgent):
# no code block is found, lang should be `UNKNOWN`` # no code block is found, lang should be `UNKNOWN``
if default_reply == "": if default_reply == "":
default_reply = "Continue. Please keep solving the problem until you need to query. (If you get to the answer, put it in \\boxed{}.)" default_reply = "Continue. Please keep solving the problem until you need to query. (If you get to the answer, put it in \\boxed{}.)"
self._send(default_reply, sender) self.send(default_reply, sender)
else: else:
is_success, all_success = True, True is_success, all_success = True, True
reply = "" reply = ""
@ -311,9 +311,9 @@ class MathUserProxyAgent(UserProxyAgent):
if not lang: if not lang:
lang = infer_lang(code) lang = infer_lang(code)
if lang == "python": if lang == "python":
output, is_success = self._execute_one_python_code(code) output, is_success = self.execute_one_python_code(code)
elif lang == "wolfram": elif lang == "wolfram":
output, is_success = self._execute_one_wolfram_query(code) output, is_success = self.execute_one_wolfram_query(code)
else: else:
output = "Error: Unknown language." output = "Error: Unknown language."
is_success = False is_success = False
@ -338,7 +338,7 @@ class MathUserProxyAgent(UserProxyAgent):
self._accum_invalid_q_per_step = 0 self._accum_invalid_q_per_step = 0
reply = "Please revisit the problem statement and your reasoning. If you think this step is correct, solve it yourself and continue the next step. Otherwise, correct this step." reply = "Please revisit the problem statement and your reasoning. If you think this step is correct, solve it yourself and continue the next step. Otherwise, correct this step."
self._send(reply, sender) self.send(reply, sender)
# Imported from langchain. Langchain is licensed under MIT License: # Imported from langchain. Langchain is licensed under MIT License:

View File

@ -77,7 +77,7 @@ class UserProxyAgent(Agent):
or str value of the docker image name to use.""" or str value of the docker image name to use."""
return self._use_docker return self._use_docker
def _execute_code(self, code_blocks): def execute_code(self, code_blocks):
"""Execute the code and return the result.""" """Execute the code and return the result."""
logs_all = "" logs_all = ""
for code_block in code_blocks: for code_block in code_blocks:
@ -185,19 +185,19 @@ class UserProxyAgent(Agent):
def auto_reply(self, message: dict, sender, default_reply=""): def auto_reply(self, message: dict, sender, default_reply=""):
"""Generate an auto reply.""" """Generate an auto reply."""
if "function_call" in message: if "function_call" in message:
is_exec_success, func_return = self._execute_function(message["function_call"]) _, func_return = self._execute_function(message["function_call"])
self._send(func_return, sender) self.send(func_return, sender)
return return
code_blocks = extract_code(message["content"]) code_blocks = extract_code(message["content"])
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN: if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
# no code block is found, lang should be `UNKNOWN` # no code block is found, lang should be `UNKNOWN`
self._send(default_reply, sender) self.send(default_reply, sender)
else: else:
# try to execute the code # try to execute the code
exitcode, logs = self._execute_code(code_blocks) exitcode, logs = self.execute_code(code_blocks)
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed" exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}", sender) self.send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}", sender)
def receive(self, message: Union[Dict, str], sender): def receive(self, message: Union[Dict, str], sender):
"""Receive a message from the sender agent. """Receive a message from the sender agent.
@ -230,9 +230,36 @@ class UserProxyAgent(Agent):
if reply: if reply:
# reset the consecutive_auto_reply_counter # reset the consecutive_auto_reply_counter
self._consecutive_auto_reply_counter[sender.name] = 0 self._consecutive_auto_reply_counter[sender.name] = 0
self._send(reply, sender) self.send(reply, sender)
return return
self._consecutive_auto_reply_counter[sender.name] += 1 self._consecutive_auto_reply_counter[sender.name] += 1
print("\n>>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...", flush=True) print("\n>>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...", flush=True)
self.auto_reply(message, sender, default_reply=reply) self.auto_reply(message, sender, default_reply=reply)
def generate_init_prompt(self, *args, **kwargs) -> Union[str, Dict]:
"""Generate the initial prompt for the agent.
Override this function to customize the initial prompt based on user's request.
"""
return args[0]
def initiate_chat(self, recipient, *args, **kwargs):
"""Initiate a chat with the receiver agent.
`generate_init_prompt` is called to generate the initial prompt for the agent.
Args:
receiver: the receiver agent.
*args: any additional arguments.
**kwargs: any additional keyword arguments.
"""
self.send(self.generate_init_prompt(*args, **kwargs), recipient)
def register_function(self, function_map: Dict[str, Callable]):
"""Register functions to the agent.
Args:
function_map: a dictionary mapping function names to functions.
"""
self._function_map.update(function_map)

View File

@ -31,7 +31,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# %pip install flaml[mathchat]==2.0.0rc2" "# %pip install flaml[mathchat]~=2.0.0rc4"
] ]
}, },
{ {
@ -122,14 +122,16 @@
" system_message=\"You are a helpful assistant.\",\n", " system_message=\"You are a helpful assistant.\",\n",
" request_timeout=600, \n", " request_timeout=600, \n",
" seed=42, \n", " seed=42, \n",
" config_list=config_list)\n", " config_list=config_list,\n",
")\n",
"\n", "\n",
"# 2. create the MathUserProxyAgent instance named \"mathproxyagent\"\n", "# 2. create the MathUserProxyAgent instance named \"mathproxyagent\"\n",
"# By default, the human_input_mode is \"NEVER\", which means the agent will not ask for human input.\n", "# By default, the human_input_mode is \"NEVER\", which means the agent will not ask for human input.\n",
"mathproxyagent = MathUserProxyAgent(\n", "mathproxyagent = MathUserProxyAgent(\n",
" name=\"MathChatAgent\", \n", " name=\"MathChatAgent\", \n",
" human_input_mode=\"NEVER\",\n", " human_input_mode=\"NEVER\",\n",
" use_docker=False)" " use_docker=False,\n",
")"
] ]
}, },
{ {
@ -283,11 +285,8 @@
"# given a math problem, we use the mathproxyagent to generate a prompt to be sent to the assistant as the initial message.\n", "# given a math problem, we use the mathproxyagent to generate a prompt to be sent to the assistant as the initial message.\n",
"# the assistant receives the message and generates a response. The response will be sent back to the mathproxyagent for processing.\n", "# the assistant receives the message and generates a response. The response will be sent back to the mathproxyagent for processing.\n",
"# The conversation continues until the termination condition is met, in MathChat, the termination condition is the detect of \"\\boxed{}\" in the response.\n", "# The conversation continues until the termination condition is met, in MathChat, the termination condition is the detect of \"\\boxed{}\" in the response.\n",
"math_problem = \"Find all $x$ that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$. Express your answer in interval notation.\"\n", "math_problem = \"Find all $x$ that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$. Express your answer in interval notation.\"\n",
"assistant.receive(\n", "mathproxyagent.initiate_chat(assistant, math_problem)"
" message=mathproxyagent.generate_init_prompt(math_problem),\n",
" sender=mathproxyagent,\n",
")"
] ]
}, },
{ {
@ -429,11 +428,8 @@
"source": [ "source": [
"assistant.reset()\n", "assistant.reset()\n",
"\n", "\n",
"math_problem = \"For what negative value of $k$ is there exactly one solution to the system of equations \\\\begin{align*}\\ny &= 2x^2 + kx + 6 \\\\\\\\\\ny &= -x + 4?\\n\\\\end{align*}\"\n", "math_problem = \"For what negative value of $k$ is there exactly one solution to the system of equations \\\\begin{align*}\\ny &= 2x^2 + kx + 6 \\\\\\\\\\ny &= -x + 4?\\n\\\\end{align*}\"\n",
"assistant.receive(\n", "mathproxyagent.initiate_chat(assistant, math_problem)"
" mathproxyagent.generate_init_prompt(math_problem),\n",
" mathproxyagent,\n",
")"
] ]
}, },
{ {
@ -561,11 +557,8 @@
"source": [ "source": [
"assistant.reset()\n", "assistant.reset()\n",
"\n", "\n",
"math_problem = \"Find all positive integer values of $c$ such that the equation $x^2-7x+c=0$ only has roots that are real and rational. Express them in decreasing order, separated by commas.\"\n", "math_problem = \"Find all positive integer values of $c$ such that the equation $x^2-7x+c=0$ only has roots that are real and rational. Express them in decreasing order, separated by commas.\"\n",
"assistant.receive(\n", "mathproxyagent.initiate_chat(assistant, math_problem)"
" mathproxyagent.generate_init_prompt(math_problem),\n",
" mathproxyagent,\n",
")"
] ]
}, },
{ {
@ -760,11 +753,8 @@
"assistant.reset() # clear LLM assistant's message history\n", "assistant.reset() # clear LLM assistant's message history\n",
"\n", "\n",
"# we set the prompt_type to \"python\", which is a simplied version of the default prompt.\n", "# we set the prompt_type to \"python\", which is a simplied version of the default prompt.\n",
"math_problem = \"Problem: If $725x + 727y = 1500$ and $729x+ 731y = 1508$, what is the value of $x - y$ ?\"\n", "math_problem = \"Problem: If $725x + 727y = 1500$ and $729x+ 731y = 1508$, what is the value of $x - y$ ?\"\n",
"assistant.receive(\n", "mathproxyagent.initiate_chat(assistant, math_problem, prompt_type=\"python\")"
" mathproxyagent.generate_init_prompt(math_problem, prompt_type=\"python\"),\n",
" mathproxyagent,\n",
")"
] ]
}, },
{ {
@ -904,11 +894,8 @@
" os.environ[\"WOLFRAM_ALPHA_APPID\"] = open(\"wolfram.txt\").read().strip()\n", " os.environ[\"WOLFRAM_ALPHA_APPID\"] = open(\"wolfram.txt\").read().strip()\n",
"\n", "\n",
"# we set the prompt_type to \"two_tools\", which allows the assistant to select wolfram alpha when necessary.\n", "# we set the prompt_type to \"two_tools\", which allows the assistant to select wolfram alpha when necessary.\n",
"math_problem = \"Find all numbers $a$ for which the graph of $y=x^2+a$ and the graph of $y=ax$ intersect. Express your answer in interval notation.\"\n", "math_problem = \"Find all numbers $a$ for which the graph of $y=x^2+a$ and the graph of $y=ax$ intersect. Express your answer in interval notation.\"\n",
"assistant.receive(\n", "mathproxyagent.initiate_chat(assistant, math_problem, prompt_type=\"two_tools\")"
" mathproxyagent.generate_init_prompt(math_problem, prompt_type=\"two_tools\"),\n",
" mathproxyagent,\n",
")"
] ]
} }
], ],

View File

@ -44,7 +44,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# %pip install flaml[autogen]==2.0.0rc3" "# %pip install flaml[autogen]~=2.0.0rc4"
] ]
}, },
{ {
@ -209,9 +209,9 @@
" use_docker=False, # set to True if you are using docker\n", " use_docker=False, # set to True if you are using docker\n",
")\n", ")\n",
"# the assistant receives a message from the user, which contains the task description\n", "# the assistant receives a message from the user, which contains the task description\n",
"assistant.receive(\n", "user.initiate_chat(\n",
" assistant,\n",
" \"\"\"What date is today? Compare the year-to-date gain for META and TESLA.\"\"\",\n", " \"\"\"What date is today? Compare the year-to-date gain for META and TESLA.\"\"\",\n",
" user,\n",
")" ")"
] ]
}, },
@ -314,9 +314,9 @@
], ],
"source": [ "source": [
"# followup of the previous question\n", "# followup of the previous question\n",
"assistant.receive(\n", "user.send(\n",
" \"\"\"Plot a chart of their stock price change YTD and save to stock_price_ytd.png.\"\"\",\n", " recipient=assistant,\n",
" user\n", " message=\"\"\"Plot a chart of their stock price change YTD and save to stock_price_ytd.png.\"\"\",\n",
")" ")"
] ]
}, },

View File

@ -36,7 +36,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# %pip install flaml[mathchat]==2.0.0rc3" "# %pip install flaml[mathchat]~=2.0.0rc4"
] ]
}, },
{ {
@ -79,11 +79,11 @@
"source": [ "source": [
"## Making Function Calls\n", "## Making Function Calls\n",
"\n", "\n",
"In this example, we demonstrate function call execution with `AssistantAgent` and `UserProxyAgent`. With the default system prompt of `AssistantAgent`, we allow the LLM assistant to perform tasks with code, and the `UserProxyAgent` would extract code blocks from the LLM response and execute them. With the new \"function_call\" feature, we define a new function using the pre-defined `_execute_code` from `UserProxyAgent` and specify the description of the function in the OpenAI config. \n", "In this example, we demonstrate function call execution with `AssistantAgent` and `UserProxyAgent`. With the default system prompt of `AssistantAgent`, we allow the LLM assistant to perform tasks with code, and the `UserProxyAgent` would extract code blocks from the LLM response and execute them. With the new \"function_call\" feature, we define a new function using the pre-defined `execute_code` from `UserProxyAgent` and specify the description of the function in the OpenAI config. \n",
"\n", "\n",
"Then, the model has two paths to execute code:\n", "Then, the model has two paths to execute code:\n",
"1. Put the code blocks in the response. `UserProxyAgent` will extract and execute the code through `_execute_code` method in the class.\n", "1. Put the code blocks in the response. `UserProxyAgent` will extract and execute the code through `execute_code` method in the class.\n",
"2. As we put a function description to OpenAI config and passed a function `execute_code_function` to `UserProxyAgent`, the model can also make function calls (will be put in `function_call` field of the API reply). `UserProxyAgent` will execute the function call through a `_execute_function` method." "2. As we put a function description to OpenAI config and register a function `exec_code` in `UserProxyAgent`, the model can also make function calls (will be put in `function_call` field of the API reply). `UserProxyAgent` will execute the function call through the registered `exec_code` method."
] ]
}, },
{ {
@ -234,24 +234,26 @@
"}\n", "}\n",
"chatbot = AssistantAgent(\"chatbot\", config_list=config_list, **oai_config)\n", "chatbot = AssistantAgent(\"chatbot\", config_list=config_list, **oai_config)\n",
"\n", "\n",
"# use pre-defined execute_code function from a UserProxyAgent instance\n", "# create a UserProxyAgent instance named \"user\"\n",
"# for simplicity, we don't pass in `exec_func` directly to UserProxyAgent because it requires a list of tuple as parameter\n",
"# instead, we define a wrapper function to call `exec_func`\n",
"exec_func = UserProxyAgent(name=\"execute_code\", work_dir=\"coding\", use_docker=False)._execute_code\n",
"\n",
"def execute_code(code_type, code):\n",
" return exec_func([(code_type, code)])\n",
"\n",
"user = UserProxyAgent(\n", "user = UserProxyAgent(\n",
" \"user\",\n", " \"user\",\n",
" human_input_mode=\"NEVER\",\n", " human_input_mode=\"NEVER\",\n",
" function_map={\"execute_code\": execute_code},\n", " work_dir=\"coding\",\n",
")\n", ")\n",
"\n", "\n",
"# define an `execute_code` function according to the function desription\n",
"def exec_code(code_type, code):\n",
" # here we reuse the method in the user proxy agent\n",
" # in general, this is not necessary\n",
" return user.execute_code([(code_type, code)])\n",
"\n",
"# register the `execute_code` function\n",
"user.register_function(function_map={\"execute_code\": exec_code})\n",
"\n",
"# start the conversation\n", "# start the conversation\n",
"chatbot.receive(\n", "user.initiate_chat(\n",
" chatbot,\n",
" \"Draw a rocket and save to a file named 'rocket.svg'\",\n", " \"Draw a rocket and save to a file named 'rocket.svg'\",\n",
" user,\n",
")\n" ")\n"
] ]
}, },
@ -289,7 +291,7 @@
"source": [ "source": [
"## Another example with Wolfram Alpha API\n", "## Another example with Wolfram Alpha API\n",
"\n", "\n",
"We give another example of query Wolfram Alpha API to solve math problem. We use the predefined function from `MathUserProxyAgent()`, we directly pass the class method, `MathUserProxyAgent()._execute_one_wolfram_query`, as the function to be called." "We give another example of query Wolfram Alpha API to solve math problem. We use the predefined function `MathUserProxyAgent().execute_one_wolfram_query` as the function to be called."
] ]
}, },
{ {
@ -389,19 +391,18 @@
"}\n", "}\n",
"chatbot = AssistantAgent(\"chatbot\", sys_prompt, config_list=config_list, **oai_config)\n", "chatbot = AssistantAgent(\"chatbot\", sys_prompt, config_list=config_list, **oai_config)\n",
"\n", "\n",
"\n", "# the key in `function_map` should match the function name in \"functions\" above\n",
"# the key in `function_map` should match the function name passed to OpenAI\n", "# we register a class instance method directly\n",
"# we pass a class instance directly\n",
"user = UserProxyAgent(\n", "user = UserProxyAgent(\n",
" \"user\",\n", " \"user\",\n",
" human_input_mode=\"NEVER\",\n", " human_input_mode=\"NEVER\",\n",
" function_map={\"query_wolfram\": MathUserProxyAgent()._execute_one_wolfram_query},\n", " function_map={\"query_wolfram\": MathUserProxyAgent().execute_one_wolfram_query},\n",
")\n", ")\n",
"\n", "\n",
"# start the conversation\n", "# start the conversation\n",
"chatbot.receive(\n", "user.initiate_chat(\n",
" chatbot,\n",
" \"Problem: Find all $x$ that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$. Express your answer in interval notation.\",\n", " \"Problem: Find all $x$ that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$. Express your answer in interval notation.\",\n",
" user,\n",
")\n" ")\n"
] ]
} }

View File

@ -44,7 +44,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# %pip install flaml[autogen]==2.0.0rc3" "# %pip install flaml[autogen]~=2.0.0rc4"
] ]
}, },
{ {

View File

@ -44,7 +44,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# %pip install flaml[autogen]==2.0.0rc3" "# %pip install flaml[autogen]~=2.0.0rc4"
] ]
}, },
{ {
@ -419,9 +419,9 @@
], ],
"source": [ "source": [
"# the assistant receives a message from the user, which contains the task description\n", "# the assistant receives a message from the user, which contains the task description\n",
"assistant.receive(\n", "user.initiate_chat(\n",
" assistant,\n",
" \"\"\"Suggest a fix to an open good first issue of flaml\"\"\",\n", " \"\"\"Suggest a fix to an open good first issue of flaml\"\"\",\n",
" user\n",
")" ")"
] ]
}, },

View File

@ -48,7 +48,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# %pip install flaml[autogen]==2.0.0rc3" "# %pip install flaml[autogen]~=2.0.0rc4"
] ]
}, },
{ {

View File

@ -82,7 +82,7 @@
"source": [ "source": [
"from flaml import oai\n", "from flaml import oai\n",
"\n", "\n",
"config_list = oai.config_list_openai_aoai(exclude=\"openai\")" "config_list = oai.config_list_openai_aoai()"
] ]
}, },
{ {

View File

@ -311,7 +311,7 @@ def test_math(num_samples=-1):
% data["problem"] % data["problem"]
] ]
oai.ChatCompletion.set_cache(seed) oai.Completion.set_cache(seed)
vanilla_config = { vanilla_config = {
"model": "text-davinci-003", "model": "text-davinci-003",
"temperature": 1, "temperature": 1,
@ -321,8 +321,8 @@ def test_math(num_samples=-1):
"stop": "###", "stop": "###",
} }
test_data_sample = test_data[0:3] test_data_sample = test_data[0:3]
result = oai.ChatCompletion.test(test_data_sample, eval_math_responses, **vanilla_config) result = oai.Completion.test(test_data_sample, eval_math_responses, **vanilla_config)
result = oai.ChatCompletion.test( result = oai.Completion.test(
test_data_sample, test_data_sample,
eval_math_responses, eval_math_responses,
agg_method="median", agg_method="median",
@ -335,13 +335,13 @@ def test_math(num_samples=-1):
def my_average(results): def my_average(results):
return np.mean(results) return np.mean(results)
result = oai.ChatCompletion.test( result = oai.Completion.test(
test_data_sample, test_data_sample,
eval_math_responses, eval_math_responses,
agg_method=my_median, agg_method=my_median,
**vanilla_config, **vanilla_config,
) )
result = oai.ChatCompletion.test( result = oai.Completion.test(
test_data_sample, test_data_sample,
eval_math_responses, eval_math_responses,
agg_method={ agg_method={
@ -355,7 +355,7 @@ def test_math(num_samples=-1):
print(result) print(result)
config, _ = oai.ChatCompletion.tune( config, _ = oai.Completion.tune(
data=tune_data, # the data for tuning data=tune_data, # the data for tuning
metric="expected_success", # the metric to optimize metric="expected_success", # the metric to optimize
mode="max", # the optimization mode mode="max", # the optimization mode
@ -368,7 +368,7 @@ def test_math(num_samples=-1):
stop="###", # the stop sequence stop="###", # the stop sequence
) )
print("tuned config", config) print("tuned config", config)
result = oai.ChatCompletion.test(test_data_sample, config_list=oai.config_list_openai_aoai(KEY_LOC), **config) result = oai.Completion.test(test_data_sample, config_list=oai.config_list_openai_aoai(KEY_LOC), **config)
print("result from tuned config:", result) print("result from tuned config:", result)
print("empty responses", eval_math_responses([], None)) print("empty responses", eval_math_responses([], None))

View File

@ -19,8 +19,8 @@ def test_agent():
dummy_agent_1._oai_conversations["dummy_agent_2"] dummy_agent_1._oai_conversations["dummy_agent_2"]
), "When the message is not an valid openai message, it should not be appended to the oai conversation." ), "When the message is not an valid openai message, it should not be appended to the oai conversation."
dummy_agent_1._send("hello", dummy_agent_2) # send a str dummy_agent_1.send("hello", dummy_agent_2) # send a str
dummy_agent_1._send( dummy_agent_1.send(
{ {
"content": "hello", "content": "hello",
}, },
@ -29,7 +29,7 @@ def test_agent():
# receive dict with no openai fields # receive dict with no openai fields
pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"]) pre_len = len(dummy_agent_1._oai_conversations["dummy_agent_2"])
dummy_agent_1._send({"message": "hello"}, dummy_agent_2) # send dict with wrong field dummy_agent_1.send({"message": "hello"}, dummy_agent_2) # send dict with wrong field
assert pre_len == len( assert pre_len == len(
dummy_agent_1._oai_conversations["dummy_agent_2"] dummy_agent_1._oai_conversations["dummy_agent_2"]

View File

@ -29,12 +29,12 @@ def test_gpt35(human_input_mode="NEVER", max_consecutive_auto_reply=5):
timeout=60, timeout=60,
) )
coding_task = "Print hello world to a file called hello.txt" coding_task = "Print hello world to a file called hello.txt"
assistant.receive(coding_task, user) user.initiate_chat(assistant, coding_task)
# coding_task = "Create a powerpoint with the text hello world in it." # coding_task = "Create a powerpoint with the text hello world in it."
# assistant.receive(coding_task, user) # assistant.receive(coding_task, user)
assistant.reset() assistant.reset()
coding_task = "Save a pandas df with 3 rows and 3 columns to disk." coding_task = "Save a pandas df with 3 rows and 3 columns to disk."
assistant.receive(coding_task, user) user.initiate_chat(assistant, coding_task)
assert not isinstance(user.use_docker, bool) # None or str assert not isinstance(user.use_docker, bool) # None or str
@ -54,21 +54,21 @@ def test_create_execute_script(human_input_mode="NEVER", max_consecutive_auto_re
max_consecutive_auto_reply=max_consecutive_auto_reply, max_consecutive_auto_reply=max_consecutive_auto_reply,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"), is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
) )
assistant.receive( user.initiate_chat(
assistant,
"""Create and execute a script to plot a rocket without using matplotlib""", """Create and execute a script to plot a rocket without using matplotlib""",
user,
) )
assistant.reset() assistant.reset()
assistant.receive( user.initiate_chat(
assistant,
"""Create a temp.py file with the following content: """Create a temp.py file with the following content:
``` ```
print('Hello world!') print('Hello world!')
```""", ```""",
user,
) )
print(conversations) print(conversations)
oai.ChatCompletion.start_logging(compact=False) oai.ChatCompletion.start_logging(compact=False)
assistant.receive("""Execute temp.py""", user) user.send("""Execute temp.py""", assistant)
print(oai.ChatCompletion.logged_history) print(oai.ChatCompletion.logged_history)
oai.ChatCompletion.stop_logging() oai.ChatCompletion.stop_logging()
@ -86,26 +86,33 @@ def test_tsp(human_input_mode="NEVER", max_consecutive_auto_reply=10):
"Can we add a new point to the graph? It's distance should be randomly between 0 - 5 to each of the existing points.", "Can we add a new point to the graph? It's distance should be randomly between 0 - 5 to each of the existing points.",
] ]
class TSPUserProxyAgent(UserProxyAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
with open(f"{here}/tsp_prompt.txt", "r") as f:
self._prompt = f.read()
def generate_init_prompt(self, question) -> str:
return self._prompt.format(question=question)
oai.ChatCompletion.start_logging() oai.ChatCompletion.start_logging()
assistant = AssistantAgent("assistant", temperature=0, config_list=config_list) assistant = AssistantAgent("assistant", temperature=0, config_list=config_list)
user = UserProxyAgent( user = TSPUserProxyAgent(
"user", "user",
work_dir=f"{here}", work_dir=f"{here}",
human_input_mode=human_input_mode, human_input_mode=human_input_mode,
max_consecutive_auto_reply=max_consecutive_auto_reply, max_consecutive_auto_reply=max_consecutive_auto_reply,
) )
with open(f"{here}/tsp_prompt.txt", "r") as f:
prompt = f.read()
# agent.receive(prompt.format(question=hard_questions[0]), user) # agent.receive(prompt.format(question=hard_questions[0]), user)
# agent.receive(prompt.format(question=hard_questions[1]), user) # agent.receive(prompt.format(question=hard_questions[1]), user)
assistant.receive(prompt.format(question=hard_questions[2]), user) user.initiate_chat(assistant, question=hard_questions[2])
print(oai.ChatCompletion.logged_history) print(oai.ChatCompletion.logged_history)
oai.ChatCompletion.stop_logging() oai.ChatCompletion.stop_logging()
if __name__ == "__main__": if __name__ == "__main__":
test_gpt35() # test_gpt35()
test_create_execute_script(human_input_mode="TERMINATE") # test_create_execute_script(human_input_mode="TERMINATE")
# when GPT-4, i.e., the DEFAULT_MODEL, is used, conversation in the following test # when GPT-4, i.e., the DEFAULT_MODEL, is used, conversation in the following test
# should terminate in 2-3 rounds of interactions (because is_termination_msg should be true after 2-3 rounds) # should terminate in 2-3 rounds of interactions (because is_termination_msg should be true after 2-3 rounds)
# although the max_consecutive_auto_reply is set to 10. # although the max_consecutive_auto_reply is set to 10.

View File

@ -64,24 +64,24 @@ def test_execute_one_python_code():
# no output found 1 # no output found 1
code = "x=3" code = "x=3"
assert mathproxyagent._execute_one_python_code(code)[0] == "No output found. Make sure you print the results." assert mathproxyagent.execute_one_python_code(code)[0] == "No output found. Make sure you print the results."
# no output found 2 # no output found 2
code = "if 4 > 5:\n\tprint('True')" code = "if 4 > 5:\n\tprint('True')"
assert mathproxyagent._execute_one_python_code(code)[0] == "No output found." assert mathproxyagent.execute_one_python_code(code)[0] == "No output found."
# return error # return error
code = "2+'2'" code = "2+'2'"
assert "Error:" in mathproxyagent._execute_one_python_code(code)[0] assert "Error:" in mathproxyagent.execute_one_python_code(code)[0]
# save previous status # save previous status
mathproxyagent._execute_one_python_code("x=3\ny=x*2") mathproxyagent.execute_one_python_code("x=3\ny=x*2")
assert mathproxyagent._execute_one_python_code("print(y)")[0].strip() == "6" assert mathproxyagent.execute_one_python_code("print(y)")[0].strip() == "6"
code = "print('*'*2001)" code = "print('*'*2001)"
assert ( assert (
mathproxyagent._execute_one_python_code(code)[0] mathproxyagent.execute_one_python_code(code)[0]
== "Your requested query response is too long. You might have made a mistake. Please revise your reasoning and query." == "Your requested query response is too long. You might have made a mistake. Please revise your reasoning and query."
) )
@ -91,7 +91,7 @@ def test_execute_one_wolfram_query():
code = "2x=3" code = "2x=3"
try: try:
mathproxyagent._execute_one_wolfram_query(code)[0] mathproxyagent.execute_one_wolfram_query(code)[0]
except ValueError: except ValueError:
print("Wolfrma API key not found. Skip test.") print("Wolfrma API key not found. Skip test.")

View File

@ -43,9 +43,9 @@ user_proxy = UserProxyAgent(
) )
# the assistant receives a message from the user, which contains the task description # the assistant receives a message from the user, which contains the task description
assistant.receive( user.initiate_chat(
assistant,
"""What date is today? Which big tech stock has the largest year-to-date gain this year? How much is the gain?""", """What date is today? Which big tech stock has the largest year-to-date gain this year? How much is the gain?""",
user_proxy,
) )
``` ```
In the example above, we create an AssistantAgent named "assistant" to serve as the assistant and a UserProxyAgent named "user_proxy" to serve as a proxy for the human user. In the example above, we create an AssistantAgent named "assistant" to serve as the assistant and a UserProxyAgent named "user_proxy" to serve as a proxy for the human user.
@ -76,7 +76,7 @@ oai_config = {
"functions": [ "functions": [
{ {
"name": "execute_code", "name": "execute_code",
"description": "Receive a list of python code or shell script and return the execution result.", "description": "Receive a python code or shell script and return the execution result.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -99,23 +99,26 @@ oai_config = {
# create an AssistantAgent instance named "assistant" # create an AssistantAgent instance named "assistant"
chatbot = AssistantAgent("assistant", config_list=config_list, **oai_config) chatbot = AssistantAgent("assistant", config_list=config_list, **oai_config)
# define your own function. Here we use a pre-defined '_execute_code' function from a UserProxyAgent instance # create a UserProxyAgent instance named "user"
# we define a wrapper function to call `exec_func`
exec_func = UserProxyAgent(name="execute_code", work_dir="coding", use_docker=False)._execute_code
def execute_code(code_type, code):
return exec_func([(code_type, code)])
# create a UserProxyAgent instance named "user", the execute_code_function is passed
user = UserProxyAgent( user = UserProxyAgent(
"user", "user",
human_input_mode="NEVER", human_input_mode="NEVER",
function_map={"execute_code": execute_code}, work_dir="coding",
) )
# define an `execute_code` function according to the function desription
def execute_code(code_type, code):
# here we reuse the method in the user proxy agent
# in general, this is not necessary
return user.execute_code([(code_type, code)])
# register the `execute_code` function
user.register_function(function_map={"execute_code": execute_code})
# start the conversation # start the conversation
chatbot.receive( user.initiate_chat(
assistant,
"Draw a rocket and save to a file named 'rocket.svg'", "Draw a rocket and save to a file named 'rocket.svg'",
user,
) )
``` ```