Update speaker selector in GroupChat and update some notebooks (#688)

* Add speaker selection methods

* Update groupchat RAG

* Update seed to cache_seed

* Update RetrieveChat notebook

* Update parameter name

* Add test

* Add more tests

* Add mock to test

* Add mock to test

* Fix typo speaking

* Add gracefully exit manual input

* Update round_robin docstring

* Add method checking

* Remove participant roles

* Fix versions in notebooks

* Minimize installation overhead

* Fix missing lower()

* Add comments for try_count 3

* Update warning for n_agents < 3

* Update warning for n_agents < 3

* Add test_n_agents_less_than_3

* Add a function for manual select

* Update version in notebooks

* Fixed bugs that allow speakers to go twice in a row even when allow_repeat_speaker = False

---------

Co-authored-by: Adam Fourney <adamfo@microsoft.com>
This commit is contained in:
Li Jiang 2023-11-17 21:56:11 +08:00 committed by GitHub
parent 3ab8c97eb6
commit 370ebf5e00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 2033 additions and 3721 deletions

View File

@ -40,7 +40,7 @@ jobs:
python -m pip install --upgrade pip wheel python -m pip install --upgrade pip wheel
pip install -e . pip install -e .
python -c "import autogen" python -c "import autogen"
pip install -e. pytest pip install -e. pytest mock
pip uninstall -y openai pip uninstall -y openai
- name: Install unstructured if not windows - name: Install unstructured if not windows
if: matrix.os != 'windows-2019' if: matrix.os != 'windows-2019'

View File

@ -1,5 +1,6 @@
import logging import logging
import sys import sys
import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import re import re
@ -21,6 +22,13 @@ class GroupChat:
When set to True and when a message is a function call suggestion, When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`. in its `function_map`.
- speaker_selection_method: the method for selecting the next speaker. Default is "auto".
Could be any of the following (case insensitive), will raise ValueError if not recognized:
- "auto": the next speaker is selected automatically by LLM.
- "manual": the next speaker is selected manually by user input.
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True.
""" """
agents: List[Agent] agents: List[Agent]
@ -28,6 +36,10 @@ class GroupChat:
max_round: int = 10 max_round: int = 10
admin_name: str = "Admin" admin_name: str = "Admin"
func_call_filter: bool = True func_call_filter: bool = True
speaker_selection_method: str = "auto"
allow_repeat_speaker: bool = True
_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
@property @property
def agent_names(self) -> List[str]: def agent_names(self) -> List[str]:
@ -55,13 +67,61 @@ class GroupChat:
def select_speaker_msg(self, agents: List[Agent]): def select_speaker_msg(self, agents: List[Agent]):
"""Return the message for selecting the next speaker.""" """Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available: return f"""You are in a role play game. The following roles are available:
{self._participant_roles()}. {self._participant_roles(agents)}.
Read the following conversation. Read the following conversation.
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.""" Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
def manual_select_speaker(self, agents: List[Agent]) -> Agent:
"""Manually select the next speaker."""
print("Please select the next speaker from the following list:")
_n_agents = len(agents)
for i in range(_n_agents):
print(f"{i+1}: {agents[i].name}")
try_count = 0
# Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking.
while try_count <= 3:
try_count += 1
if try_count >= 3:
print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
break
try:
i = input("Enter the number of the next speaker (enter nothing or `q` to use auto selection): ")
if i == "" or i == "q":
break
i = int(i)
if i > 0 and i <= _n_agents:
return agents[i - 1]
else:
raise ValueError
except ValueError:
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
return None
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent): def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker.""" """Select the next speaker."""
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
)
agents = self.agents
n_agents = len(agents)
# Warn if GroupChat is underpopulated
if n_agents < 2:
raise ValueError(
f"GroupChat is underpopulated with {n_agents} agents. "
"Please add more agents to the GroupChat or use direct communication instead."
)
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and self.allow_repeat_speaker:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. "
"It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False."
"Or, use direct communication instead."
)
if self.func_call_filter and self.messages and "function_call" in self.messages[-1]: if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
# find agents with the right function_map which contains the function name # find agents with the right function_map which contains the function name
agents = [ agents = [
@ -80,14 +140,20 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
f"No agent can execute the function {self.messages[-1]['name']}. " f"No agent can execute the function {self.messages[-1]['name']}. "
"Please check the function_map of the agents." "Please check the function_map of the agents."
) )
else:
agents = self.agents # remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
# Warn if GroupChat is underpopulated agents = agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]
n_agents = len(agents)
if n_agents < 3: if self.speaker_selection_method.lower() == "manual":
logger.warning( selected_agent = self.manual_select_speaker(agents)
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient." if selected_agent:
) return selected_agent
elif self.speaker_selection_method.lower() == "round_robin":
return self.next_agent(last_speaker, agents)
elif self.speaker_selection_method.lower() == "random":
return random.choice(agents)
# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents)) selector.update_system_message(self.select_speaker_msg(agents))
final, name = selector.generate_oai_reply( final, name = selector.generate_oai_reply(
self.messages self.messages
@ -99,26 +165,31 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
] ]
) )
if not final: if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id # the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents) return self.next_agent(last_speaker, agents)
# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified # If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(name, agents) mentions = self._mentioned_agents(name, agents)
if len(mentions) == 1: if len(mentions) == 1:
name = next(iter(mentions)) name = next(iter(mentions))
else:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
)
# Return the result # Return the result
try: try:
return self.agent_by_name(name) return self.agent_by_name(name)
except ValueError: except ValueError:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. Speaker selection will default to the next speaker in the list. This is because the speaker selection OAI call returned:\n{name}"
)
return self.next_agent(last_speaker, agents) return self.next_agent(last_speaker, agents)
def _participant_roles(self): def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
agents = self.agents
roles = [] roles = []
for agent in self.agents: for agent in agents:
if agent.system_message.strip() == "": if agent.system_message.strip() == "":
logger.warning( logger.warning(
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat." f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -29,10 +29,19 @@
"\n", "\n",
"AutoGen requires `Python>=3.8`. To run this notebook example, please install the [retrievechat] option.\n", "AutoGen requires `Python>=3.8`. To run this notebook example, please install the [retrievechat] option.\n",
"```bash\n", "```bash\n",
"pip install \"pyautogen[retrievechat] flaml[automl] qdrant_client[fastembed]\"\n", "pip install \"pyautogen[retrievechat]~=0.2.0b5\" \"flaml[automl]\" \"qdrant_client[fastembed]\"\n",
"```" "```"
] ]
}, },
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# %pip install \"pyautogen[retrievechat]~=0.2.0b5\" \"flaml[automl]\" \"qdrant_client[fastembed]\""
]
},
{ {
"attachments": {}, "attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
@ -165,7 +174,7 @@
" system_message=\"You are a helpful assistant.\",\n", " system_message=\"You are a helpful assistant.\",\n",
" llm_config={\n", " llm_config={\n",
" \"timeout\": 600,\n", " \"timeout\": 600,\n",
" \"seed\": 42,\n", " \"cache_seed\": 42,\n",
" \"config_list\": config_list,\n", " \"config_list\": config_list,\n",
" },\n", " },\n",
")\n", ")\n",
@ -1224,7 +1233,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.6" "version": "3.10.12"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -24,7 +24,7 @@
"\n", "\n",
"AutoGen requires `Python>=3.8`. To run this notebook example, please install:\n", "AutoGen requires `Python>=3.8`. To run this notebook example, please install:\n",
"```bash\n", "```bash\n",
"pip install pyautogen\n", "pip install \"pyautogen~=0.2.0b5\"\n",
"```" "```"
] ]
}, },
@ -34,7 +34,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# %pip install --quiet pyautogen~=0.1.0" "# %pip install --quiet \"pyautogen~=0.2.0b5\""
] ]
}, },
{ {
@ -85,7 +85,7 @@
"\n", "\n",
"llm_config={\n", "llm_config={\n",
" \"timeout\": 600,\n", " \"timeout\": 600,\n",
" \"seed\": 44, # change the seed for different trials\n", " \"cache_seed\": 44, # change the seed for different trials\n",
" \"config_list\": autogen.config_list_from_json(\n", " \"config_list\": autogen.config_list_from_json(\n",
" \"OAI_CONFIG_LIST\",\n", " \"OAI_CONFIG_LIST\",\n",
" filter_dict={\"model\": [\"gpt-4-32k\"]},\n", " filter_dict={\"model\": [\"gpt-4-32k\"]},\n",

View File

@ -46,6 +46,7 @@ setuptools.setup(
"pre-commit", "pre-commit",
"pytest-asyncio", "pytest-asyncio",
"pytest>=6.1.1", "pytest>=6.1.1",
"mock",
], ],
"blendsearch": ["flaml[blendsearch]"], "blendsearch": ["flaml[blendsearch]"],
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"], "mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],

View File

@ -1,4 +1,6 @@
import pytest import pytest
import mock
import builtins
import autogen import autogen
import json import json
@ -8,7 +10,7 @@ def test_func_call_groupchat():
"alice", "alice",
human_input_mode="NEVER", human_input_mode="NEVER",
llm_config=False, llm_config=False,
default_auto_reply="This is alice sepaking.", default_auto_reply="This is alice speaking.",
) )
agent2 = autogen.ConversableAgent( agent2 = autogen.ConversableAgent(
"bob", "bob",
@ -56,7 +58,7 @@ def test_chat_manager():
max_consecutive_auto_reply=2, max_consecutive_auto_reply=2,
human_input_mode="NEVER", human_input_mode="NEVER",
llm_config=False, llm_config=False,
default_auto_reply="This is alice sepaking.", default_auto_reply="This is alice speaking.",
) )
agent2 = autogen.ConversableAgent( agent2 = autogen.ConversableAgent(
"bob", "bob",
@ -83,6 +85,150 @@ def test_chat_manager():
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}}) agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})
def _test_selection_method(method: str):
agent1 = autogen.ConversableAgent(
"alice",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice speaking.",
)
agent2 = autogen.ConversableAgent(
"bob",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob speaking.",
)
agent3 = autogen.ConversableAgent(
"charlie",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is charlie speaking.",
)
groupchat = autogen.GroupChat(
agents=[agent1, agent2, agent3],
messages=[],
max_round=6,
speaker_selection_method=method,
allow_repeat_speaker=False if method == "manual" else True,
)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
if method == "round_robin":
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
assert len(agent1.chat_messages[group_chat_manager]) == 6
assert len(groupchat.messages) == 6
assert [msg["content"] for msg in agent1.chat_messages[group_chat_manager]] == [
"This is alice speaking.",
"This is bob speaking.",
"This is charlie speaking.",
] * 2
elif method == "auto":
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
assert len(agent1.chat_messages[group_chat_manager]) == 6
assert len(groupchat.messages) == 6
elif method == "random":
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
assert len(agent1.chat_messages[group_chat_manager]) == 6
assert len(groupchat.messages) == 6
elif method == "manual":
for user_input in ["", "q", "x", "1", "10"]:
with mock.patch.object(builtins, "input", lambda _: user_input):
group_chat_manager.reset()
agent1.reset()
agent2.reset()
agent3.reset()
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
if user_input == "1":
assert len(agent1.chat_messages[group_chat_manager]) == 6
assert len(groupchat.messages) == 6
assert [msg["content"] for msg in agent1.chat_messages[group_chat_manager]] == [
"This is alice speaking.",
"This is bob speaking.",
"This is alice speaking.",
"This is bob speaking.",
"This is alice speaking.",
"This is bob speaking.",
]
else:
assert len(agent1.chat_messages[group_chat_manager]) == 6
assert len(groupchat.messages) == 6
elif method == "wrong":
with pytest.raises(ValueError):
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
def test_speaker_selection_method():
for method in ["auto", "round_robin", "random", "manual", "wrong", "RounD_roBin"]:
_test_selection_method(method)
def _test_n_agents_less_than_3(method):
agent1 = autogen.ConversableAgent(
"alice",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice speaking.",
)
agent2 = autogen.ConversableAgent(
"bob",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob speaking.",
)
# test two agents
groupchat = autogen.GroupChat(
agents=[agent1, agent2],
messages=[],
max_round=6,
speaker_selection_method=method,
allow_repeat_speaker=True if method == "random" else False,
)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
assert len(agent1.chat_messages[group_chat_manager]) == 6
assert len(groupchat.messages) == 6
if method != "random" or method.lower() == "round_robin":
assert [msg["content"] for msg in agent1.chat_messages[group_chat_manager]] == [
"This is alice speaking.",
"This is bob speaking.",
] * 3
# test one agent
groupchat = autogen.GroupChat(
agents=[agent1],
messages=[],
max_round=6,
speaker_selection_method="round_robin",
allow_repeat_speaker=False,
)
with pytest.raises(ValueError):
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
# test zero agent
groupchat = autogen.GroupChat(
agents=[],
messages=[],
max_round=6,
speaker_selection_method="round_robin",
allow_repeat_speaker=False,
)
with pytest.raises(ValueError):
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
def test_n_agents_less_than_3():
for method in ["auto", "round_robin", "random", "RounD_roBin"]:
_test_n_agents_less_than_3(method)
def test_plugin(): def test_plugin():
# Give another Agent class ability to manage group chat # Give another Agent class ability to manage group chat
agent1 = autogen.ConversableAgent( agent1 = autogen.ConversableAgent(
@ -90,7 +236,7 @@ def test_plugin():
max_consecutive_auto_reply=2, max_consecutive_auto_reply=2,
human_input_mode="NEVER", human_input_mode="NEVER",
llm_config=False, llm_config=False,
default_auto_reply="This is alice sepaking.", default_auto_reply="This is alice speaking.",
) )
agent2 = autogen.ConversableAgent( agent2 = autogen.ConversableAgent(
"bob", "bob",
@ -185,8 +331,10 @@ def test_agent_mentions():
if __name__ == "__main__": if __name__ == "__main__":
test_func_call_groupchat() # test_func_call_groupchat()
# test_broadcast() # test_broadcast()
test_chat_manager() # test_chat_manager()
# test_plugin() # test_plugin()
test_speaker_selection_method()
test_n_agents_less_than_3()
# test_agent_mentions() # test_agent_mentions()