mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-31 01:40:58 +00:00 
			
		
		
		
	 f113c9a959
			
		
	
	
		f113c9a959
		
			
		
	
	
	
	
		
			
			* Move core samples to /python/samples * Fix proto check * Add sample code check workflow * Update pyright settings; fix types
		
			
				
	
	
		
			215 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			215 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import asyncio
 | |
| import random
 | |
| from typing import Awaitable, Callable, List
 | |
| from uuid import uuid4
 | |
| 
 | |
| from _types import GroupChatMessage, MessageChunk, RequestToSpeak, UIAgentConfig
 | |
| from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, message_handler
 | |
| from autogen_core.models import (
 | |
|     AssistantMessage,
 | |
|     ChatCompletionClient,
 | |
|     LLMMessage,
 | |
|     SystemMessage,
 | |
|     UserMessage,
 | |
| )
 | |
| from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
 | |
| from rich.console import Console
 | |
| from rich.markdown import Markdown
 | |
| 
 | |
| 
 | |
| class BaseGroupChatAgent(RoutedAgent):
 | |
|     """A group chat participant using an LLM."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         description: str,
 | |
|         group_chat_topic_type: str,
 | |
|         model_client: ChatCompletionClient,
 | |
|         system_message: str,
 | |
|         ui_config: UIAgentConfig,
 | |
|     ) -> None:
 | |
|         super().__init__(description=description)
 | |
|         self._group_chat_topic_type = group_chat_topic_type
 | |
|         self._model_client = model_client
 | |
|         self._system_message = SystemMessage(content=system_message)
 | |
|         self._chat_history: List[LLMMessage] = []
 | |
|         self._ui_config = ui_config
 | |
|         self.console = Console()
 | |
| 
 | |
|     @message_handler
 | |
|     async def handle_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
 | |
|         self._chat_history.extend(
 | |
|             [
 | |
|                 UserMessage(content=f"Transferred to {message.body.source}", source="system"),  # type: ignore[union-attr]
 | |
|                 message.body,
 | |
|             ]
 | |
|         )
 | |
| 
 | |
|     @message_handler
 | |
|     async def handle_request_to_speak(self, message: RequestToSpeak, ctx: MessageContext) -> None:
 | |
|         self._chat_history.append(
 | |
|             UserMessage(content=f"Transferred to {self.id.type}, adopt the persona immediately.", source="system")
 | |
|         )
 | |
|         completion = await self._model_client.create([self._system_message] + self._chat_history)
 | |
|         assert isinstance(completion.content, str)
 | |
|         self._chat_history.append(AssistantMessage(content=completion.content, source=self.id.type))
 | |
| 
 | |
|         console_message = f"\n{'-'*80}\n**{self.id.type}**: {completion.content}"
 | |
|         self.console.print(Markdown(console_message))
 | |
| 
 | |
|         await publish_message_to_ui_and_backend(
 | |
|             runtime=self,
 | |
|             source=self.id.type,
 | |
|             user_message=completion.content,
 | |
|             ui_config=self._ui_config,
 | |
|             group_chat_topic_type=self._group_chat_topic_type,
 | |
|         )
 | |
| 
 | |
| 
 | |
| class GroupChatManager(RoutedAgent):
 | |
|     def __init__(
 | |
|         self,
 | |
|         model_client: ChatCompletionClient,
 | |
|         participant_topic_types: List[str],
 | |
|         participant_descriptions: List[str],
 | |
|         ui_config: UIAgentConfig,
 | |
|         max_rounds: int = 3,
 | |
|     ) -> None:
 | |
|         super().__init__("Group chat manager")
 | |
|         self._model_client = model_client
 | |
|         self._num_rounds = 0
 | |
|         self._participant_topic_types = participant_topic_types
 | |
|         self._chat_history: List[GroupChatMessage] = []
 | |
|         self._max_rounds = max_rounds
 | |
|         self.console = Console()
 | |
|         self._participant_descriptions = participant_descriptions
 | |
|         self._previous_participant_topic_type: str | None = None
 | |
|         self._ui_config = ui_config
 | |
| 
 | |
|     @message_handler
 | |
|     async def handle_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
 | |
|         assert isinstance(message.body, UserMessage)
 | |
| 
 | |
|         self._chat_history.append(message.body)  # type: ignore[reportargumenttype,arg-type]
 | |
| 
 | |
|         # Format message history.
 | |
|         messages: List[str] = []
 | |
|         for msg in self._chat_history:
 | |
|             if isinstance(msg.content, str):  # type: ignore[attr-defined]
 | |
|                 messages.append(f"{msg.source}: {msg.content}")  # type: ignore[attr-defined]
 | |
|             elif isinstance(msg.content, list):  # type: ignore[attr-defined]
 | |
|                 messages.append(f"{msg.source}: {', '.join(msg.content)}")  # type: ignore[attr-defined,reportUnknownArgumentType]
 | |
|         history = "\n".join(messages)
 | |
|         # Format roles.
 | |
|         roles = "\n".join(
 | |
|             [
 | |
|                 f"{topic_type}: {description}".strip()
 | |
|                 for topic_type, description in zip(
 | |
|                     self._participant_topic_types, self._participant_descriptions, strict=True
 | |
|                 )
 | |
|                 if topic_type != self._previous_participant_topic_type
 | |
|             ]
 | |
|         )
 | |
|         participants = str(
 | |
|             [
 | |
|                 topic_type
 | |
|                 for topic_type in self._participant_topic_types
 | |
|                 if topic_type != self._previous_participant_topic_type
 | |
|             ]
 | |
|         )
 | |
| 
 | |
|         selector_prompt = f"""You are in a role play game. The following roles are available:
 | |
| {roles}.
 | |
| Read the following conversation. Then select the next role from {participants} to play. Only return the role.
 | |
| 
 | |
| {history}
 | |
| 
 | |
| Read the above conversation. Then select the next role from {participants} to play. if you think it's enough talking (for example they have talked for {self._max_rounds} rounds), return 'FINISH'.
 | |
| """
 | |
|         system_message = SystemMessage(content=selector_prompt)
 | |
|         completion = await self._model_client.create([system_message], cancellation_token=ctx.cancellation_token)
 | |
| 
 | |
|         assert isinstance(
 | |
|             completion.content, str
 | |
|         ), f"Completion content must be a string, but is: {type(completion.content)}"
 | |
| 
 | |
|         if completion.content.upper() == "FINISH":
 | |
|             finish_msg = "I think it's enough iterations on the story! Thanks for collaborating!"
 | |
|             manager_message = f"\n{'-'*80}\n Manager ({id(self)}): {finish_msg}"
 | |
|             await publish_message_to_ui(
 | |
|                 runtime=self, source=self.id.type, user_message=finish_msg, ui_config=self._ui_config
 | |
|             )
 | |
|             self.console.print(Markdown(manager_message))
 | |
|             return
 | |
| 
 | |
|         selected_topic_type: str
 | |
|         for topic_type in self._participant_topic_types:
 | |
|             if topic_type.lower() in completion.content.lower():
 | |
|                 selected_topic_type = topic_type
 | |
|                 self._previous_participant_topic_type = selected_topic_type
 | |
|                 self.console.print(
 | |
|                     Markdown(f"\n{'-'*80}\n Manager ({id(self)}): Asking `{selected_topic_type}` to speak")
 | |
|                 )
 | |
|                 await self.publish_message(RequestToSpeak(), DefaultTopicId(type=selected_topic_type))
 | |
|                 return
 | |
|         raise ValueError(f"Invalid role selected: {completion.content}")
 | |
| 
 | |
| 
 | |
| class UIAgent(RoutedAgent):
 | |
|     """Handles UI-related tasks and message processing for the distributed group chat system."""
 | |
| 
 | |
|     def __init__(self, on_message_chunk_func: Callable[[MessageChunk], Awaitable[None]]) -> None:
 | |
|         super().__init__("UI Agent")
 | |
|         self._on_message_chunk_func = on_message_chunk_func
 | |
| 
 | |
|     @message_handler
 | |
|     async def handle_message_chunk(self, message: MessageChunk, ctx: MessageContext) -> None:
 | |
|         await self._on_message_chunk_func(message)
 | |
| 
 | |
| 
 | |
| async def publish_message_to_ui(
 | |
|     runtime: RoutedAgent | GrpcWorkerAgentRuntime,
 | |
|     source: str,
 | |
|     user_message: str,
 | |
|     ui_config: UIAgentConfig,
 | |
| ) -> None:
 | |
|     message_id = str(uuid4())
 | |
|     # Stream the message to UI
 | |
|     message_chunks = (
 | |
|         MessageChunk(message_id=message_id, text=token + " ", author=source, finished=False)
 | |
|         for token in user_message.split()
 | |
|     )
 | |
|     for chunk in message_chunks:
 | |
|         await runtime.publish_message(
 | |
|             chunk,
 | |
|             DefaultTopicId(type=ui_config.topic_type),
 | |
|         )
 | |
|         await asyncio.sleep(random.uniform(ui_config.min_delay, ui_config.max_delay))
 | |
| 
 | |
|     await runtime.publish_message(
 | |
|         MessageChunk(message_id=message_id, text=" ", author=source, finished=True),
 | |
|         DefaultTopicId(type=ui_config.topic_type),
 | |
|     )
 | |
| 
 | |
| 
 | |
| async def publish_message_to_ui_and_backend(
 | |
|     runtime: RoutedAgent | GrpcWorkerAgentRuntime,
 | |
|     source: str,
 | |
|     user_message: str,
 | |
|     ui_config: UIAgentConfig,
 | |
|     group_chat_topic_type: str,
 | |
| ) -> None:
 | |
|     # Publish messages for ui
 | |
|     await publish_message_to_ui(
 | |
|         runtime=runtime,
 | |
|         source=source,
 | |
|         user_message=user_message,
 | |
|         ui_config=ui_config,
 | |
|     )
 | |
| 
 | |
|     # Publish message to backend
 | |
|     await runtime.publish_message(
 | |
|         GroupChatMessage(body=UserMessage(content=user_message, source=source)),
 | |
|         topic_id=DefaultTopicId(type=group_chat_topic_type),
 | |
|     )
 |