mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-25 16:16:37 +00:00
Using a more robust "reflection_with_llm" summary method (#1575)
* summary exception * badrequest error * test * skip reason * error * address func call in summary * reflection_with_llm enhancement and tests * remove old * update notebook * update notebook
This commit is contained in:
parent
e0fa6ee55b
commit
2a2e466932
@ -63,7 +63,7 @@ class ConversableAgent(Agent):
|
||||
DEFAULT_CONFIG = {} # An empty configuration
|
||||
MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change)
|
||||
|
||||
DEFAULT_summary_prompt = "Summarize the takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out."
|
||||
DEFAULT_summary_prompt = "Summarize the takeaway from the conversation. Do not add any introductory phrases."
|
||||
llm_config: Union[Dict, Literal[False]]
|
||||
|
||||
def __init__(
|
||||
@ -822,66 +822,51 @@ class ConversableAgent(Agent):
|
||||
"""
|
||||
agent = self if agent is None else agent
|
||||
summary = ""
|
||||
if method == "last_msg":
|
||||
try:
|
||||
summary = agent.last_message(self)["content"]
|
||||
summary = summary.replace("TERMINATE", "")
|
||||
except (IndexError, AttributeError):
|
||||
warnings.warn("Cannot extract summary from last message.", UserWarning)
|
||||
elif method == "reflection_with_llm":
|
||||
if method == "reflection_with_llm":
|
||||
prompt = ConversableAgent.DEFAULT_summary_prompt if prompt is None else prompt
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("The summary_prompt must be a string.")
|
||||
msg_list = agent._groupchat.messages if hasattr(agent, "_groupchat") else agent.chat_messages[self]
|
||||
try:
|
||||
summary = self._llm_response_preparer(prompt, msg_list, llm_agent=agent, cache=cache)
|
||||
summary = self._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=cache)
|
||||
except BadRequestError as e:
|
||||
warnings.warn(f"Cannot extract summary using reflection_with_llm: {e}", UserWarning)
|
||||
elif method == "last_msg" or method is None:
|
||||
try:
|
||||
summary = agent.last_message(self)["content"].replace("TERMINATE", "")
|
||||
except (IndexError, AttributeError) as e:
|
||||
warnings.warn(f"Cannot extract summary using last_msg: {e}", UserWarning)
|
||||
else:
|
||||
warnings.warn("No summary_method provided or summary_method is not supported: ")
|
||||
warnings.warn(f"Unsupported summary method: {method}", UserWarning)
|
||||
return summary
|
||||
|
||||
def _llm_response_preparer(
|
||||
def _reflection_with_llm(
|
||||
self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[Cache] = None
|
||||
) -> str:
|
||||
"""Default summary preparer with llm
|
||||
"""Get a chat summary using reflection with an llm client based on the conversation history.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt used to extract the final response from the transcript.
|
||||
prompt (str): The prompt (in this method it is used as system prompt) used to get the summary.
|
||||
messages (list): The messages generated as part of a chat conversation.
|
||||
llm_agent: the agent with an llm client.
|
||||
cache (Cache or None): the cache client to be used for this conversation.
|
||||
"""
|
||||
|
||||
_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:""",
|
||||
}
|
||||
]
|
||||
for message in messages:
|
||||
message = copy.deepcopy(message)
|
||||
message["role"] = "user"
|
||||
_messages.append(message)
|
||||
|
||||
_messages.append(
|
||||
system_msg = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
messages = messages + system_msg
|
||||
if llm_agent and llm_agent.client is not None:
|
||||
llm_client = llm_agent.client
|
||||
elif self.client is not None:
|
||||
llm_client = self.client
|
||||
else:
|
||||
raise ValueError("No OpenAIWrapper client is found.")
|
||||
|
||||
response = llm_client.create(context=None, messages=_messages, cache=cache)
|
||||
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
|
||||
if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"):
|
||||
return str(extracted_response.model_dump(mode="dict"))
|
||||
else:
|
||||
return extracted_response
|
||||
response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache)
|
||||
return response
|
||||
|
||||
def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[Agent, ChatResult]:
|
||||
"""(Experimental) Initiate chats with multiple agents.
|
||||
@ -1021,7 +1006,12 @@ class ConversableAgent(Agent):
|
||||
return False, None
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
extracted_response = self._generate_oai_reply_from_client(
|
||||
client, self._oai_system_message + messages, self.client_cache
|
||||
)
|
||||
return True, extracted_response
|
||||
|
||||
def _generate_oai_reply_from_client(self, llm_client, messages, cache):
|
||||
# unroll tool_responses
|
||||
all_messages = []
|
||||
for message in messages:
|
||||
@ -1035,13 +1025,12 @@ class ConversableAgent(Agent):
|
||||
all_messages.append(message)
|
||||
|
||||
# TODO: #1143 handle token limit exceeded error
|
||||
response = client.create(
|
||||
response = llm_client.create(
|
||||
context=messages[-1].pop("context", None),
|
||||
messages=self._oai_system_message + all_messages,
|
||||
cache=self.client_cache,
|
||||
messages=all_messages,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
extracted_response = client.extract_text_or_completion_object(response)[0]
|
||||
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
|
||||
|
||||
if extracted_response is None:
|
||||
warnings.warn("Extracted_response is None.", UserWarning)
|
||||
@ -1056,7 +1045,7 @@ class ConversableAgent(Agent):
|
||||
)
|
||||
for tool_call in extracted_response.get("tool_calls") or []:
|
||||
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])
|
||||
return True, extracted_response
|
||||
return extracted_response
|
||||
|
||||
async def a_generate_oai_reply(
|
||||
self,
|
||||
|
File diff suppressed because one or more lines are too long
@ -52,7 +52,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"id": "dca301a4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -122,7 +122,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"id": "9fb85afb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -249,7 +249,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 3,
|
||||
"id": "d5518947",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -264,9 +264,9 @@
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): currency_calculator *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_Ak49uR4cwLWyPKs5T2gK9bMg): currency_calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"base_amount\":123.45,\"base_currency\":\"USD\",\"quote_currency\":\"EUR\"}\n",
|
||||
"{\"base_amount\":123.45}\n",
|
||||
"\u001b[32m************************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
@ -276,7 +276,7 @@
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_ubo7cKE3TKumGHkqGjQtZisy\" *****\u001b[0m\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_Ak49uR4cwLWyPKs5T2gK9bMg\" *****\u001b[0m\n",
|
||||
"112.22727272727272 EUR\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
@ -302,12 +302,29 @@
|
||||
"source": [
|
||||
"with Cache.disk():\n",
|
||||
" # start the conversation\n",
|
||||
" user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 123.45 USD in EUR?\",\n",
|
||||
" res = user_proxy.initiate_chat(\n",
|
||||
" chatbot, message=\"How much is 123.45 USD in EUR?\", summary_method=\"reflection_with_llm\"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "4b5a0edc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chat summary: 123.45 USD is equivalent to approximately 112.23 EUR.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Chat summary:\", res.summary)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bd9d61cf",
|
||||
@ -326,7 +343,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 5,
|
||||
"id": "7b3d8b58",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -432,7 +449,7 @@
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_0VuU2rATuOgYrGmcBnXzPXlh): currency_calculator *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_G64JQKQBT2rI4vnuA4iz1vmE): currency_calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"base\":{\"currency\":\"EUR\",\"amount\":112.23},\"quote_currency\":\"USD\"}\n",
|
||||
"\u001b[32m************************************************************************************\u001b[0m\n",
|
||||
@ -444,14 +461,14 @@
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_0VuU2rATuOgYrGmcBnXzPXlh\" *****\u001b[0m\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_G64JQKQBT2rI4vnuA4iz1vmE\" *****\u001b[0m\n",
|
||||
"{\"currency\":\"USD\",\"amount\":123.45300000000002}\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"112.23 Euros is approximately 123.45 US Dollars.\n",
|
||||
"112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
@ -470,15 +487,32 @@
|
||||
"source": [
|
||||
"with Cache.disk():\n",
|
||||
" # start the conversation\n",
|
||||
" user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 112.23 Euros in US Dollars?\",\n",
|
||||
" res = user_proxy.initiate_chat(\n",
|
||||
" chatbot, message=\"How much is 112.23 Euros in US Dollars?\", summary_method=\"reflection_with_llm\"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "4799f60c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chat summary: 112.23 Euros is approximately 123.45 US Dollars.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Chat summary:\", res.summary)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "0064d9cd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -493,7 +527,7 @@
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_A6lqMu7s5SyDvftTSeQTtPcj): currency_calculator *****\u001b[0m\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_qv2SwJHpKrG73btxNzUnYBoR): currency_calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}\n",
|
||||
"\u001b[32m************************************************************************************\u001b[0m\n",
|
||||
@ -505,7 +539,7 @@
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_A6lqMu7s5SyDvftTSeQTtPcj\" *****\u001b[0m\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_qv2SwJHpKrG73btxNzUnYBoR\" *****\u001b[0m\n",
|
||||
"{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
@ -531,7 +565,7 @@
|
||||
"source": [
|
||||
"with Cache.disk():\n",
|
||||
" # start the conversation\n",
|
||||
" user_proxy.initiate_chat(\n",
|
||||
" res = user_proxy.initiate_chat(\n",
|
||||
" chatbot,\n",
|
||||
" message=\"How much is 123.45 US Dollars in Euros?\",\n",
|
||||
" )"
|
||||
@ -539,11 +573,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "06137f23",
|
||||
"execution_count": 15,
|
||||
"id": "80b2b42c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chat history: [{'content': 'How much is 123.45 US Dollars in Euros?', 'role': 'assistant'}, {'tool_calls': [{'id': 'call_qv2SwJHpKrG73btxNzUnYBoR', 'function': {'arguments': '{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}', 'name': 'currency_calculator'}, 'type': 'function'}], 'content': None, 'role': 'assistant'}, {'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}', 'tool_responses': [{'tool_call_id': 'call_qv2SwJHpKrG73btxNzUnYBoR', 'role': 'tool', 'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}'}], 'role': 'tool'}, {'content': '123.45 US Dollars is approximately 112.23 Euros.', 'role': 'user'}, {'content': '', 'role': 'assistant'}, {'content': 'TERMINATE', 'role': 'user'}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Chat history:\", res.chat_history)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
File diff suppressed because one or more lines are too long
@ -108,10 +108,12 @@ def test_agent_usage():
|
||||
)
|
||||
|
||||
math_problem = "$x^3=125$. What is x?"
|
||||
ai_user_proxy.initiate_chat(
|
||||
res = ai_user_proxy.initiate_chat(
|
||||
assistant,
|
||||
message=math_problem,
|
||||
summary_method="reflection_with_llm",
|
||||
)
|
||||
print("Result summary:", res.summary)
|
||||
|
||||
# test print
|
||||
captured_output = io.StringIO()
|
||||
|
@ -55,11 +55,12 @@ def test_ai_user_proxy_agent():
|
||||
assistant.reset()
|
||||
|
||||
math_problem = "$x^3=125$. What is x?"
|
||||
ai_user_proxy.initiate_chat(
|
||||
res = ai_user_proxy.initiate_chat(
|
||||
assistant,
|
||||
message=math_problem,
|
||||
)
|
||||
print(conversations)
|
||||
print("Result summary:", res.summary)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
|
||||
@ -149,7 +150,7 @@ def test_create_execute_script(human_input_mode="NEVER", max_consecutive_auto_re
|
||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||
)
|
||||
user.initiate_chat(
|
||||
res = user.initiate_chat(
|
||||
assistant,
|
||||
message="""Create a temp.py file with the following content:
|
||||
```
|
||||
@ -157,12 +158,14 @@ print('Hello world!')
|
||||
```""",
|
||||
)
|
||||
print(conversations)
|
||||
print("Result summary:", res.summary)
|
||||
# autogen.ChatCompletion.print_usage_summary()
|
||||
# autogen.ChatCompletion.start_logging(compact=False)
|
||||
user.send("""Execute temp.py""", assistant)
|
||||
res = user.send("""Execute temp.py""", assistant)
|
||||
# print(autogen.ChatCompletion.logged_history)
|
||||
# autogen.ChatCompletion.print_usage_summary()
|
||||
# autogen.ChatCompletion.stop_logging()
|
||||
print("Execution result summary:", res.summary)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
|
||||
|
@ -153,14 +153,18 @@ async def test_stream():
|
||||
|
||||
user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={"news_stream": data})
|
||||
|
||||
await user_proxy.a_initiate_chat(
|
||||
assistant,
|
||||
message="""Give me investment suggestion in 3 bullet points.""",
|
||||
chat_res = await user_proxy.a_initiate_chat(
|
||||
assistant, message="""Give me investment suggestion in 3 bullet points.""", summary_method="reflection_with_llm"
|
||||
)
|
||||
|
||||
print("Chat summary:", chat_res.summary)
|
||||
print("Chat cost:", chat_res.cost)
|
||||
|
||||
while not data_task.done() and not data_task.cancelled():
|
||||
reply = await user_proxy.a_generate_reply(sender=assistant)
|
||||
if reply is not None:
|
||||
await user_proxy.a_send(reply, assistant)
|
||||
res = await user_proxy.a_send(reply, assistant)
|
||||
print("Chat summary and cost:", res.summary, res.cost)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -38,6 +38,12 @@ async def test_async_get_human_input():
|
||||
|
||||
await user_proxy.a_initiate_chat(assistant, clear_history=True, message="Hello.")
|
||||
# Test without message
|
||||
await user_proxy.a_initiate_chat(assistant, clear_history=True)
|
||||
res = await user_proxy.a_initiate_chat(assistant, clear_history=True, summary_method="reflection_with_llm")
|
||||
# Assert that custom a_get_human_input was called at least once
|
||||
user_proxy.a_get_human_input.assert_called()
|
||||
print("Result summary:", res.summary)
|
||||
print("Human input:", res.human_input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_async_get_human_input())
|
||||
|
@ -234,23 +234,27 @@ def test_update_function():
|
||||
},
|
||||
is_remove=False,
|
||||
)
|
||||
user_proxy.initiate_chat(
|
||||
res1 = user_proxy.initiate_chat(
|
||||
assistant,
|
||||
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
|
||||
summary_method="reflection_with_llm",
|
||||
)
|
||||
messages1 = assistant.chat_messages[user_proxy][-1]["content"]
|
||||
print(messages1)
|
||||
print("Chat summary and cost", res1.summary, res1.cost)
|
||||
|
||||
assistant.update_function_signature("greet_user", is_remove=True)
|
||||
user_proxy.initiate_chat(
|
||||
res2 = user_proxy.initiate_chat(
|
||||
assistant,
|
||||
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
|
||||
summary_method="reflection_with_llm",
|
||||
)
|
||||
messages2 = assistant.chat_messages[user_proxy][-1]["content"]
|
||||
print(messages2)
|
||||
# The model should know about the function in the context of the conversation
|
||||
assert "greet_user" in messages1
|
||||
assert "greet_user" not in messages2
|
||||
print("Chat summary and cost", res2.summary, res2.cost)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -95,10 +95,14 @@ async def test_function_call_groupchat(key, value, sync):
|
||||
manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config_no_function)
|
||||
|
||||
if sync:
|
||||
observer.initiate_chat(manager, message="Let's start the game!")
|
||||
res = observer.initiate_chat(manager, message="Let's start the game!", summary_method="reflection_with_llm")
|
||||
else:
|
||||
await observer.a_initiate_chat(manager, message="Let's start the game!")
|
||||
res = await observer.a_initiate_chat(
|
||||
manager, message="Let's start the game!", summary_method="reflection_with_llm"
|
||||
)
|
||||
assert func.call_count >= 1, "The function get_random_number should be called at least once."
|
||||
print("Chat summary:", res.summary)
|
||||
print("Chat cost:", res.cost)
|
||||
|
||||
|
||||
def test_no_function_map():
|
||||
|
@ -606,12 +606,15 @@ def test_clear_agents_history():
|
||||
|
||||
# testing pure "clear history" statement
|
||||
with mock.patch.object(builtins, "input", lambda _: "clear history. How you doing?"):
|
||||
agent1.initiate_chat(group_chat_manager, message="hello")
|
||||
res = agent1.initiate_chat(group_chat_manager, message="hello", summary_method="last_msg")
|
||||
agent1_history = list(agent1._oai_messages.values())[0]
|
||||
agent2_history = list(agent2._oai_messages.values())[0]
|
||||
assert agent1_history == [{"content": "How you doing?", "name": "sam", "role": "user"}]
|
||||
assert agent2_history == [{"content": "How you doing?", "name": "sam", "role": "user"}]
|
||||
assert groupchat.messages == [{"content": "How you doing?", "name": "sam", "role": "user"}]
|
||||
print("Chat summary", res.summary)
|
||||
print("Chat cost", res.cost)
|
||||
print("Chat history", res.chat_history)
|
||||
|
||||
# testing clear history for defined agent
|
||||
with mock.patch.object(builtins, "input", lambda _: "clear history bob. How you doing?"):
|
||||
|
@ -34,9 +34,14 @@ def test_get_human_input():
|
||||
|
||||
user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply)
|
||||
|
||||
user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.")
|
||||
res = user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.")
|
||||
print("Result summary:", res.summary)
|
||||
print("Human input:", res.human_input)
|
||||
|
||||
# Test without supplying messages parameter
|
||||
user_proxy.initiate_chat(assistant, clear_history=True)
|
||||
res = user_proxy.initiate_chat(assistant, clear_history=True)
|
||||
print("Result summary:", res.summary)
|
||||
print("Human input:", res.human_input)
|
||||
|
||||
# Assert that custom_a_get_human_input was called at least once
|
||||
user_proxy.get_human_input.assert_called()
|
||||
|
@ -55,8 +55,10 @@ def test_math_user_proxy_agent():
|
||||
# message=mathproxyagent.generate_init_message(math_problem),
|
||||
# sender=mathproxyagent,
|
||||
# )
|
||||
mathproxyagent.initiate_chat(assistant, problem=math_problem)
|
||||
res = mathproxyagent.initiate_chat(assistant, problem=math_problem)
|
||||
print(conversations)
|
||||
print("Chat summary:", res.summary)
|
||||
print("Chat history:", res.chat_history)
|
||||
|
||||
|
||||
def test_add_remove_print():
|
||||
|
@ -165,23 +165,29 @@ def test_update_tool():
|
||||
},
|
||||
is_remove=False,
|
||||
)
|
||||
user_proxy.initiate_chat(
|
||||
res = user_proxy.initiate_chat(
|
||||
assistant,
|
||||
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
|
||||
)
|
||||
messages1 = assistant.chat_messages[user_proxy][-1]["content"]
|
||||
print(messages1)
|
||||
print("Message:", messages1)
|
||||
print("Summary:", res.summary)
|
||||
assert (
|
||||
messages1.replace("TERMINATE", "") == res.summary
|
||||
), "Message (removing TERMINATE) and summary should be the same"
|
||||
|
||||
assistant.update_tool_signature("greet_user", is_remove=True)
|
||||
user_proxy.initiate_chat(
|
||||
res = user_proxy.initiate_chat(
|
||||
assistant,
|
||||
message="What functions do you know about in the context of this conversation? End your response with 'TERMINATE'.",
|
||||
summary_method="reflection_with_llm",
|
||||
)
|
||||
messages2 = assistant.chat_messages[user_proxy][-1]["content"]
|
||||
print(messages2)
|
||||
print("Message2:", messages2)
|
||||
# The model should know about the function in the context of the conversation
|
||||
assert "greet_user" in messages1
|
||||
assert "greet_user" not in messages2
|
||||
print("Summary2:", res.summary)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
|
||||
@ -366,7 +372,7 @@ async def test_async_multi_tool_call():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_update_tool()
|
||||
test_update_tool()
|
||||
# test_eval_math_responses()
|
||||
# test_multi_tool_call()
|
||||
test_eval_math_responses_api_style_function()
|
||||
# test_eval_math_responses_api_style_function()
|
||||
|
Loading…
x
Reference in New Issue
Block a user