Implement 'candidate_func' parameter to filter down the pool of candidates for selection (#5954)

## Summary of Changes
- Added 'candidate_func' to 'SelectorGroupChat' to narrow-down the pool
of candidate speakers.
- Introduced a test in tests/test_group_chat_endpoint.py to validate its
functionality.
- Updated the selector group chat user guide with an example
demonstrating 'candidate_func'.

## Why are these changes needed?
- These changes adds a new parameter `candidate_func` to
`SelectorGroupChat` that helps user narrow-down the set of agents for
speaker selection, allowing users to automatically select next speaker
from a smaller pool of agents.

## Related issue number
Closes #5828

## Checks
- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.

---------

Signed-off-by: Abhijeetsingh Meena <abhijeet040403@gmail.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Abhijeetsingh Meena 2025-03-18 02:33:25 +05:30 committed by GitHub
parent 8f8ee0478a
commit c4e07e86d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 279 additions and 8 deletions

View File

@ -45,6 +45,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
allow_repeated_speaker: bool,
selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None,
max_selector_attempts: int,
candidate_func: Callable[[Sequence[AgentEvent | ChatMessage]], List[str]] | None,
) -> None:
super().__init__(
name,
@ -63,6 +64,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
pass
@ -107,6 +109,25 @@ class SelectorGroupChatManager(BaseGroupChatManager):
# Skip the model based selection.
return speaker
# Use the candidate function to filter participants if provided
if self._candidate_func is not None:
participants = self._candidate_func(thread)
if not participants:
raise ValueError("Candidate function must return a non-empty list of participant names.")
if not all(p in self._participant_names for p in participants):
raise ValueError(
f"Candidate function returned invalid participant names: {participants}. "
f"Expected one of: {self._participant_names}."
)
else:
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
if self._previous_speaker is not None and not self._allow_repeated_speaker:
participants = [p for p in self._participant_names if p != self._previous_speaker]
else:
participants = list(self._participant_names)
assert len(participants) > 0
# Construct the history of the conversation.
history_messages: List[str] = []
for msg in thread:
@ -136,13 +157,6 @@ class SelectorGroupChatManager(BaseGroupChatManager):
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
roles = roles.strip()
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
if self._previous_speaker is not None and not self._allow_repeated_speaker:
participants = [p for p in self._participant_names if p != self._previous_speaker]
else:
participants = list(self._participant_names)
assert len(participants) > 0
# Select the next speaker.
if len(participants) > 1:
agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts)
@ -277,6 +291,10 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
function that takes the conversation history and returns the name of the next speaker.
If provided, this function will be used to override the model to select the next speaker.
If the function returns None, the model will be used to select the next speaker.
candidate_func (Callable[[Sequence[AgentEvent | ChatMessage]], List[str]], optional):
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
Raises:
@ -417,6 +435,7 @@ Read the above conversation. Then select the next role from {participants} to pl
allow_repeated_speaker: bool = False,
max_selector_attempts: int = 3,
selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None = None,
candidate_func: Callable[[Sequence[AgentEvent | ChatMessage]], List[str]] | None = None,
):
super().__init__(
participants,
@ -434,6 +453,7 @@ Read the above conversation. Then select the next role from {participants} to pl
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
def _create_group_chat_manager_factory(
self,
@ -462,6 +482,7 @@ Read the above conversation. Then select the next role from {participants} to pl
self._allow_repeated_speaker,
self._selector_func,
self._max_selector_attempts,
self._candidate_func,
)
def _to_config(self) -> SelectorGroupChatConfig:

View File

@ -725,6 +725,47 @@ async def test_selector_group_chat_custom_selector(runtime: AgentRuntime | None)
)
@pytest.mark.asyncio
async def test_selector_group_chat_custom_candidate_func(runtime: AgentRuntime | None) -> None:
model_client = ReplayChatCompletionClient(["agent3"])
agent1 = _EchoAgent("agent1", description="echo agent 1")
agent2 = _EchoAgent("agent2", description="echo agent 2")
agent3 = _EchoAgent("agent3", description="echo agent 3")
agent4 = _EchoAgent("agent4", description="echo agent 4")
def _candidate_func(messages: Sequence[AgentEvent | ChatMessage]) -> List[str]:
if len(messages) == 0:
return ["agent1"]
elif messages[-1].source == "agent1":
return ["agent2"]
elif messages[-1].source == "agent2":
return ["agent2", "agent3"] # will generate agent3
elif messages[-1].source == "agent3":
return ["agent4"]
else:
return ["agent1"]
termination = MaxMessageTermination(6)
team = SelectorGroupChat(
participants=[agent1, agent2, agent3, agent4],
model_client=model_client,
candidate_func=_candidate_func,
termination_condition=termination,
runtime=runtime,
)
result = await team.run(task="task")
assert len(result.messages) == 6
assert result.messages[1].source == "agent1"
assert result.messages[2].source == "agent2"
assert result.messages[3].source == "agent3"
assert result.messages[4].source == "agent4"
assert result.messages[5].source == "agent1"
assert (
result.stop_reason is not None
and result.stop_reason == "Maximum number of messages 6 reached, current message count: 6"
)
class _HandOffAgent(BaseChatAgent):
def __init__(self, name: str, description: str, next_agent: str) -> None:
super().__init__(name, description)

View File

@ -1,7 +1,13 @@
import os
from typing import List, Sequence
import pytest
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
)
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console
from autogen_core.models import ChatCompletionClient
@ -27,6 +33,51 @@ async def _test_selector_group_chat(model_client: ChatCompletionClient) -> None:
await Console(team.run_stream(task="Draft a short email about organizing a holiday party for new year."))
async def _test_selector_group_chat_with_candidate_func(model_client: ChatCompletionClient) -> None:
filtered_participants = ["developer", "tester"]
def dummy_candidate_func(thread: Sequence[AgentEvent | ChatMessage]) -> List[str]:
# Dummy candidate function that will return
# only return developer and reviewer
return filtered_participants
developer = AssistantAgent(
"developer",
description="Writes and implements code based on requirements.",
model_client=model_client,
system_message="You are a software developer working on a new feature.",
)
tester = AssistantAgent(
"tester",
description="Writes and executes test cases to validate the implementation.",
model_client=model_client,
system_message="You are a software tester ensuring the feature works correctly.",
)
project_manager = AssistantAgent(
"project_manager",
description="Oversees the project and ensures alignment with the broader goals.",
model_client=model_client,
system_message="You are a project manager ensuring the team meets the project goals.",
)
team = SelectorGroupChat(
participants=[developer, tester, project_manager],
model_client=model_client,
max_turns=3,
candidate_func=dummy_candidate_func,
)
task = "Create a detailed implementation plan for adding dark mode in a React app and review it for feasibility and improvements."
async for message in team.run_stream(task=task):
if not isinstance(message, TaskResult):
if message.source == "user": # ignore the first 'user' message
continue
assert message.source in filtered_participants, "Candidate function didn't filter the participants"
@pytest.mark.asyncio
async def test_selector_group_chat_gemini() -> None:
try:
@ -39,6 +90,7 @@ async def test_selector_group_chat_gemini() -> None:
api_key=api_key,
)
await _test_selector_group_chat(model_client)
await _test_selector_group_chat_with_candidate_func(model_client)
@pytest.mark.asyncio
@ -53,3 +105,4 @@ async def test_selector_group_chat_openai() -> None:
api_key=api_key,
)
await _test_selector_group_chat(model_client)
await _test_selector_group_chat_with_candidate_func(model_client)