mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-31 01:40:58 +00:00 
			
		
		
		
	Update speaker selector in GroupChat and update some notebooks (#688)
* Add speaker selection methods * Update groupchat RAG * Update seed to cache_seed * Update RetrieveChat notebook * Update parameter name * Add test * Add more tests * Add mock to test * Add mock to test * Fix typo speaking * Add gracefully exit manual input * Update round_robin docstring * Add method checking * Remove participant roles * Fix versions in notebooks * Minimize installation overhead * Fix missing lower() * Add comments for try_count 3 * Update warning for n_agents < 3 * Update warning for n_agents < 3 * Add test_n_agents_less_than_3 * Add a function for manual select * Update version in notebooks * Fixed bugs that allow speakers to go twice in a row even when allow_repeat_speaker = False --------- Co-authored-by: Adam Fourney <adamfo@microsoft.com>
This commit is contained in:
		
							parent
							
								
									3ab8c97eb6
								
							
						
					
					
						commit
						370ebf5e00
					
				
							
								
								
									
										2
									
								
								.github/workflows/build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/build.yml
									
									
									
									
										vendored
									
									
								
							| @ -40,7 +40,7 @@ jobs: | ||||
|           python -m pip install --upgrade pip wheel | ||||
|           pip install -e . | ||||
|           python -c "import autogen" | ||||
|           pip install -e. pytest | ||||
|           pip install -e. pytest mock | ||||
|           pip uninstall -y openai | ||||
|       - name: Install unstructured if not windows | ||||
|         if: matrix.os != 'windows-2019' | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| import logging | ||||
| import sys | ||||
| import random | ||||
| from dataclasses import dataclass | ||||
| from typing import Dict, List, Optional, Union | ||||
| import re | ||||
| @ -21,6 +22,13 @@ class GroupChat: | ||||
|         When set to True and when a message is a function call suggestion, | ||||
|         the next speaker will be chosen from an agent which contains the corresponding function name | ||||
|         in its `function_map`. | ||||
|     - speaker_selection_method: the method for selecting the next speaker. Default is "auto". | ||||
|         Could be any of the following (case insensitive), will raise ValueError if not recognized: | ||||
|         - "auto": the next speaker is selected automatically by LLM. | ||||
|         - "manual": the next speaker is selected manually by user input. | ||||
|         - "random": the next speaker is selected randomly. | ||||
|         - "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`. | ||||
|     - allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True. | ||||
|     """ | ||||
| 
 | ||||
|     agents: List[Agent] | ||||
| @ -28,6 +36,10 @@ class GroupChat: | ||||
|     max_round: int = 10 | ||||
|     admin_name: str = "Admin" | ||||
|     func_call_filter: bool = True | ||||
|     speaker_selection_method: str = "auto" | ||||
|     allow_repeat_speaker: bool = True | ||||
| 
 | ||||
|     _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"] | ||||
| 
 | ||||
|     @property | ||||
|     def agent_names(self) -> List[str]: | ||||
| @ -55,13 +67,61 @@ class GroupChat: | ||||
|     def select_speaker_msg(self, agents: List[Agent]): | ||||
|         """Return the message for selecting the next speaker.""" | ||||
|         return f"""You are in a role play game. The following roles are available: | ||||
| {self._participant_roles()}. | ||||
| {self._participant_roles(agents)}. | ||||
| 
 | ||||
| Read the following conversation. | ||||
| Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.""" | ||||
| 
 | ||||
|     def manual_select_speaker(self, agents: List[Agent]) -> Agent: | ||||
|         """Manually select the next speaker.""" | ||||
| 
 | ||||
|         print("Please select the next speaker from the following list:") | ||||
|         _n_agents = len(agents) | ||||
|         for i in range(_n_agents): | ||||
|             print(f"{i+1}: {agents[i].name}") | ||||
|         try_count = 0 | ||||
|         # Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking. | ||||
|         while try_count <= 3: | ||||
|             try_count += 1 | ||||
|             if try_count >= 3: | ||||
|                 print(f"You have tried {try_count} times. The next speaker will be selected automatically.") | ||||
|                 break | ||||
|             try: | ||||
|                 i = input("Enter the number of the next speaker (enter nothing or `q` to use auto selection): ") | ||||
|                 if i == "" or i == "q": | ||||
|                     break | ||||
|                 i = int(i) | ||||
|                 if i > 0 and i <= _n_agents: | ||||
|                     return agents[i - 1] | ||||
|                 else: | ||||
|                     raise ValueError | ||||
|             except ValueError: | ||||
|                 print(f"Invalid input. Please enter a number between 1 and {_n_agents}.") | ||||
|         return None | ||||
| 
 | ||||
|     def select_speaker(self, last_speaker: Agent, selector: ConversableAgent): | ||||
|         """Select the next speaker.""" | ||||
|         if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS: | ||||
|             raise ValueError( | ||||
|                 f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. " | ||||
|                 f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). " | ||||
|             ) | ||||
| 
 | ||||
|         agents = self.agents | ||||
|         n_agents = len(agents) | ||||
|         # Warn if GroupChat is underpopulated | ||||
|         if n_agents < 2: | ||||
|             raise ValueError( | ||||
|                 f"GroupChat is underpopulated with {n_agents} agents. " | ||||
|                 "Please add more agents to the GroupChat or use direct communication instead." | ||||
|             ) | ||||
|         elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and self.allow_repeat_speaker: | ||||
|             logger.warning( | ||||
|                 f"GroupChat is underpopulated with {n_agents} agents. " | ||||
|                 "It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False." | ||||
|                 "Or, use direct communication instead." | ||||
|             ) | ||||
| 
 | ||||
|         if self.func_call_filter and self.messages and "function_call" in self.messages[-1]: | ||||
|             # find agents with the right function_map which contains the function name | ||||
|             agents = [ | ||||
| @ -80,14 +140,20 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only | ||||
|                         f"No agent can execute the function {self.messages[-1]['name']}. " | ||||
|                         "Please check the function_map of the agents." | ||||
|                     ) | ||||
|         else: | ||||
|             agents = self.agents | ||||
|             # Warn if GroupChat is underpopulated | ||||
|             n_agents = len(agents) | ||||
|             if n_agents < 3: | ||||
|                 logger.warning( | ||||
|                     f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient." | ||||
|                 ) | ||||
| 
 | ||||
|         # remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False | ||||
|         agents = agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker] | ||||
| 
 | ||||
|         if self.speaker_selection_method.lower() == "manual": | ||||
|             selected_agent = self.manual_select_speaker(agents) | ||||
|             if selected_agent: | ||||
|                 return selected_agent | ||||
|         elif self.speaker_selection_method.lower() == "round_robin": | ||||
|             return self.next_agent(last_speaker, agents) | ||||
|         elif self.speaker_selection_method.lower() == "random": | ||||
|             return random.choice(agents) | ||||
| 
 | ||||
|         # auto speaker selection | ||||
|         selector.update_system_message(self.select_speaker_msg(agents)) | ||||
|         final, name = selector.generate_oai_reply( | ||||
|             self.messages | ||||
| @ -99,26 +165,31 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only | ||||
|             ] | ||||
|         ) | ||||
|         if not final: | ||||
|             # i = self._random.randint(0, len(self._agent_names) - 1)  # randomly pick an id | ||||
|             # the LLM client is None, thus no reply is generated. Use round robin instead. | ||||
|             return self.next_agent(last_speaker, agents) | ||||
| 
 | ||||
|         # If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified | ||||
|         mentions = self._mentioned_agents(name, agents) | ||||
|         if len(mentions) == 1: | ||||
|             name = next(iter(mentions)) | ||||
|         else: | ||||
|             logger.warning( | ||||
|                 f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}" | ||||
|             ) | ||||
| 
 | ||||
|         # Return the result | ||||
|         try: | ||||
|             return self.agent_by_name(name) | ||||
|         except ValueError: | ||||
|             logger.warning( | ||||
|                 f"GroupChat select_speaker failed to resolve the next speaker's name. Speaker selection will default to the next speaker in the list. This is because the speaker selection OAI call returned:\n{name}" | ||||
|             ) | ||||
|             return self.next_agent(last_speaker, agents) | ||||
| 
 | ||||
|     def _participant_roles(self): | ||||
|     def _participant_roles(self, agents: List[Agent] = None) -> str: | ||||
|         # Default to all agents registered | ||||
|         if agents is None: | ||||
|             agents = self.agents | ||||
| 
 | ||||
|         roles = [] | ||||
|         for agent in self.agents: | ||||
|         for agent in agents: | ||||
|             if agent.system_message.strip() == "": | ||||
|                 logger.warning( | ||||
|                     f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat." | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -29,10 +29,19 @@ | ||||
|     "\n", | ||||
|     "AutoGen requires `Python>=3.8`. To run this notebook example, please install the [retrievechat] option.\n", | ||||
|     "```bash\n", | ||||
|     "pip install \"pyautogen[retrievechat] flaml[automl] qdrant_client[fastembed]\"\n", | ||||
|     "pip install \"pyautogen[retrievechat]~=0.2.0b5\" \"flaml[automl]\" \"qdrant_client[fastembed]\"\n", | ||||
|     "```" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "# %pip install \"pyautogen[retrievechat]~=0.2.0b5\" \"flaml[automl]\" \"qdrant_client[fastembed]\"" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "attachments": {}, | ||||
|    "cell_type": "markdown", | ||||
| @ -165,7 +174,7 @@ | ||||
|     "    system_message=\"You are a helpful assistant.\",\n", | ||||
|     "    llm_config={\n", | ||||
|     "        \"timeout\": 600,\n", | ||||
|     "        \"seed\": 42,\n", | ||||
|     "        \"cache_seed\": 42,\n", | ||||
|     "        \"config_list\": config_list,\n", | ||||
|     "    },\n", | ||||
|     ")\n", | ||||
| @ -1224,7 +1233,7 @@ | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.11.6" | ||||
|    "version": "3.10.12" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  | ||||
| @ -24,7 +24,7 @@ | ||||
|     "\n", | ||||
|     "AutoGen requires `Python>=3.8`. To run this notebook example, please install:\n", | ||||
|     "```bash\n", | ||||
|     "pip install pyautogen\n", | ||||
|     "pip install \"pyautogen~=0.2.0b5\"\n", | ||||
|     "```" | ||||
|    ] | ||||
|   }, | ||||
| @ -34,7 +34,7 @@ | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "# %pip install --quiet pyautogen~=0.1.0" | ||||
|     "# %pip install --quiet \"pyautogen~=0.2.0b5\"" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
| @ -85,7 +85,7 @@ | ||||
|     "\n", | ||||
|     "llm_config={\n", | ||||
|     "    \"timeout\": 600,\n", | ||||
|     "    \"seed\": 44,  # change the seed for different trials\n", | ||||
|     "    \"cache_seed\": 44,  # change the seed for different trials\n", | ||||
|     "    \"config_list\": autogen.config_list_from_json(\n", | ||||
|     "        \"OAI_CONFIG_LIST\",\n", | ||||
|     "        filter_dict={\"model\": [\"gpt-4-32k\"]},\n", | ||||
|  | ||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							| @ -46,6 +46,7 @@ setuptools.setup( | ||||
|             "pre-commit", | ||||
|             "pytest-asyncio", | ||||
|             "pytest>=6.1.1", | ||||
|             "mock", | ||||
|         ], | ||||
|         "blendsearch": ["flaml[blendsearch]"], | ||||
|         "mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"], | ||||
|  | ||||
| @ -1,4 +1,6 @@ | ||||
| import pytest | ||||
| import mock | ||||
| import builtins | ||||
| import autogen | ||||
| import json | ||||
| 
 | ||||
| @ -8,7 +10,7 @@ def test_func_call_groupchat(): | ||||
|         "alice", | ||||
|         human_input_mode="NEVER", | ||||
|         llm_config=False, | ||||
|         default_auto_reply="This is alice sepaking.", | ||||
|         default_auto_reply="This is alice speaking.", | ||||
|     ) | ||||
|     agent2 = autogen.ConversableAgent( | ||||
|         "bob", | ||||
| @ -56,7 +58,7 @@ def test_chat_manager(): | ||||
|         max_consecutive_auto_reply=2, | ||||
|         human_input_mode="NEVER", | ||||
|         llm_config=False, | ||||
|         default_auto_reply="This is alice sepaking.", | ||||
|         default_auto_reply="This is alice speaking.", | ||||
|     ) | ||||
|     agent2 = autogen.ConversableAgent( | ||||
|         "bob", | ||||
| @ -83,6 +85,150 @@ def test_chat_manager(): | ||||
|         agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}}) | ||||
| 
 | ||||
| 
 | ||||
| def _test_selection_method(method: str): | ||||
|     agent1 = autogen.ConversableAgent( | ||||
|         "alice", | ||||
|         max_consecutive_auto_reply=10, | ||||
|         human_input_mode="NEVER", | ||||
|         llm_config=False, | ||||
|         default_auto_reply="This is alice speaking.", | ||||
|     ) | ||||
|     agent2 = autogen.ConversableAgent( | ||||
|         "bob", | ||||
|         max_consecutive_auto_reply=10, | ||||
|         human_input_mode="NEVER", | ||||
|         llm_config=False, | ||||
|         default_auto_reply="This is bob speaking.", | ||||
|     ) | ||||
|     agent3 = autogen.ConversableAgent( | ||||
|         "charlie", | ||||
|         max_consecutive_auto_reply=10, | ||||
|         human_input_mode="NEVER", | ||||
|         llm_config=False, | ||||
|         default_auto_reply="This is charlie speaking.", | ||||
|     ) | ||||
| 
 | ||||
|     groupchat = autogen.GroupChat( | ||||
|         agents=[agent1, agent2, agent3], | ||||
|         messages=[], | ||||
|         max_round=6, | ||||
|         speaker_selection_method=method, | ||||
|         allow_repeat_speaker=False if method == "manual" else True, | ||||
|     ) | ||||
|     group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False) | ||||
| 
 | ||||
|     if method == "round_robin": | ||||
|         agent1.initiate_chat(group_chat_manager, message="This is alice speaking.") | ||||
|         assert len(agent1.chat_messages[group_chat_manager]) == 6 | ||||
|         assert len(groupchat.messages) == 6 | ||||
|         assert [msg["content"] for msg in agent1.chat_messages[group_chat_manager]] == [ | ||||
|             "This is alice speaking.", | ||||
|             "This is bob speaking.", | ||||
|             "This is charlie speaking.", | ||||
|         ] * 2 | ||||
|     elif method == "auto": | ||||
|         agent1.initiate_chat(group_chat_manager, message="This is alice speaking.") | ||||
|         assert len(agent1.chat_messages[group_chat_manager]) == 6 | ||||
|         assert len(groupchat.messages) == 6 | ||||
|     elif method == "random": | ||||
|         agent1.initiate_chat(group_chat_manager, message="This is alice speaking.") | ||||
|         assert len(agent1.chat_messages[group_chat_manager]) == 6 | ||||
|         assert len(groupchat.messages) == 6 | ||||
|     elif method == "manual": | ||||
|         for user_input in ["", "q", "x", "1", "10"]: | ||||
|             with mock.patch.object(builtins, "input", lambda _: user_input): | ||||
|                 group_chat_manager.reset() | ||||
|                 agent1.reset() | ||||
|                 agent2.reset() | ||||
|                 agent3.reset() | ||||
|                 agent1.initiate_chat(group_chat_manager, message="This is alice speaking.") | ||||
|                 if user_input == "1": | ||||
|                     assert len(agent1.chat_messages[group_chat_manager]) == 6 | ||||
|                     assert len(groupchat.messages) == 6 | ||||
|                     assert [msg["content"] for msg in agent1.chat_messages[group_chat_manager]] == [ | ||||
|                         "This is alice speaking.", | ||||
|                         "This is bob speaking.", | ||||
|                         "This is alice speaking.", | ||||
|                         "This is bob speaking.", | ||||
|                         "This is alice speaking.", | ||||
|                         "This is bob speaking.", | ||||
|                     ] | ||||
|                 else: | ||||
|                     assert len(agent1.chat_messages[group_chat_manager]) == 6 | ||||
|                     assert len(groupchat.messages) == 6 | ||||
|     elif method == "wrong": | ||||
|         with pytest.raises(ValueError): | ||||
|             agent1.initiate_chat(group_chat_manager, message="This is alice speaking.") | ||||
| 
 | ||||
| 
 | ||||
| def test_speaker_selection_method(): | ||||
|     for method in ["auto", "round_robin", "random", "manual", "wrong", "RounD_roBin"]: | ||||
|         _test_selection_method(method) | ||||
| 
 | ||||
| 
 | ||||
| def _test_n_agents_less_than_3(method): | ||||
|     agent1 = autogen.ConversableAgent( | ||||
|         "alice", | ||||
|         max_consecutive_auto_reply=10, | ||||
|         human_input_mode="NEVER", | ||||
|         llm_config=False, | ||||
|         default_auto_reply="This is alice speaking.", | ||||
|     ) | ||||
|     agent2 = autogen.ConversableAgent( | ||||
|         "bob", | ||||
|         max_consecutive_auto_reply=10, | ||||
|         human_input_mode="NEVER", | ||||
|         llm_config=False, | ||||
|         default_auto_reply="This is bob speaking.", | ||||
|     ) | ||||
|     # test two agents | ||||
|     groupchat = autogen.GroupChat( | ||||
|         agents=[agent1, agent2], | ||||
|         messages=[], | ||||
|         max_round=6, | ||||
|         speaker_selection_method=method, | ||||
|         allow_repeat_speaker=True if method == "random" else False, | ||||
|     ) | ||||
|     group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False) | ||||
|     agent1.initiate_chat(group_chat_manager, message="This is alice speaking.") | ||||
|     assert len(agent1.chat_messages[group_chat_manager]) == 6 | ||||
|     assert len(groupchat.messages) == 6 | ||||
|     if method != "random" or method.lower() == "round_robin": | ||||
|         assert [msg["content"] for msg in agent1.chat_messages[group_chat_manager]] == [ | ||||
|             "This is alice speaking.", | ||||
|             "This is bob speaking.", | ||||
|         ] * 3 | ||||
| 
 | ||||
|     # test one agent | ||||
|     groupchat = autogen.GroupChat( | ||||
|         agents=[agent1], | ||||
|         messages=[], | ||||
|         max_round=6, | ||||
|         speaker_selection_method="round_robin", | ||||
|         allow_repeat_speaker=False, | ||||
|     ) | ||||
|     with pytest.raises(ValueError): | ||||
|         group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False) | ||||
|         agent1.initiate_chat(group_chat_manager, message="This is alice speaking.") | ||||
| 
 | ||||
|     # test zero agent | ||||
|     groupchat = autogen.GroupChat( | ||||
|         agents=[], | ||||
|         messages=[], | ||||
|         max_round=6, | ||||
|         speaker_selection_method="round_robin", | ||||
|         allow_repeat_speaker=False, | ||||
|     ) | ||||
|     with pytest.raises(ValueError): | ||||
|         group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False) | ||||
|         agent1.initiate_chat(group_chat_manager, message="This is alice speaking.") | ||||
| 
 | ||||
| 
 | ||||
| def test_n_agents_less_than_3(): | ||||
|     for method in ["auto", "round_robin", "random", "RounD_roBin"]: | ||||
|         _test_n_agents_less_than_3(method) | ||||
| 
 | ||||
| 
 | ||||
| def test_plugin(): | ||||
|     # Give another Agent class ability to manage group chat | ||||
|     agent1 = autogen.ConversableAgent( | ||||
| @ -90,7 +236,7 @@ def test_plugin(): | ||||
|         max_consecutive_auto_reply=2, | ||||
|         human_input_mode="NEVER", | ||||
|         llm_config=False, | ||||
|         default_auto_reply="This is alice sepaking.", | ||||
|         default_auto_reply="This is alice speaking.", | ||||
|     ) | ||||
|     agent2 = autogen.ConversableAgent( | ||||
|         "bob", | ||||
| @ -185,8 +331,10 @@ def test_agent_mentions(): | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     test_func_call_groupchat() | ||||
|     # test_func_call_groupchat() | ||||
|     # test_broadcast() | ||||
|     test_chat_manager() | ||||
|     # test_chat_manager() | ||||
|     # test_plugin() | ||||
|     test_speaker_selection_method() | ||||
|     test_n_agents_less_than_3() | ||||
|     # test_agent_mentions() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Li Jiang
						Li Jiang