Using a more robust "reflection_with_llm" summary method (#1575)

* summary exception

* badrequest error

* test

* skip reason

* error

* address func call in summary

* reflection_with_llm enhancement and tests

* remove old

* update notebook

* update notebook
This commit is contained in:
Qingyun Wu 2024-02-07 12:17:05 -05:00 committed by GitHub
parent e0fa6ee55b
commit 2a2e466932
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 769 additions and 768 deletions

View File

@ -63,7 +63,7 @@ class ConversableAgent(Agent):
DEFAULT_CONFIG = {} # An empty configuration
MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change)
DEFAULT_summary_prompt = "Summarize the takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out."
DEFAULT_summary_prompt = "Summarize the takeaway from the conversation. Do not add any introductory phrases."
llm_config: Union[Dict, Literal[False]]
def __init__(
@ -822,66 +822,51 @@ class ConversableAgent(Agent):
"""
agent = self if agent is None else agent
summary = ""
if method == "last_msg":
try:
summary = agent.last_message(self)["content"]
summary = summary.replace("TERMINATE", "")
except (IndexError, AttributeError):
warnings.warn("Cannot extract summary from last message.", UserWarning)
elif method == "reflection_with_llm":
if method == "reflection_with_llm":
prompt = ConversableAgent.DEFAULT_summary_prompt if prompt is None else prompt
if not isinstance(prompt, str):
raise ValueError("The summary_prompt must be a string.")
msg_list = agent._groupchat.messages if hasattr(agent, "_groupchat") else agent.chat_messages[self]
try:
summary = self._llm_response_preparer(prompt, msg_list, llm_agent=agent, cache=cache)
summary = self._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=cache)
except BadRequestError as e:
warnings.warn(f"Cannot extract summary using reflection_with_llm: {e}", UserWarning)
elif method == "last_msg" or method is None:
try:
summary = agent.last_message(self)["content"].replace("TERMINATE", "")
except (IndexError, AttributeError) as e:
warnings.warn(f"Cannot extract summary using last_msg: {e}", UserWarning)
else:
warnings.warn("No summary_method provided or summary_method is not supported: ")
warnings.warn(f"Unsupported summary method: {method}", UserWarning)
return summary
def _llm_response_preparer(
def _reflection_with_llm(
self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[Cache] = None
) -> str:
"""Default summary preparer with llm
"""Get a chat summary using reflection with an llm client based on the conversation history.
Args:
prompt (str): The prompt used to extract the final response from the transcript.
prompt (str): The prompt (in this method it is used as system prompt) used to get the summary.
messages (list): The messages generated as part of a chat conversation.
llm_agent: the agent with an llm client.
cache (Cache or None): the cache client to be used for this conversation.
"""
_messages = [
{
"role": "system",
"content": """Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:""",
}
]
for message in messages:
message = copy.deepcopy(message)
message["role"] = "user"
_messages.append(message)
_messages.append(
system_msg = [
{
"role": "system",
"content": prompt,
}
)
]
messages = messages + system_msg
if llm_agent and llm_agent.client is not None:
llm_client = llm_agent.client
elif self.client is not None:
llm_client = self.client
else:
raise ValueError("No OpenAIWrapper client is found.")
response = llm_client.create(context=None, messages=_messages, cache=cache)
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"):
return str(extracted_response.model_dump(mode="dict"))
else:
return extracted_response
response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache)
return response
def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[Agent, ChatResult]:
"""(Experimental) Initiate chats with multiple agents.
@ -1021,7 +1006,12 @@ class ConversableAgent(Agent):
return False, None
if messages is None:
messages = self._oai_messages[sender]
extracted_response = self._generate_oai_reply_from_client(
client, self._oai_system_message + messages, self.client_cache
)
return True, extracted_response
def _generate_oai_reply_from_client(self, llm_client, messages, cache):
# unroll tool_responses
all_messages = []
for message in messages:
@ -1035,13 +1025,12 @@ class ConversableAgent(Agent):
all_messages.append(message)
# TODO: #1143 handle token limit exceeded error
response = client.create(
response = llm_client.create(
context=messages[-1].pop("context", None),
messages=self._oai_system_message + all_messages,
cache=self.client_cache,
messages=all_messages,
cache=cache,
)
extracted_response = client.extract_text_or_completion_object(response)[0]
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
if extracted_response is None:
warnings.warn("Extracted_response is None.", UserWarning)
@ -1056,7 +1045,7 @@ class ConversableAgent(Agent):
)
for tool_call in extracted_response.get("tool_calls") or []:
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])
return True, extracted_response
return extracted_response
async def a_generate_oai_reply(
self,

File diff suppressed because one or more lines are too long

View File

@ -52,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "dca301a4",
"metadata": {},
"outputs": [],
@ -122,7 +122,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "9fb85afb",
"metadata": {},
"outputs": [],
@ -249,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"id": "d5518947",
"metadata": {},
"outputs": [
@ -264,9 +264,9 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): currency_calculator *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_Ak49uR4cwLWyPKs5T2gK9bMg): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base_amount\":123.45,\"base_currency\":\"USD\",\"quote_currency\":\"EUR\"}\n",
"{\"base_amount\":123.45}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
@ -276,7 +276,7 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_ubo7cKE3TKumGHkqGjQtZisy\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_Ak49uR4cwLWyPKs5T2gK9bMg\" *****\u001b[0m\n",
"112.22727272727272 EUR\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
@ -302,12 +302,29 @@
"source": [
"with Cache.disk():\n",
" # start the conversation\n",
" user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 123.45 USD in EUR?\",\n",
" res = user_proxy.initiate_chat(\n",
" chatbot, message=\"How much is 123.45 USD in EUR?\", summary_method=\"reflection_with_llm\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4b5a0edc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat summary: 123.45 USD is equivalent to approximately 112.23 EUR.\n"
]
}
],
"source": [
"print(\"Chat summary:\", res.summary)"
]
},
{
"cell_type": "markdown",
"id": "bd9d61cf",
@ -326,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"id": "7b3d8b58",
"metadata": {},
"outputs": [],
@ -432,7 +449,7 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_0VuU2rATuOgYrGmcBnXzPXlh): currency_calculator *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_G64JQKQBT2rI4vnuA4iz1vmE): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base\":{\"currency\":\"EUR\",\"amount\":112.23},\"quote_currency\":\"USD\"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
@ -444,14 +461,14 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_0VuU2rATuOgYrGmcBnXzPXlh\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_G64JQKQBT2rI4vnuA4iz1vmE\" *****\u001b[0m\n",
"{\"currency\":\"USD\",\"amount\":123.45300000000002}\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"112.23 Euros is approximately 123.45 US Dollars.\n",
"112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
@ -470,15 +487,32 @@
"source": [
"with Cache.disk():\n",
" # start the conversation\n",
" user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 112.23 Euros in US Dollars?\",\n",
" res = user_proxy.initiate_chat(\n",
" chatbot, message=\"How much is 112.23 Euros in US Dollars?\", summary_method=\"reflection_with_llm\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4799f60c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat summary: 112.23 Euros is approximately 123.45 US Dollars.\n"
]
}
],
"source": [
"print(\"Chat summary:\", res.summary)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0064d9cd",
"metadata": {},
"outputs": [
@ -493,7 +527,7 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_A6lqMu7s5SyDvftTSeQTtPcj): currency_calculator *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_qv2SwJHpKrG73btxNzUnYBoR): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
@ -505,7 +539,7 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_A6lqMu7s5SyDvftTSeQTtPcj\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_qv2SwJHpKrG73btxNzUnYBoR\" *****\u001b[0m\n",
"{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
@ -531,7 +565,7 @@
"source": [
"with Cache.disk():\n",
" # start the conversation\n",
" user_proxy.initiate_chat(\n",
" res = user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 123.45 US Dollars in Euros?\",\n",
" )"
@ -539,11 +573,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "06137f23",
"execution_count": 15,
"id": "80b2b42c",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat history: [{'content': 'How much is 123.45 US Dollars in Euros?', 'role': 'assistant'}, {'tool_calls': [{'id': 'call_qv2SwJHpKrG73btxNzUnYBoR', 'function': {'arguments': '{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}', 'name': 'currency_calculator'}, 'type': 'function'}], 'content': None, 'role': 'assistant'}, {'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}', 'tool_responses': [{'tool_call_id': 'call_qv2SwJHpKrG73btxNzUnYBoR', 'role': 'tool', 'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}'}], 'role': 'tool'}, {'content': '123.45 US Dollars is approximately 112.23 Euros.', 'role': 'user'}, {'content': '', 'role': 'assistant'}, {'content': 'TERMINATE', 'role': 'user'}]\n"
]
}
],
"source": [
"print(\"Chat history:\", res.chat_history)"
]
}
],
"metadata": {

File diff suppressed because one or more lines are too long

View File

@ -108,10 +108,12 @@ def test_agent_usage():
)
math_problem = "$x^3=125$. What is x?"
ai_user_proxy.initiate_chat(
res = ai_user_proxy.initiate_chat(
assistant,
message=math_problem,
summary_method="reflection_with_llm",
)
print("Result summary:", res.summary)
# test print
captured_output = io.StringIO()

View File

@ -55,11 +55,12 @@ def test_ai_user_proxy_agent():
assistant.reset()
math_problem = "$x^3=125$. What is x?"
ai_user_proxy.initiate_chat(
res = ai_user_proxy.initiate_chat(
assistant,
message=math_problem,
)
print(conversations)
print("Result summary:", res.summary)
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
@ -149,7 +150,7 @@ def test_create_execute_script(human_input_mode="NEVER", max_consecutive_auto_re
max_consecutive_auto_reply=max_consecutive_auto_reply,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
)
user.initiate_chat(
res = user.initiate_chat(
assistant,
message="""Create a temp.py file with the following content:
```
@ -157,12 +158,14 @@ print('Hello world!')
```""",
)
print(conversations)
print("Result summary:", res.summary)
# autogen.ChatCompletion.print_usage_summary()
# autogen.ChatCompletion.start_logging(compact=False)
user.send("""Execute temp.py""", assistant)
res = user.send("""Execute temp.py""", assistant)
# print(autogen.ChatCompletion.logged_history)
# autogen.ChatCompletion.print_usage_summary()
# autogen.ChatCompletion.stop_logging()
print("Execution result summary:", res.summary)
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")

View File

@ -153,14 +153,18 @@ async def test_stream():
user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={"news_stream": data})
await user_proxy.a_initiate_chat(
assistant,
message="""Give me investment suggestion in 3 bullet points.""",
chat_res = await user_proxy.a_initiate_chat(
assistant, message="""Give me investment suggestion in 3 bullet points.""", summary_method="reflection_with_llm"
)
print("Chat summary:", chat_res.summary)
print("Chat cost:", chat_res.cost)
while not data_task.done() and not data_task.cancelled():
reply = await user_proxy.a_generate_reply(sender=assistant)
if reply is not None:
await user_proxy.a_send(reply, assistant)
res = await user_proxy.a_send(reply, assistant)
print("Chat summary and cost:", res.summary, res.cost)
if __name__ == "__main__":

View File

@ -38,6 +38,12 @@ async def test_async_get_human_input():
await user_proxy.a_initiate_chat(assistant, clear_history=True, message="Hello.")
# Test without message
await user_proxy.a_initiate_chat(assistant, clear_history=True)
res = await user_proxy.a_initiate_chat(assistant, clear_history=True, summary_method="reflection_with_llm")
# Assert that custom a_get_human_input was called at least once
user_proxy.a_get_human_input.assert_called()
print("Result summary:", res.summary)
print("Human input:", res.human_input)
if __name__ == "__main__":
asyncio.run(test_async_get_human_input())

View File

@ -234,23 +234,27 @@ def test_update_function():
},
is_remove=False,
)
user_proxy.initiate_chat(
res1 = user_proxy.initiate_chat(
assistant,
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
summary_method="reflection_with_llm",
)
messages1 = assistant.chat_messages[user_proxy][-1]["content"]
print(messages1)
print("Chat summary and cost", res1.summary, res1.cost)
assistant.update_function_signature("greet_user", is_remove=True)
user_proxy.initiate_chat(
res2 = user_proxy.initiate_chat(
assistant,
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
summary_method="reflection_with_llm",
)
messages2 = assistant.chat_messages[user_proxy][-1]["content"]
print(messages2)
# The model should know about the function in the context of the conversation
assert "greet_user" in messages1
assert "greet_user" not in messages2
print("Chat summary and cost", res2.summary, res2.cost)
if __name__ == "__main__":

View File

@ -95,10 +95,14 @@ async def test_function_call_groupchat(key, value, sync):
manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config_no_function)
if sync:
observer.initiate_chat(manager, message="Let's start the game!")
res = observer.initiate_chat(manager, message="Let's start the game!", summary_method="reflection_with_llm")
else:
await observer.a_initiate_chat(manager, message="Let's start the game!")
res = await observer.a_initiate_chat(
manager, message="Let's start the game!", summary_method="reflection_with_llm"
)
assert func.call_count >= 1, "The function get_random_number should be called at least once."
print("Chat summary:", res.summary)
print("Chat cost:", res.cost)
def test_no_function_map():

View File

@ -606,12 +606,15 @@ def test_clear_agents_history():
# testing pure "clear history" statement
with mock.patch.object(builtins, "input", lambda _: "clear history. How you doing?"):
agent1.initiate_chat(group_chat_manager, message="hello")
res = agent1.initiate_chat(group_chat_manager, message="hello", summary_method="last_msg")
agent1_history = list(agent1._oai_messages.values())[0]
agent2_history = list(agent2._oai_messages.values())[0]
assert agent1_history == [{"content": "How you doing?", "name": "sam", "role": "user"}]
assert agent2_history == [{"content": "How you doing?", "name": "sam", "role": "user"}]
assert groupchat.messages == [{"content": "How you doing?", "name": "sam", "role": "user"}]
print("Chat summary", res.summary)
print("Chat cost", res.cost)
print("Chat history", res.chat_history)
# testing clear history for defined agent
with mock.patch.object(builtins, "input", lambda _: "clear history bob. How you doing?"):

View File

@ -34,9 +34,14 @@ def test_get_human_input():
user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply)
user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.")
res = user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.")
print("Result summary:", res.summary)
print("Human input:", res.human_input)
# Test without supplying messages parameter
user_proxy.initiate_chat(assistant, clear_history=True)
res = user_proxy.initiate_chat(assistant, clear_history=True)
print("Result summary:", res.summary)
print("Human input:", res.human_input)
# Assert that custom_a_get_human_input was called at least once
user_proxy.get_human_input.assert_called()

View File

@ -55,8 +55,10 @@ def test_math_user_proxy_agent():
# message=mathproxyagent.generate_init_message(math_problem),
# sender=mathproxyagent,
# )
mathproxyagent.initiate_chat(assistant, problem=math_problem)
res = mathproxyagent.initiate_chat(assistant, problem=math_problem)
print(conversations)
print("Chat summary:", res.summary)
print("Chat history:", res.chat_history)
def test_add_remove_print():

View File

@ -165,23 +165,29 @@ def test_update_tool():
},
is_remove=False,
)
user_proxy.initiate_chat(
res = user_proxy.initiate_chat(
assistant,
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
)
messages1 = assistant.chat_messages[user_proxy][-1]["content"]
print(messages1)
print("Message:", messages1)
print("Summary:", res.summary)
assert (
messages1.replace("TERMINATE", "") == res.summary
), "Message (removing TERMINATE) and summary should be the same"
assistant.update_tool_signature("greet_user", is_remove=True)
user_proxy.initiate_chat(
res = user_proxy.initiate_chat(
assistant,
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
summary_method="reflection_with_llm",
)
messages2 = assistant.chat_messages[user_proxy][-1]["content"]
print(messages2)
print("Message2:", messages2)
# The model should know about the function in the context of the conversation
assert "greet_user" in messages1
assert "greet_user" not in messages2
print("Summary2:", res.summary)
@pytest.mark.skipif(not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
@ -366,7 +372,7 @@ async def test_async_multi_tool_call():
if __name__ == "__main__":
# test_update_tool()
test_update_tool()
# test_eval_math_responses()
# test_multi_tool_call()
test_eval_math_responses_api_style_function()
# test_eval_math_responses_api_style_function()