Using a more robust "reflection_with_llm" summary method (#1575)

* summary exception

* badrequest error

* test

* skip reason

* error

* address func call in summary

* reflection_with_llm enhancement and tests

* remove old

* update notebook

* update notebook
This commit is contained in:
Qingyun Wu 2024-02-07 12:17:05 -05:00 committed by GitHub
parent e0fa6ee55b
commit 2a2e466932
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 769 additions and 768 deletions

View File

@ -63,7 +63,7 @@ class ConversableAgent(Agent):
DEFAULT_CONFIG = {} # An empty configuration DEFAULT_CONFIG = {} # An empty configuration
MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change) MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change)
DEFAULT_summary_prompt = "Summarize the takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out." DEFAULT_summary_prompt = "Summarize the takeaway from the conversation. Do not add any introductory phrases."
llm_config: Union[Dict, Literal[False]] llm_config: Union[Dict, Literal[False]]
def __init__( def __init__(
@ -822,66 +822,51 @@ class ConversableAgent(Agent):
""" """
agent = self if agent is None else agent agent = self if agent is None else agent
summary = "" summary = ""
if method == "last_msg": if method == "reflection_with_llm":
try:
summary = agent.last_message(self)["content"]
summary = summary.replace("TERMINATE", "")
except (IndexError, AttributeError):
warnings.warn("Cannot extract summary from last message.", UserWarning)
elif method == "reflection_with_llm":
prompt = ConversableAgent.DEFAULT_summary_prompt if prompt is None else prompt prompt = ConversableAgent.DEFAULT_summary_prompt if prompt is None else prompt
if not isinstance(prompt, str): if not isinstance(prompt, str):
raise ValueError("The summary_prompt must be a string.") raise ValueError("The summary_prompt must be a string.")
msg_list = agent._groupchat.messages if hasattr(agent, "_groupchat") else agent.chat_messages[self] msg_list = agent._groupchat.messages if hasattr(agent, "_groupchat") else agent.chat_messages[self]
try: try:
summary = self._llm_response_preparer(prompt, msg_list, llm_agent=agent, cache=cache) summary = self._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=cache)
except BadRequestError as e: except BadRequestError as e:
warnings.warn(f"Cannot extract summary using reflection_with_llm: {e}", UserWarning) warnings.warn(f"Cannot extract summary using reflection_with_llm: {e}", UserWarning)
elif method == "last_msg" or method is None:
try:
summary = agent.last_message(self)["content"].replace("TERMINATE", "")
except (IndexError, AttributeError) as e:
warnings.warn(f"Cannot extract summary using last_msg: {e}", UserWarning)
else: else:
warnings.warn("No summary_method provided or summary_method is not supported: ") warnings.warn(f"Unsupported summary method: {method}", UserWarning)
return summary return summary
def _llm_response_preparer( def _reflection_with_llm(
self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[Cache] = None self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[Cache] = None
) -> str: ) -> str:
"""Default summary preparer with llm """Get a chat summary using reflection with an llm client based on the conversation history.
Args: Args:
prompt (str): The prompt used to extract the final response from the transcript. prompt (str): The prompt (in this method it is used as system prompt) used to get the summary.
messages (list): The messages generated as part of a chat conversation. messages (list): The messages generated as part of a chat conversation.
llm_agent: the agent with an llm client.
cache (Cache or None): the cache client to be used for this conversation.
""" """
system_msg = [
_messages = [
{
"role": "system",
"content": """Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:""",
}
]
for message in messages:
message = copy.deepcopy(message)
message["role"] = "user"
_messages.append(message)
_messages.append(
{ {
"role": "system", "role": "system",
"content": prompt, "content": prompt,
} }
) ]
messages = messages + system_msg
if llm_agent and llm_agent.client is not None: if llm_agent and llm_agent.client is not None:
llm_client = llm_agent.client llm_client = llm_agent.client
elif self.client is not None: elif self.client is not None:
llm_client = self.client llm_client = self.client
else: else:
raise ValueError("No OpenAIWrapper client is found.") raise ValueError("No OpenAIWrapper client is found.")
response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache)
response = llm_client.create(context=None, messages=_messages, cache=cache) return response
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"):
return str(extracted_response.model_dump(mode="dict"))
else:
return extracted_response
def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[Agent, ChatResult]: def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[Agent, ChatResult]:
"""(Experimental) Initiate chats with multiple agents. """(Experimental) Initiate chats with multiple agents.
@ -1021,7 +1006,12 @@ class ConversableAgent(Agent):
return False, None return False, None
if messages is None: if messages is None:
messages = self._oai_messages[sender] messages = self._oai_messages[sender]
extracted_response = self._generate_oai_reply_from_client(
client, self._oai_system_message + messages, self.client_cache
)
return True, extracted_response
def _generate_oai_reply_from_client(self, llm_client, messages, cache):
# unroll tool_responses # unroll tool_responses
all_messages = [] all_messages = []
for message in messages: for message in messages:
@ -1035,13 +1025,12 @@ class ConversableAgent(Agent):
all_messages.append(message) all_messages.append(message)
# TODO: #1143 handle token limit exceeded error # TODO: #1143 handle token limit exceeded error
response = client.create( response = llm_client.create(
context=messages[-1].pop("context", None), context=messages[-1].pop("context", None),
messages=self._oai_system_message + all_messages, messages=all_messages,
cache=self.client_cache, cache=cache,
) )
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
extracted_response = client.extract_text_or_completion_object(response)[0]
if extracted_response is None: if extracted_response is None:
warnings.warn("Extracted_response is None.", UserWarning) warnings.warn("Extracted_response is None.", UserWarning)
@ -1056,7 +1045,7 @@ class ConversableAgent(Agent):
) )
for tool_call in extracted_response.get("tool_calls") or []: for tool_call in extracted_response.get("tool_calls") or []:
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"]) tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])
return True, extracted_response return extracted_response
async def a_generate_oai_reply( async def a_generate_oai_reply(
self, self,

File diff suppressed because one or more lines are too long

View File

@ -52,7 +52,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "dca301a4", "id": "dca301a4",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -122,7 +122,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"id": "9fb85afb", "id": "9fb85afb",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -249,7 +249,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 3,
"id": "d5518947", "id": "d5518947",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -264,9 +264,9 @@
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n", "\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n", "\n",
"\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): currency_calculator *****\u001b[0m\n", "\u001b[32m***** Suggested tool Call (call_Ak49uR4cwLWyPKs5T2gK9bMg): currency_calculator *****\u001b[0m\n",
"Arguments: \n", "Arguments: \n",
"{\"base_amount\":123.45,\"base_currency\":\"USD\",\"quote_currency\":\"EUR\"}\n", "{\"base_amount\":123.45}\n",
"\u001b[32m************************************************************************************\u001b[0m\n", "\u001b[32m************************************************************************************\u001b[0m\n",
"\n", "\n",
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
@ -276,7 +276,7 @@
"\n", "\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n", "\n",
"\u001b[32m***** Response from calling tool \"call_ubo7cKE3TKumGHkqGjQtZisy\" *****\u001b[0m\n", "\u001b[32m***** Response from calling tool \"call_Ak49uR4cwLWyPKs5T2gK9bMg\" *****\u001b[0m\n",
"112.22727272727272 EUR\n", "112.22727272727272 EUR\n",
"\u001b[32m**********************************************************************\u001b[0m\n", "\u001b[32m**********************************************************************\u001b[0m\n",
"\n", "\n",
@ -302,12 +302,29 @@
"source": [ "source": [
"with Cache.disk():\n", "with Cache.disk():\n",
" # start the conversation\n", " # start the conversation\n",
" user_proxy.initiate_chat(\n", " res = user_proxy.initiate_chat(\n",
" chatbot,\n", " chatbot, message=\"How much is 123.45 USD in EUR?\", summary_method=\"reflection_with_llm\"\n",
" message=\"How much is 123.45 USD in EUR?\",\n",
" )" " )"
] ]
}, },
{
"cell_type": "code",
"execution_count": 4,
"id": "4b5a0edc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat summary: 123.45 USD is equivalent to approximately 112.23 EUR.\n"
]
}
],
"source": [
"print(\"Chat summary:\", res.summary)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "bd9d61cf", "id": "bd9d61cf",
@ -326,7 +343,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 5,
"id": "7b3d8b58", "id": "7b3d8b58",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -432,7 +449,7 @@
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n", "\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n", "\n",
"\u001b[32m***** Suggested tool Call (call_0VuU2rATuOgYrGmcBnXzPXlh): currency_calculator *****\u001b[0m\n", "\u001b[32m***** Suggested tool Call (call_G64JQKQBT2rI4vnuA4iz1vmE): currency_calculator *****\u001b[0m\n",
"Arguments: \n", "Arguments: \n",
"{\"base\":{\"currency\":\"EUR\",\"amount\":112.23},\"quote_currency\":\"USD\"}\n", "{\"base\":{\"currency\":\"EUR\",\"amount\":112.23},\"quote_currency\":\"USD\"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n", "\u001b[32m************************************************************************************\u001b[0m\n",
@ -444,14 +461,14 @@
"\n", "\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n", "\n",
"\u001b[32m***** Response from calling tool \"call_0VuU2rATuOgYrGmcBnXzPXlh\" *****\u001b[0m\n", "\u001b[32m***** Response from calling tool \"call_G64JQKQBT2rI4vnuA4iz1vmE\" *****\u001b[0m\n",
"{\"currency\":\"USD\",\"amount\":123.45300000000002}\n", "{\"currency\":\"USD\",\"amount\":123.45300000000002}\n",
"\u001b[32m**********************************************************************\u001b[0m\n", "\u001b[32m**********************************************************************\u001b[0m\n",
"\n", "\n",
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n", "\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n", "\n",
"112.23 Euros is approximately 123.45 US Dollars.\n", "112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
"\n", "\n",
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
@ -470,15 +487,32 @@
"source": [ "source": [
"with Cache.disk():\n", "with Cache.disk():\n",
" # start the conversation\n", " # start the conversation\n",
" user_proxy.initiate_chat(\n", " res = user_proxy.initiate_chat(\n",
" chatbot,\n", " chatbot, message=\"How much is 112.23 Euros in US Dollars?\", summary_method=\"reflection_with_llm\"\n",
" message=\"How much is 112.23 Euros in US Dollars?\",\n",
" )" " )"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
"id": "4799f60c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat summary: 112.23 Euros is approximately 123.45 US Dollars.\n"
]
}
],
"source": [
"print(\"Chat summary:\", res.summary)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0064d9cd", "id": "0064d9cd",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -493,7 +527,7 @@
"--------------------------------------------------------------------------------\n", "--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n", "\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n", "\n",
"\u001b[32m***** Suggested tool Call (call_A6lqMu7s5SyDvftTSeQTtPcj): currency_calculator *****\u001b[0m\n", "\u001b[32m***** Suggested tool Call (call_qv2SwJHpKrG73btxNzUnYBoR): currency_calculator *****\u001b[0m\n",
"Arguments: \n", "Arguments: \n",
"{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}\n", "{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n", "\u001b[32m************************************************************************************\u001b[0m\n",
@ -505,7 +539,7 @@
"\n", "\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n", "\n",
"\u001b[32m***** Response from calling tool \"call_A6lqMu7s5SyDvftTSeQTtPcj\" *****\u001b[0m\n", "\u001b[32m***** Response from calling tool \"call_qv2SwJHpKrG73btxNzUnYBoR\" *****\u001b[0m\n",
"{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n", "{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n",
"\u001b[32m**********************************************************************\u001b[0m\n", "\u001b[32m**********************************************************************\u001b[0m\n",
"\n", "\n",
@ -531,7 +565,7 @@
"source": [ "source": [
"with Cache.disk():\n", "with Cache.disk():\n",
" # start the conversation\n", " # start the conversation\n",
" user_proxy.initiate_chat(\n", " res = user_proxy.initiate_chat(\n",
" chatbot,\n", " chatbot,\n",
" message=\"How much is 123.45 US Dollars in Euros?\",\n", " message=\"How much is 123.45 US Dollars in Euros?\",\n",
" )" " )"
@ -539,11 +573,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 15,
"id": "06137f23", "id": "80b2b42c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
"source": [] {
"name": "stdout",
"output_type": "stream",
"text": [
"Chat history: [{'content': 'How much is 123.45 US Dollars in Euros?', 'role': 'assistant'}, {'tool_calls': [{'id': 'call_qv2SwJHpKrG73btxNzUnYBoR', 'function': {'arguments': '{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}', 'name': 'currency_calculator'}, 'type': 'function'}], 'content': None, 'role': 'assistant'}, {'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}', 'tool_responses': [{'tool_call_id': 'call_qv2SwJHpKrG73btxNzUnYBoR', 'role': 'tool', 'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}'}], 'role': 'tool'}, {'content': '123.45 US Dollars is approximately 112.23 Euros.', 'role': 'user'}, {'content': '', 'role': 'assistant'}, {'content': 'TERMINATE', 'role': 'user'}]\n"
]
}
],
"source": [
"print(\"Chat history:\", res.chat_history)"
]
} }
], ],
"metadata": { "metadata": {

File diff suppressed because one or more lines are too long

View File

@ -108,10 +108,12 @@ def test_agent_usage():
) )
math_problem = "$x^3=125$. What is x?" math_problem = "$x^3=125$. What is x?"
ai_user_proxy.initiate_chat( res = ai_user_proxy.initiate_chat(
assistant, assistant,
message=math_problem, message=math_problem,
summary_method="reflection_with_llm",
) )
print("Result summary:", res.summary)
# test print # test print
captured_output = io.StringIO() captured_output = io.StringIO()

View File

@ -55,11 +55,12 @@ def test_ai_user_proxy_agent():
assistant.reset() assistant.reset()
math_problem = "$x^3=125$. What is x?" math_problem = "$x^3=125$. What is x?"
ai_user_proxy.initiate_chat( res = ai_user_proxy.initiate_chat(
assistant, assistant,
message=math_problem, message=math_problem,
) )
print(conversations) print(conversations)
print("Result summary:", res.summary)
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip") @pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
@ -149,7 +150,7 @@ def test_create_execute_script(human_input_mode="NEVER", max_consecutive_auto_re
max_consecutive_auto_reply=max_consecutive_auto_reply, max_consecutive_auto_reply=max_consecutive_auto_reply,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"), is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
) )
user.initiate_chat( res = user.initiate_chat(
assistant, assistant,
message="""Create a temp.py file with the following content: message="""Create a temp.py file with the following content:
``` ```
@ -157,12 +158,14 @@ print('Hello world!')
```""", ```""",
) )
print(conversations) print(conversations)
print("Result summary:", res.summary)
# autogen.ChatCompletion.print_usage_summary() # autogen.ChatCompletion.print_usage_summary()
# autogen.ChatCompletion.start_logging(compact=False) # autogen.ChatCompletion.start_logging(compact=False)
user.send("""Execute temp.py""", assistant) res = user.send("""Execute temp.py""", assistant)
# print(autogen.ChatCompletion.logged_history) # print(autogen.ChatCompletion.logged_history)
# autogen.ChatCompletion.print_usage_summary() # autogen.ChatCompletion.print_usage_summary()
# autogen.ChatCompletion.stop_logging() # autogen.ChatCompletion.stop_logging()
print("Execution result summary:", res.summary)
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip") @pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")

View File

@ -153,14 +153,18 @@ async def test_stream():
user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={"news_stream": data}) user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={"news_stream": data})
await user_proxy.a_initiate_chat( chat_res = await user_proxy.a_initiate_chat(
assistant, assistant, message="""Give me investment suggestion in 3 bullet points.""", summary_method="reflection_with_llm"
message="""Give me investment suggestion in 3 bullet points.""",
) )
print("Chat summary:", chat_res.summary)
print("Chat cost:", chat_res.cost)
while not data_task.done() and not data_task.cancelled(): while not data_task.done() and not data_task.cancelled():
reply = await user_proxy.a_generate_reply(sender=assistant) reply = await user_proxy.a_generate_reply(sender=assistant)
if reply is not None: if reply is not None:
await user_proxy.a_send(reply, assistant) res = await user_proxy.a_send(reply, assistant)
print("Chat summary and cost:", res.summary, res.cost)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -38,6 +38,12 @@ async def test_async_get_human_input():
await user_proxy.a_initiate_chat(assistant, clear_history=True, message="Hello.") await user_proxy.a_initiate_chat(assistant, clear_history=True, message="Hello.")
# Test without message # Test without message
await user_proxy.a_initiate_chat(assistant, clear_history=True) res = await user_proxy.a_initiate_chat(assistant, clear_history=True, summary_method="reflection_with_llm")
# Assert that custom a_get_human_input was called at least once # Assert that custom a_get_human_input was called at least once
user_proxy.a_get_human_input.assert_called() user_proxy.a_get_human_input.assert_called()
print("Result summary:", res.summary)
print("Human input:", res.human_input)
if __name__ == "__main__":
asyncio.run(test_async_get_human_input())

View File

@ -234,23 +234,27 @@ def test_update_function():
}, },
is_remove=False, is_remove=False,
) )
user_proxy.initiate_chat( res1 = user_proxy.initiate_chat(
assistant, assistant,
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.", message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
summary_method="reflection_with_llm",
) )
messages1 = assistant.chat_messages[user_proxy][-1]["content"] messages1 = assistant.chat_messages[user_proxy][-1]["content"]
print(messages1) print(messages1)
print("Chat summary and cost", res1.summary, res1.cost)
assistant.update_function_signature("greet_user", is_remove=True) assistant.update_function_signature("greet_user", is_remove=True)
user_proxy.initiate_chat( res2 = user_proxy.initiate_chat(
assistant, assistant,
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.", message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
summary_method="reflection_with_llm",
) )
messages2 = assistant.chat_messages[user_proxy][-1]["content"] messages2 = assistant.chat_messages[user_proxy][-1]["content"]
print(messages2) print(messages2)
# The model should know about the function in the context of the conversation # The model should know about the function in the context of the conversation
assert "greet_user" in messages1 assert "greet_user" in messages1
assert "greet_user" not in messages2 assert "greet_user" not in messages2
print("Chat summary and cost", res2.summary, res2.cost)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -95,10 +95,14 @@ async def test_function_call_groupchat(key, value, sync):
manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config_no_function) manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config_no_function)
if sync: if sync:
observer.initiate_chat(manager, message="Let's start the game!") res = observer.initiate_chat(manager, message="Let's start the game!", summary_method="reflection_with_llm")
else: else:
await observer.a_initiate_chat(manager, message="Let's start the game!") res = await observer.a_initiate_chat(
manager, message="Let's start the game!", summary_method="reflection_with_llm"
)
assert func.call_count >= 1, "The function get_random_number should be called at least once." assert func.call_count >= 1, "The function get_random_number should be called at least once."
print("Chat summary:", res.summary)
print("Chat cost:", res.cost)
def test_no_function_map(): def test_no_function_map():

View File

@ -606,12 +606,15 @@ def test_clear_agents_history():
# testing pure "clear history" statement # testing pure "clear history" statement
with mock.patch.object(builtins, "input", lambda _: "clear history. How you doing?"): with mock.patch.object(builtins, "input", lambda _: "clear history. How you doing?"):
agent1.initiate_chat(group_chat_manager, message="hello") res = agent1.initiate_chat(group_chat_manager, message="hello", summary_method="last_msg")
agent1_history = list(agent1._oai_messages.values())[0] agent1_history = list(agent1._oai_messages.values())[0]
agent2_history = list(agent2._oai_messages.values())[0] agent2_history = list(agent2._oai_messages.values())[0]
assert agent1_history == [{"content": "How you doing?", "name": "sam", "role": "user"}] assert agent1_history == [{"content": "How you doing?", "name": "sam", "role": "user"}]
assert agent2_history == [{"content": "How you doing?", "name": "sam", "role": "user"}] assert agent2_history == [{"content": "How you doing?", "name": "sam", "role": "user"}]
assert groupchat.messages == [{"content": "How you doing?", "name": "sam", "role": "user"}] assert groupchat.messages == [{"content": "How you doing?", "name": "sam", "role": "user"}]
print("Chat summary", res.summary)
print("Chat cost", res.cost)
print("Chat history", res.chat_history)
# testing clear history for defined agent # testing clear history for defined agent
with mock.patch.object(builtins, "input", lambda _: "clear history bob. How you doing?"): with mock.patch.object(builtins, "input", lambda _: "clear history bob. How you doing?"):

View File

@ -34,9 +34,14 @@ def test_get_human_input():
user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply) user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply)
user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.") res = user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.")
print("Result summary:", res.summary)
print("Human input:", res.human_input)
# Test without supplying messages parameter # Test without supplying messages parameter
user_proxy.initiate_chat(assistant, clear_history=True) res = user_proxy.initiate_chat(assistant, clear_history=True)
print("Result summary:", res.summary)
print("Human input:", res.human_input)
# Assert that custom_a_get_human_input was called at least once # Assert that custom_a_get_human_input was called at least once
user_proxy.get_human_input.assert_called() user_proxy.get_human_input.assert_called()

View File

@ -55,8 +55,10 @@ def test_math_user_proxy_agent():
# message=mathproxyagent.generate_init_message(math_problem), # message=mathproxyagent.generate_init_message(math_problem),
# sender=mathproxyagent, # sender=mathproxyagent,
# ) # )
mathproxyagent.initiate_chat(assistant, problem=math_problem) res = mathproxyagent.initiate_chat(assistant, problem=math_problem)
print(conversations) print(conversations)
print("Chat summary:", res.summary)
print("Chat history:", res.chat_history)
def test_add_remove_print(): def test_add_remove_print():

View File

@ -165,23 +165,29 @@ def test_update_tool():
}, },
is_remove=False, is_remove=False,
) )
user_proxy.initiate_chat( res = user_proxy.initiate_chat(
assistant, assistant,
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.", message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
) )
messages1 = assistant.chat_messages[user_proxy][-1]["content"] messages1 = assistant.chat_messages[user_proxy][-1]["content"]
print(messages1) print("Message:", messages1)
print("Summary:", res.summary)
assert (
messages1.replace("TERMINATE", "") == res.summary
), "Message (removing TERMINATE) and summary should be the same"
assistant.update_tool_signature("greet_user", is_remove=True) assistant.update_tool_signature("greet_user", is_remove=True)
user_proxy.initiate_chat( res = user_proxy.initiate_chat(
assistant, assistant,
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.", message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
summary_method="reflection_with_llm",
) )
messages2 = assistant.chat_messages[user_proxy][-1]["content"] messages2 = assistant.chat_messages[user_proxy][-1]["content"]
print(messages2) print("Message2:", messages2)
# The model should know about the function in the context of the conversation # The model should know about the function in the context of the conversation
assert "greet_user" in messages1 assert "greet_user" in messages1
assert "greet_user" not in messages2 assert "greet_user" not in messages2
print("Summary2:", res.summary)
@pytest.mark.skipif(not TOOL_ENABLED, reason="openai>=1.1.0 not installed") @pytest.mark.skipif(not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
@ -366,7 +372,7 @@ async def test_async_multi_tool_call():
if __name__ == "__main__": if __name__ == "__main__":
# test_update_tool() test_update_tool()
# test_eval_math_responses() # test_eval_math_responses()
# test_multi_tool_call() # test_multi_tool_call()
test_eval_math_responses_api_style_function() # test_eval_math_responses_api_style_function()