2023-07-25 16:46:11 -07:00
|
|
|
import pytest
|
2023-09-11 17:07:35 -07:00
|
|
|
from flaml.autogen.agentchat import ConversableAgent
|
2023-07-06 06:08:44 +08:00
|
|
|
|
|
|
|
|
2023-08-07 11:41:58 -07:00
|
|
|
def test_trigger():
|
2023-09-11 17:07:35 -07:00
|
|
|
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
|
|
|
|
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, human_input_mode="NEVER")
|
|
|
|
agent.register_reply(agent1, lambda recipient, messages, sender, config: (True, "hello"))
|
2023-08-07 11:41:58 -07:00
|
|
|
agent1.initiate_chat(agent, message="hi")
|
|
|
|
assert agent1.last_message(agent)["content"] == "hello"
|
2023-09-11 17:07:35 -07:00
|
|
|
agent.register_reply("a1", lambda recipient, messages, sender, config: (True, "hello a1"))
|
2023-08-07 11:41:58 -07:00
|
|
|
agent1.initiate_chat(agent, message="hi")
|
|
|
|
assert agent1.last_message(agent)["content"] == "hello a1"
|
2023-09-11 17:07:35 -07:00
|
|
|
agent.register_reply(
|
|
|
|
ConversableAgent, lambda recipient, messages, sender, config: (True, "hello conversable agent")
|
2023-08-07 11:41:58 -07:00
|
|
|
)
|
|
|
|
agent1.initiate_chat(agent, message="hi")
|
2023-09-11 17:07:35 -07:00
|
|
|
assert agent1.last_message(agent)["content"] == "hello conversable agent"
|
|
|
|
agent.register_reply(
|
2023-08-14 00:09:45 -07:00
|
|
|
lambda sender: sender.name.startswith("a"), lambda recipient, messages, sender, config: (True, "hello a")
|
2023-08-07 11:41:58 -07:00
|
|
|
)
|
|
|
|
agent1.initiate_chat(agent, message="hi")
|
|
|
|
assert agent1.last_message(agent)["content"] == "hello a"
|
2023-09-11 17:07:35 -07:00
|
|
|
agent.register_reply(
|
2023-08-14 00:09:45 -07:00
|
|
|
lambda sender: sender.name.startswith("b"), lambda recipient, messages, sender, config: (True, "hello b")
|
2023-08-07 11:41:58 -07:00
|
|
|
)
|
|
|
|
agent1.initiate_chat(agent, message="hi")
|
|
|
|
assert agent1.last_message(agent)["content"] == "hello a"
|
2023-09-11 17:07:35 -07:00
|
|
|
agent.register_reply(
|
2023-08-14 00:09:45 -07:00
|
|
|
["agent2", agent1], lambda recipient, messages, sender, config: (True, "hello agent2 or agent1")
|
2023-08-07 11:41:58 -07:00
|
|
|
)
|
|
|
|
agent1.initiate_chat(agent, message="hi")
|
|
|
|
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"
|
2023-09-11 17:07:35 -07:00
|
|
|
agent.register_reply(
|
2023-08-14 00:09:45 -07:00
|
|
|
["agent2", "agent3"], lambda recipient, messages, sender, config: (True, "hello agent2 or agent3")
|
2023-08-07 11:41:58 -07:00
|
|
|
)
|
|
|
|
agent1.initiate_chat(agent, message="hi")
|
|
|
|
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"
|
2023-09-11 17:07:35 -07:00
|
|
|
pytest.raises(ValueError, agent.register_reply, 1, lambda recipient, messages, sender, config: (True, "hi"))
|
2023-08-07 11:41:58 -07:00
|
|
|
pytest.raises(ValueError, agent._match_trigger, 1, agent1)
|
|
|
|
|
|
|
|
|
2023-08-03 02:17:20 -07:00
|
|
|
def test_context():
|
2023-09-11 17:07:35 -07:00
|
|
|
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
|
|
|
|
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, human_input_mode="NEVER")
|
2023-08-03 02:17:20 -07:00
|
|
|
agent1.send(
|
|
|
|
{
|
|
|
|
"content": "hello {name}",
|
|
|
|
"context": {
|
|
|
|
"name": "there",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
agent,
|
|
|
|
)
|
|
|
|
# expect hello {name} to be printed
|
|
|
|
agent1.send(
|
|
|
|
{
|
|
|
|
"content": lambda context: f"hello {context['name']}",
|
|
|
|
"context": {
|
|
|
|
"name": "there",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
agent,
|
|
|
|
)
|
|
|
|
# expect hello there to be printed
|
|
|
|
agent.llm_config = {"allow_format_str_template": True}
|
|
|
|
agent1.send(
|
|
|
|
{
|
|
|
|
"content": "hello {name}",
|
|
|
|
"context": {
|
|
|
|
"name": "there",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
agent,
|
|
|
|
)
|
|
|
|
# expect hello there to be printed
|
|
|
|
|
|
|
|
|
|
|
|
def test_max_consecutive_auto_reply():
|
2023-09-11 17:07:35 -07:00
|
|
|
agent = ConversableAgent("a0", max_consecutive_auto_reply=2, llm_config=False, human_input_mode="NEVER")
|
|
|
|
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, human_input_mode="NEVER")
|
2023-08-03 02:17:20 -07:00
|
|
|
assert agent.max_consecutive_auto_reply() == agent.max_consecutive_auto_reply(agent1) == 2
|
|
|
|
agent.update_max_consecutive_auto_reply(1)
|
|
|
|
assert agent.max_consecutive_auto_reply() == agent.max_consecutive_auto_reply(agent1) == 1
|
|
|
|
|
|
|
|
agent1.initiate_chat(agent, message="hello")
|
2023-08-04 07:26:58 -07:00
|
|
|
assert agent._consecutive_auto_reply_counter[agent1] == 1
|
2023-08-03 02:17:20 -07:00
|
|
|
agent1.initiate_chat(agent, message="hello again")
|
|
|
|
# with auto reply because the counter is reset
|
|
|
|
assert agent1.last_message(agent)["role"] == "user"
|
2023-08-04 07:26:58 -07:00
|
|
|
assert len(agent1.chat_messages[agent]) == 2
|
|
|
|
assert len(agent.chat_messages[agent1]) == 2
|
2023-08-03 02:17:20 -07:00
|
|
|
|
2023-08-04 07:26:58 -07:00
|
|
|
assert agent._consecutive_auto_reply_counter[agent1] == 1
|
2023-08-03 02:17:20 -07:00
|
|
|
agent1.send(message="bye", recipient=agent)
|
|
|
|
# no auto reply
|
|
|
|
assert agent1.last_message(agent)["role"] == "assistant"
|
|
|
|
|
|
|
|
agent1.initiate_chat(agent, clear_history=False, message="hi")
|
2023-08-04 07:26:58 -07:00
|
|
|
assert len(agent1.chat_messages[agent]) > 2
|
|
|
|
assert len(agent.chat_messages[agent1]) > 2
|
2023-08-03 02:17:20 -07:00
|
|
|
|
2023-08-04 07:26:58 -07:00
|
|
|
assert agent1.reply_at_receive[agent] == agent.reply_at_receive[agent1] is True
|
|
|
|
agent1.stop_reply_at_receive(agent)
|
|
|
|
assert agent1.reply_at_receive[agent] is False and agent.reply_at_receive[agent1] is True
|
2023-08-03 02:17:20 -07:00
|
|
|
|
2023-08-04 07:26:58 -07:00
|
|
|
|
2023-09-11 17:07:35 -07:00
|
|
|
def test_conversable_agent():
|
|
|
|
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", human_input_mode="ALWAYS")
|
|
|
|
dummy_agent_2 = ConversableAgent(name="dummy_agent_2", human_input_mode="TERMINATE")
|
2023-07-25 16:46:11 -07:00
|
|
|
|
2023-08-04 07:26:58 -07:00
|
|
|
# monkeypatch.setattr(sys, "stdin", StringIO("exit"))
|
2023-07-06 06:08:44 +08:00
|
|
|
dummy_agent_1.receive("hello", dummy_agent_2) # receive a str
|
2023-08-04 07:26:58 -07:00
|
|
|
# monkeypatch.setattr(sys, "stdin", StringIO("TERMINATE\n\n"))
|
2023-07-06 06:08:44 +08:00
|
|
|
dummy_agent_1.receive(
|
|
|
|
{
|
2023-07-31 19:22:30 -07:00
|
|
|
"content": "hello {name}",
|
|
|
|
"context": {
|
|
|
|
"name": "dummy_agent_2",
|
|
|
|
},
|
2023-07-06 06:08:44 +08:00
|
|
|
},
|
|
|
|
dummy_agent_2,
|
|
|
|
) # receive a dict
|
2023-08-04 07:26:58 -07:00
|
|
|
assert "context" in dummy_agent_1.chat_messages[dummy_agent_2][-1]
|
2023-07-06 06:08:44 +08:00
|
|
|
# receive dict without openai fields to be printed, such as "content", 'function_call'. There should be no error raised.
|
2023-08-04 07:26:58 -07:00
|
|
|
pre_len = len(dummy_agent_1.chat_messages[dummy_agent_2])
|
2023-07-31 19:22:30 -07:00
|
|
|
with pytest.raises(ValueError):
|
|
|
|
dummy_agent_1.receive({"message": "hello"}, dummy_agent_2)
|
2023-07-06 06:08:44 +08:00
|
|
|
assert pre_len == len(
|
2023-08-04 07:26:58 -07:00
|
|
|
dummy_agent_1.chat_messages[dummy_agent_2]
|
2023-07-06 06:08:44 +08:00
|
|
|
), "When the message is not an valid openai message, it should not be appended to the oai conversation."
|
|
|
|
|
2023-08-04 07:26:58 -07:00
|
|
|
# monkeypatch.setattr(sys, "stdin", StringIO("exit"))
|
2023-07-25 16:46:11 -07:00
|
|
|
dummy_agent_1.send("TERMINATE", dummy_agent_2) # send a str
|
2023-08-04 07:26:58 -07:00
|
|
|
# monkeypatch.setattr(sys, "stdin", StringIO("exit"))
|
2023-07-17 20:40:41 -07:00
|
|
|
dummy_agent_1.send(
|
2023-07-06 06:08:44 +08:00
|
|
|
{
|
2023-07-25 16:46:11 -07:00
|
|
|
"content": "TERMINATE",
|
2023-07-06 06:08:44 +08:00
|
|
|
},
|
|
|
|
dummy_agent_2,
|
|
|
|
) # send a dict
|
|
|
|
|
2023-07-28 21:17:51 -07:00
|
|
|
# send dict with no openai fields
|
2023-08-04 07:26:58 -07:00
|
|
|
pre_len = len(dummy_agent_1.chat_messages[dummy_agent_2])
|
2023-07-28 21:17:51 -07:00
|
|
|
with pytest.raises(ValueError):
|
|
|
|
dummy_agent_1.send({"message": "hello"}, dummy_agent_2)
|
2023-07-06 06:08:44 +08:00
|
|
|
|
|
|
|
assert pre_len == len(
|
2023-08-04 07:26:58 -07:00
|
|
|
dummy_agent_1.chat_messages[dummy_agent_2]
|
2023-07-06 06:08:44 +08:00
|
|
|
), "When the message is not a valid openai message, it should not be appended to the oai conversation."
|
|
|
|
|
2023-07-31 19:22:30 -07:00
|
|
|
# update system message
|
|
|
|
dummy_agent_1.update_system_message("new system message")
|
2023-08-04 07:26:58 -07:00
|
|
|
assert dummy_agent_1.system_message == "new system message"
|
2023-07-31 19:22:30 -07:00
|
|
|
|
2023-07-06 06:08:44 +08:00
|
|
|
|
2023-08-25 06:50:22 -04:00
|
|
|
def test_generate_reply():
|
|
|
|
def add_num(num_to_be_added):
|
|
|
|
given_num = 10
|
|
|
|
return num_to_be_added + given_num
|
|
|
|
|
2023-09-11 17:07:35 -07:00
|
|
|
dummy_agent_2 = ConversableAgent(name="user_proxy", human_input_mode="TERMINATE", function_map={"add_num": add_num})
|
2023-08-25 06:50:22 -04:00
|
|
|
messsages = [{"function_call": {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}, "role": "assistant"}]
|
|
|
|
|
|
|
|
# when sender is None, messages is provided
|
|
|
|
assert (
|
|
|
|
dummy_agent_2.generate_reply(messages=messsages, sender=None)["content"] == "15"
|
|
|
|
), "generate_reply not working when sender is None"
|
|
|
|
|
|
|
|
# when sender is provided, messages is None
|
2023-09-11 17:07:35 -07:00
|
|
|
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", human_input_mode="ALWAYS")
|
2023-08-25 06:50:22 -04:00
|
|
|
dummy_agent_2._oai_messages[dummy_agent_1] = messsages
|
|
|
|
assert (
|
|
|
|
dummy_agent_2.generate_reply(messages=None, sender=dummy_agent_1)["content"] == "15"
|
|
|
|
), "generate_reply not working when messages is None"
|
|
|
|
|
|
|
|
|
2023-07-06 06:08:44 +08:00
|
|
|
if __name__ == "__main__":
|
2023-08-07 11:41:58 -07:00
|
|
|
test_trigger()
|
|
|
|
# test_context()
|
2023-08-03 02:17:20 -07:00
|
|
|
# test_max_consecutive_auto_reply()
|
2023-09-11 17:07:35 -07:00
|
|
|
# test_conversable_agent(pytest.monkeypatch)
|