mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-24 21:49:42 +00:00
Implement 'candidate_func' parameter to filter down the pool of candidates for selection (#5954)
## Summary of Changes - Added 'candidate_func' to 'SelectorGroupChat' to narrow-down the pool of candidate speakers. - Introduced a test in tests/test_group_chat_endpoint.py to validate its functionality. - Updated the selector group chat user guide with an example demonstrating 'candidate_func'. ## Why are these changes needed? - These changes adds a new parameter `candidate_func` to `SelectorGroupChat` that helps user narrow-down the set of agents for speaker selection, allowing users to automatically select next speaker from a smaller pool of agents. ## Related issue number Closes #5828 ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Signed-off-by: Abhijeetsingh Meena <abhijeet040403@gmail.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
8f8ee0478a
commit
c4e07e86d8
@ -45,6 +45,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
allow_repeated_speaker: bool,
|
||||
selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None,
|
||||
max_selector_attempts: int,
|
||||
candidate_func: Callable[[Sequence[AgentEvent | ChatMessage]], List[str]] | None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
@ -63,6 +64,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
self._selector_func = selector_func
|
||||
self._max_selector_attempts = max_selector_attempts
|
||||
self._candidate_func = candidate_func
|
||||
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
pass
|
||||
@ -107,6 +109,25 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
# Skip the model based selection.
|
||||
return speaker
|
||||
|
||||
# Use the candidate function to filter participants if provided
|
||||
if self._candidate_func is not None:
|
||||
participants = self._candidate_func(thread)
|
||||
if not participants:
|
||||
raise ValueError("Candidate function must return a non-empty list of participant names.")
|
||||
if not all(p in self._participant_names for p in participants):
|
||||
raise ValueError(
|
||||
f"Candidate function returned invalid participant names: {participants}. "
|
||||
f"Expected one of: {self._participant_names}."
|
||||
)
|
||||
else:
|
||||
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
|
||||
if self._previous_speaker is not None and not self._allow_repeated_speaker:
|
||||
participants = [p for p in self._participant_names if p != self._previous_speaker]
|
||||
else:
|
||||
participants = list(self._participant_names)
|
||||
|
||||
assert len(participants) > 0
|
||||
|
||||
# Construct the history of the conversation.
|
||||
history_messages: List[str] = []
|
||||
for msg in thread:
|
||||
@ -136,13 +157,6 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
|
||||
roles = roles.strip()
|
||||
|
||||
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
|
||||
if self._previous_speaker is not None and not self._allow_repeated_speaker:
|
||||
participants = [p for p in self._participant_names if p != self._previous_speaker]
|
||||
else:
|
||||
participants = list(self._participant_names)
|
||||
assert len(participants) > 0
|
||||
|
||||
# Select the next speaker.
|
||||
if len(participants) > 1:
|
||||
agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts)
|
||||
@ -277,6 +291,10 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||
function that takes the conversation history and returns the name of the next speaker.
|
||||
If provided, this function will be used to override the model to select the next speaker.
|
||||
If the function returns None, the model will be used to select the next speaker.
|
||||
candidate_func (Callable[[Sequence[AgentEvent | ChatMessage]], List[str]], optional):
|
||||
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
|
||||
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
|
||||
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
|
||||
|
||||
|
||||
Raises:
|
||||
@ -417,6 +435,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
allow_repeated_speaker: bool = False,
|
||||
max_selector_attempts: int = 3,
|
||||
selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None = None,
|
||||
candidate_func: Callable[[Sequence[AgentEvent | ChatMessage]], List[str]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
participants,
|
||||
@ -434,6 +453,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
self._selector_func = selector_func
|
||||
self._max_selector_attempts = max_selector_attempts
|
||||
self._candidate_func = candidate_func
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
@ -462,6 +482,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
self._allow_repeated_speaker,
|
||||
self._selector_func,
|
||||
self._max_selector_attempts,
|
||||
self._candidate_func,
|
||||
)
|
||||
|
||||
def _to_config(self) -> SelectorGroupChatConfig:
|
||||
|
||||
@ -725,6 +725,47 @@ async def test_selector_group_chat_custom_selector(runtime: AgentRuntime | None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_custom_candidate_func(runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient(["agent3"])
|
||||
agent1 = _EchoAgent("agent1", description="echo agent 1")
|
||||
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
||||
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
||||
agent4 = _EchoAgent("agent4", description="echo agent 4")
|
||||
|
||||
def _candidate_func(messages: Sequence[AgentEvent | ChatMessage]) -> List[str]:
|
||||
if len(messages) == 0:
|
||||
return ["agent1"]
|
||||
elif messages[-1].source == "agent1":
|
||||
return ["agent2"]
|
||||
elif messages[-1].source == "agent2":
|
||||
return ["agent2", "agent3"] # will generate agent3
|
||||
elif messages[-1].source == "agent3":
|
||||
return ["agent4"]
|
||||
else:
|
||||
return ["agent1"]
|
||||
|
||||
termination = MaxMessageTermination(6)
|
||||
team = SelectorGroupChat(
|
||||
participants=[agent1, agent2, agent3, agent4],
|
||||
model_client=model_client,
|
||||
candidate_func=_candidate_func,
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
)
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 6
|
||||
assert result.messages[1].source == "agent1"
|
||||
assert result.messages[2].source == "agent2"
|
||||
assert result.messages[3].source == "agent3"
|
||||
assert result.messages[4].source == "agent4"
|
||||
assert result.messages[5].source == "agent1"
|
||||
assert (
|
||||
result.stop_reason is not None
|
||||
and result.stop_reason == "Maximum number of messages 6 reached, current message count: 6"
|
||||
)
|
||||
|
||||
|
||||
class _HandOffAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str, next_agent: str) -> None:
|
||||
super().__init__(name, description)
|
||||
|
||||
@ -1,7 +1,13 @@
|
||||
import os
|
||||
from typing import List, Sequence
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
)
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
@ -27,6 +33,51 @@ async def _test_selector_group_chat(model_client: ChatCompletionClient) -> None:
|
||||
await Console(team.run_stream(task="Draft a short email about organizing a holiday party for new year."))
|
||||
|
||||
|
||||
async def _test_selector_group_chat_with_candidate_func(model_client: ChatCompletionClient) -> None:
|
||||
filtered_participants = ["developer", "tester"]
|
||||
|
||||
def dummy_candidate_func(thread: Sequence[AgentEvent | ChatMessage]) -> List[str]:
|
||||
# Dummy candidate function that will return
|
||||
# only return developer and reviewer
|
||||
return filtered_participants
|
||||
|
||||
developer = AssistantAgent(
|
||||
"developer",
|
||||
description="Writes and implements code based on requirements.",
|
||||
model_client=model_client,
|
||||
system_message="You are a software developer working on a new feature.",
|
||||
)
|
||||
|
||||
tester = AssistantAgent(
|
||||
"tester",
|
||||
description="Writes and executes test cases to validate the implementation.",
|
||||
model_client=model_client,
|
||||
system_message="You are a software tester ensuring the feature works correctly.",
|
||||
)
|
||||
|
||||
project_manager = AssistantAgent(
|
||||
"project_manager",
|
||||
description="Oversees the project and ensures alignment with the broader goals.",
|
||||
model_client=model_client,
|
||||
system_message="You are a project manager ensuring the team meets the project goals.",
|
||||
)
|
||||
|
||||
team = SelectorGroupChat(
|
||||
participants=[developer, tester, project_manager],
|
||||
model_client=model_client,
|
||||
max_turns=3,
|
||||
candidate_func=dummy_candidate_func,
|
||||
)
|
||||
|
||||
task = "Create a detailed implementation plan for adding dark mode in a React app and review it for feasibility and improvements."
|
||||
|
||||
async for message in team.run_stream(task=task):
|
||||
if not isinstance(message, TaskResult):
|
||||
if message.source == "user": # ignore the first 'user' message
|
||||
continue
|
||||
assert message.source in filtered_participants, "Candidate function didn't filter the participants"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_gemini() -> None:
|
||||
try:
|
||||
@ -39,6 +90,7 @@ async def test_selector_group_chat_gemini() -> None:
|
||||
api_key=api_key,
|
||||
)
|
||||
await _test_selector_group_chat(model_client)
|
||||
await _test_selector_group_chat_with_candidate_func(model_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -53,3 +105,4 @@ async def test_selector_group_chat_openai() -> None:
|
||||
api_key=api_key,
|
||||
)
|
||||
await _test_selector_group_chat(model_client)
|
||||
await _test_selector_group_chat_with_candidate_func(model_client)
|
||||
|
||||
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user