mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-29 16:09:07 +00:00
Add model_context to SelectorGroupChat for enhanced speaker selection (#6330)
## Why are these changes needed? This PR enhances the `SelectorGroupChat` class by introducing a new `model_context` parameter to support more context-aware speaker selection. ### Changes - Added a `model_context: ChatCompletionContext | None` parameter to `SelectorGroupChat`. - Defaulted to `UnboundedChatCompletionContext` when None is provided like `AssistantAgent`. - Updated `_select_speaker` to prepend context messages from `model_context` to the main thread history. - Refactored history construction into a helper method `construct_message_history`. ## Related issue number Closes [Issue #6301](https://github.com/org/repo/issues/6301), enabling the group chat manager to utilize `model_context` for richer, more informed speaker selection decisions. ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Signed-off-by: Abhijeetsingh Meena <abhijeet040403@gmail.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
085ff3dd7d
commit
2864fbfc2c
@ -115,7 +115,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
)
|
||||
|
||||
# Append all messages to thread
|
||||
self._message_thread.extend(message.messages)
|
||||
await self.update_message_thread(message.messages)
|
||||
|
||||
# Check termination condition after processing all messages
|
||||
if await self._apply_termination_condition(message.messages):
|
||||
@ -139,6 +139,9 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
self._message_thread.extend(messages)
|
||||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
|
||||
try:
|
||||
@ -146,10 +149,9 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.agent_response.inner_messages is not None:
|
||||
for inner_message in message.agent_response.inner_messages:
|
||||
self._message_thread.append(inner_message)
|
||||
delta.append(inner_message)
|
||||
self._message_thread.append(message.agent_response.chat_message)
|
||||
delta.append(message.agent_response.chat_message)
|
||||
await self.update_message_thread(delta)
|
||||
|
||||
# Check if the conversation should be terminated.
|
||||
if await self._apply_termination_condition(delta, increment_turn_count=True):
|
||||
|
||||
@ -191,7 +191,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
if message.agent_response.inner_messages is not None:
|
||||
for inner_message in message.agent_response.inner_messages:
|
||||
delta.append(inner_message)
|
||||
self._message_thread.append(message.agent_response.chat_message)
|
||||
await self.update_message_thread([message.agent_response.chat_message])
|
||||
delta.append(message.agent_response.chat_message)
|
||||
|
||||
if self._termination_condition is not None:
|
||||
@ -263,7 +263,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
)
|
||||
|
||||
# Save my copy
|
||||
self._message_thread.append(ledger_message)
|
||||
await self.update_message_thread([ledger_message])
|
||||
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
@ -376,7 +376,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
|
||||
# Broadcast the next step
|
||||
message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name)
|
||||
self._message_thread.append(message) # My copy
|
||||
await self.update_message_thread([message]) # My copy
|
||||
|
||||
await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}")
|
||||
# Log it to the output topic.
|
||||
@ -458,7 +458,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
assert isinstance(response.content, str)
|
||||
message = TextMessage(content=response.content, source=self._name)
|
||||
|
||||
self._message_thread.append(message) # My copy
|
||||
await self.update_message_thread([message]) # My copy
|
||||
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
|
||||
@ -4,11 +4,16 @@ import re
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
from autogen_core import AgentRuntime, Component, ComponentModel
|
||||
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
|
||||
from autogen_core.model_context import (
|
||||
ChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
@ -22,6 +27,7 @@ from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MessageFactory,
|
||||
ModelClientStreamingChunkEvent,
|
||||
SelectorEvent,
|
||||
@ -65,6 +71,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
max_selector_attempts: int,
|
||||
candidate_func: Optional[CandidateFuncType],
|
||||
emit_team_events: bool,
|
||||
model_context: ChatCompletionContext | None,
|
||||
model_client_streaming: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -90,6 +97,11 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
self._candidate_func = candidate_func
|
||||
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
|
||||
self._model_client_streaming = model_client_streaming
|
||||
if model_context is not None:
|
||||
self._model_context = model_context
|
||||
else:
|
||||
self._model_context = UnboundedChatCompletionContext()
|
||||
self._cancellation_token = CancellationToken()
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
@ -97,6 +109,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
async def reset(self) -> None:
|
||||
self._current_turn = 0
|
||||
self._message_thread.clear()
|
||||
await self._model_context.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._previous_speaker = None
|
||||
@ -112,16 +125,37 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
selector_state = SelectorManagerState.model_validate(state)
|
||||
self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread]
|
||||
await self._add_messages_to_context(
|
||||
self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)]
|
||||
)
|
||||
self._current_turn = selector_state.current_turn
|
||||
self._previous_speaker = selector_state.previous_speaker
|
||||
|
||||
@staticmethod
|
||||
async def _add_messages_to_context(
|
||||
model_context: ChatCompletionContext,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
) -> None:
|
||||
"""
|
||||
Add incoming messages to the model context.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HandoffMessage):
|
||||
for llm_msg in msg.context:
|
||||
await model_context.add_message(llm_msg)
|
||||
await model_context.add_message(msg.to_model_message())
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
self._message_thread.extend(messages)
|
||||
base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)]
|
||||
await self._add_messages_to_context(self._model_context, base_chat_messages)
|
||||
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
||||
with the selector function as override if it returns a speaker name.
|
||||
|
||||
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
|
||||
"""
|
||||
|
||||
# Use the selector function if provided.
|
||||
if self._selector_func is not None:
|
||||
if self._is_selector_func_async:
|
||||
@ -163,18 +197,6 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
|
||||
assert len(participants) > 0
|
||||
|
||||
# Construct the history of the conversation.
|
||||
history_messages: List[str] = []
|
||||
for msg in thread:
|
||||
if not isinstance(msg, BaseChatMessage):
|
||||
# Only process chat messages.
|
||||
continue
|
||||
message = f"{msg.source}: {msg.to_model_text()}"
|
||||
history_messages.append(
|
||||
message.rstrip() + "\n\n"
|
||||
) # Create some consistency for how messages are separated in the transcript
|
||||
history = "\n".join(history_messages)
|
||||
|
||||
# Construct agent roles.
|
||||
# Each agent sould appear on a single line.
|
||||
roles = ""
|
||||
@ -184,17 +206,34 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
|
||||
# Select the next speaker.
|
||||
if len(participants) > 1:
|
||||
agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts)
|
||||
agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts)
|
||||
else:
|
||||
agent_name = participants[0]
|
||||
self._previous_speaker = agent_name
|
||||
trace_logger.debug(f"Selected speaker: {agent_name}")
|
||||
return agent_name
|
||||
|
||||
async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str:
|
||||
def construct_message_history(self, message_history: List[LLMMessage]) -> str:
|
||||
# Construct the history of the conversation.
|
||||
history_messages: List[str] = []
|
||||
for msg in message_history:
|
||||
if isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage):
|
||||
message = f"{msg.source}: {msg.content}"
|
||||
history_messages.append(
|
||||
message.rstrip() + "\n\n"
|
||||
) # Create some consistency for how messages are separated in the transcript
|
||||
|
||||
history: str = "\n".join(history_messages)
|
||||
return history
|
||||
|
||||
async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str:
|
||||
model_context_messages = await self._model_context.get_messages()
|
||||
model_context_history = self.construct_message_history(model_context_messages)
|
||||
|
||||
select_speaker_prompt = self._selector_prompt.format(
|
||||
roles=roles, participants=str(participants), history=history
|
||||
roles=roles, participants=str(participants), history=model_context_history
|
||||
)
|
||||
|
||||
select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
|
||||
if ModelFamily.is_openai(self._model_client.model_info["family"]):
|
||||
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||
@ -312,6 +351,7 @@ class SelectorGroupChatConfig(BaseModel):
|
||||
max_selector_attempts: int = 3
|
||||
emit_team_events: bool = False
|
||||
model_client_streaming: bool = False
|
||||
model_context: ComponentModel | None = None
|
||||
|
||||
|
||||
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||
@ -349,6 +389,8 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||
Make sure your custom message types are subclasses of :class:`~autogen_agentchat.messages.BaseAgentEvent` or :class:`~autogen_agentchat.messages.BaseChatMessage`.
|
||||
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
|
||||
model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False.
|
||||
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving
|
||||
:class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
|
||||
@ -463,6 +505,64 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||
await Console(team.run_stream(task="What is 1 + 1?"))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team with custom model context:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from autogen_core.model_context import BufferedChatCompletionContext
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.conditions import TextMentionTermination
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
model_context = BufferedChatCompletionContext(buffer_size=5)
|
||||
|
||||
async def lookup_hotel(location: str) -> str:
|
||||
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
|
||||
|
||||
async def lookup_flight(origin: str, destination: str) -> str:
|
||||
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
|
||||
|
||||
async def book_trip() -> str:
|
||||
return "Your trip is booked!"
|
||||
|
||||
travel_advisor = AssistantAgent(
|
||||
"Travel_Advisor",
|
||||
model_client,
|
||||
tools=[book_trip],
|
||||
description="Helps with travel planning.",
|
||||
)
|
||||
hotel_agent = AssistantAgent(
|
||||
"Hotel_Agent",
|
||||
model_client,
|
||||
tools=[lookup_hotel],
|
||||
description="Helps with hotel booking.",
|
||||
)
|
||||
flight_agent = AssistantAgent(
|
||||
"Flight_Agent",
|
||||
model_client,
|
||||
tools=[lookup_flight],
|
||||
description="Helps with flight booking.",
|
||||
)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = SelectorGroupChat(
|
||||
[travel_advisor, hotel_agent, flight_agent],
|
||||
model_client=model_client,
|
||||
termination_condition=termination,
|
||||
model_context=model_context,
|
||||
)
|
||||
await Console(team.run_stream(task="Book a 3-day trip to new york."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
@ -492,6 +592,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
model_client_streaming: bool = False,
|
||||
model_context: ChatCompletionContext | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
participants,
|
||||
@ -513,6 +614,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
self._max_selector_attempts = max_selector_attempts
|
||||
self._candidate_func = candidate_func
|
||||
self._model_client_streaming = model_client_streaming
|
||||
self._model_context = model_context
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
@ -545,6 +647,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
self._max_selector_attempts,
|
||||
self._candidate_func,
|
||||
self._emit_team_events,
|
||||
self._model_context,
|
||||
self._model_client_streaming,
|
||||
)
|
||||
|
||||
@ -560,6 +663,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
|
||||
emit_team_events=self._emit_team_events,
|
||||
model_client_streaming=self._model_client_streaming,
|
||||
model_context=self._model_context.dump_component() if self._model_context else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -579,4 +683,5 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
# else None,
|
||||
emit_team_events=config.emit_team_events,
|
||||
model_client_streaming=config.model_client_streaming,
|
||||
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
|
||||
)
|
||||
|
||||
@ -2,10 +2,28 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Sequence
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime
|
||||
from autogen_core.model_context import BufferedChatCompletionContext
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.tools import FunctionTool
|
||||
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from utils import FileLogHandler
|
||||
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import (
|
||||
AssistantAgent,
|
||||
@ -39,22 +57,6 @@ from autogen_agentchat.teams._group_chat._round_robin_group_chat import RoundRob
|
||||
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
|
||||
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.tools import FunctionTool
|
||||
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from utils import FileLogHandler
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@ -698,6 +700,60 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
assert result2 == result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_with_model_context(runtime: AgentRuntime | None) -> None:
|
||||
buffered_context = BufferedChatCompletionContext(buffer_size=5)
|
||||
await buffered_context.add_message(UserMessage(content="[User] Prefilled message", source="user"))
|
||||
|
||||
selector_group_chat_model_client = ReplayChatCompletionClient(
|
||||
["agent2", "agent1", "agent1", "agent2", "agent1", "agent2", "agent1"]
|
||||
)
|
||||
agent_one_model_client = ReplayChatCompletionClient(
|
||||
["[Agent One] First generation", "[Agent One] Second generation", "[Agent One] Third generation", "TERMINATE"]
|
||||
)
|
||||
agent_two_model_client = ReplayChatCompletionClient(
|
||||
["[Agent Two] First generation", "[Agent Two] Second generation", "[Agent Two] Third generation"]
|
||||
)
|
||||
|
||||
agent1 = AssistantAgent("agent1", model_client=agent_one_model_client, description="Assistant agent 1")
|
||||
agent2 = AssistantAgent("agent2", model_client=agent_two_model_client, description="Assistant agent 2")
|
||||
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = SelectorGroupChat(
|
||||
participants=[agent1, agent2],
|
||||
model_client=selector_group_chat_model_client,
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
emit_team_events=True,
|
||||
allow_repeated_speaker=True,
|
||||
model_context=buffered_context,
|
||||
)
|
||||
await team.run(
|
||||
task="[GroupChat] Task",
|
||||
)
|
||||
|
||||
messages_to_check = [
|
||||
"user: [User] Prefilled message",
|
||||
"user: [GroupChat] Task",
|
||||
"agent2: [Agent Two] First generation",
|
||||
"agent1: [Agent One] First generation",
|
||||
"agent1: [Agent One] Second generation",
|
||||
"agent2: [Agent Two] Second generation",
|
||||
"agent1: [Agent One] Third generation",
|
||||
"agent2: [Agent Two] Third generation",
|
||||
]
|
||||
|
||||
create_calls: List[Dict[str, Any]] = selector_group_chat_model_client.create_calls
|
||||
for idx, call in enumerate(create_calls):
|
||||
messages = call["messages"]
|
||||
prompt = messages[0].content
|
||||
prompt_lines = prompt.split("\n")
|
||||
chat_history = [value for value in messages_to_check[max(0, idx - 3) : idx + 2]]
|
||||
assert all(
|
||||
line.strip() in prompt_lines for line in chat_history
|
||||
), f"Expected all lines {chat_history} to be in prompt, but got {prompt_lines}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_with_team_event(runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user