import logging from typing import Any, Callable, List, Mapping from agnext.components import TypeRoutedAgent, message_handler from agnext.components.memory import ChatMemory from agnext.components.models import ChatCompletionClient from agnext.core import AgentId, AgentProxy, CancellationToken from ..types import ( Message, MultiModalMessage, PublishNow, Reset, TextMessage, ) from ._group_chat_utils import select_speaker logger = logging.getLogger("agnext.events") class GroupChatManager(TypeRoutedAgent): """An agent that manages a group chat through event-driven orchestration. Args: name (str): The name of the agent. description (str): The description of the agent. runtime (AgentRuntime): The runtime to register the agent. participants (List[AgentId]): The list of participants in the group chat. memory (ChatMemory[Message]): The memory to store and retrieve messages. model_client (ChatCompletionClient, optional): The client to use for the model. If provided, the agent will use the model to select the next speaker. If not provided, the agent will select the next speaker from the list of participants according to the order given. termination_word (str, optional): The word that terminates the group chat. Defaults to "TERMINATE". transitions (Mapping[AgentId, List[AgentId]], optional): The transitions between agents. Keys are the agents, and values are the list of agents that can follow the key agent. Defaults to {}. If provided, the group chat manager will use the transitions to select the next speaker. If a transition is not provided for an agent, the choices fallback to all participants. If no model client is provided, a transition must have a single value. on_message_received (Callable[[TextMessage], None], optional): A custom handler to call when a message is received. Defaults to None. """ def __init__( self, description: str, participants: List[AgentId], memory: ChatMemory[Message], model_client: ChatCompletionClient | None = None, termination_word: str = "TERMINATE", transitions: Mapping[AgentId, List[AgentId]] = {}, on_message_received: Callable[[TextMessage | MultiModalMessage], None] | None = None, ): super().__init__(description) self._memory = memory self._client = model_client self._participants = participants self._participant_proxies = dict((p, AgentProxy(p, self.runtime)) for p in participants) self._termination_word = termination_word for key, value in transitions.items(): if not value: # Make sure no empty transitions are provided. raise ValueError(f"Empty transition list provided for {key.name}.") if key not in participants: # Make sure all keys are in the list of participants. raise ValueError(f"Transition key {key.name} not found in participants.") for v in value: if v not in participants: # Make sure all values are in the list of participants. raise ValueError(f"Transition value {v.name} not found in participants.") if self._client is None: # Make sure there is only one transition for each key if no model client is provided. if len(value) > 1: raise ValueError(f"Multiple transitions provided for {key.name} but no model client is provided.") self._tranistions = transitions self._on_message_received = on_message_received @message_handler() async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: """Handle a reset message. This method clears the memory.""" await self._memory.clear() @message_handler() async def on_new_message( self, message: TextMessage | MultiModalMessage, cancellation_token: CancellationToken ) -> None: """Handle a message. This method adds the message to the memory, selects the next speaker, and sends a message to the selected speaker to publish a response.""" # Call the custom on_message_received handler if provided. if self._on_message_received is not None: self._on_message_received(message) # Check if the message contains the termination word. if isinstance(message, TextMessage) and self._termination_word in message.content: # Terminate the group chat by not selecting the next speaker. return # Save the message to chat memory. await self._memory.add_message(message) # Get the last speaker. last_speaker_name = message.source last_speaker_index = next((i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None) # Get the candidates for the next speaker. if last_speaker_index is not None: logger.debug(f"Last speaker: {last_speaker_name}") last_speaker = self._participants[last_speaker_index] if self._tranistions.get(last_speaker) is not None: candidates = [c for c in self._participants if c in self._tranistions[last_speaker]] else: candidates = self._participants else: candidates = self._participants logger.debug(f"Group chat manager next speaker candidates: {[c.name for c in candidates]}") # Select speaker. if len(candidates) == 0: speaker = None elif len(candidates) == 1: speaker = candidates[0] else: # More than one candidate, select the next speaker. if self._client is None: # If no model client is provided, candidates must be the list of participants. assert candidates == self._participants # If no model client is provided, select the next speaker from the list of participants. if last_speaker_index is not None: next_speaker_index = (last_speaker_index + 1) % len(self._participants) speaker = self._participants[next_speaker_index] else: # If no last speaker, select the first speaker. speaker = candidates[0] else: # If a model client is provided, select the speaker based on the transitions and the model. speaker_index = await select_speaker( self._memory, self._client, [self._participant_proxies[c] for c in candidates] ) speaker = candidates[speaker_index] logger.debug(f"Group chat manager selected speaker: {speaker.name if speaker is not None else None}") if speaker is not None: # Send the message to the selected speaker to ask it to publish a response. await self.send_message(PublishNow(), speaker) def save_state(self) -> Mapping[str, Any]: return { "memory": self._memory.save_state(), "termination_word": self._termination_word, } def load_state(self, state: Mapping[str, Any]) -> None: self._memory.load_state(state["memory"]) self._termination_word = state["termination_word"]