2024-11-16 03:51:04 +00:00
import asyncio
import random
2024-10-31 11:54:24 +00:00
from typing import Awaitable , Callable , List
2024-11-16 03:51:04 +00:00
from uuid import uuid4
2024-10-28 16:59:58 +00:00
2024-11-16 03:51:04 +00:00
from _types import GroupChatMessage , MessageChunk , RequestToSpeak , UIAgentConfig
2024-12-03 17:00:44 -08:00
from autogen_core import DefaultTopicId , MessageContext , RoutedAgent , message_handler
2024-12-09 13:00:08 -05:00
from autogen_core . models import (
2024-10-28 16:59:58 +00:00
AssistantMessage ,
ChatCompletionClient ,
LLMMessage ,
SystemMessage ,
UserMessage ,
)
2024-12-04 16:23:20 -08:00
from autogen_ext . runtimes . grpc import GrpcWorkerAgentRuntime
2024-10-28 16:59:58 +00:00
from rich . console import Console
from rich . markdown import Markdown
class BaseGroupChatAgent ( RoutedAgent ) :
""" A group chat participant using an LLM. """
def __init__ (
self ,
description : str ,
group_chat_topic_type : str ,
model_client : ChatCompletionClient ,
system_message : str ,
2024-11-16 03:51:04 +00:00
ui_config : UIAgentConfig ,
2024-10-28 16:59:58 +00:00
) - > None :
super ( ) . __init__ ( description = description )
self . _group_chat_topic_type = group_chat_topic_type
self . _model_client = model_client
2024-12-04 16:14:41 -08:00
self . _system_message = SystemMessage ( content = system_message )
2024-10-28 16:59:58 +00:00
self . _chat_history : List [ LLMMessage ] = [ ]
2024-11-16 03:51:04 +00:00
self . _ui_config = ui_config
self . console = Console ( )
2024-10-28 16:59:58 +00:00
@message_handler
async def handle_message ( self , message : GroupChatMessage , ctx : MessageContext ) - > None :
self . _chat_history . extend (
[
UserMessage ( content = f " Transferred to { message . body . source } " , source = " system " ) , # type: ignore[union-attr]
message . body ,
]
)
@message_handler
async def handle_request_to_speak ( self , message : RequestToSpeak , ctx : MessageContext ) - > None :
self . _chat_history . append (
UserMessage ( content = f " Transferred to { self . id . type } , adopt the persona immediately. " , source = " system " )
)
completion = await self . _model_client . create ( [ self . _system_message ] + self . _chat_history )
assert isinstance ( completion . content , str )
self . _chat_history . append ( AssistantMessage ( content = completion . content , source = self . id . type ) )
2024-11-16 03:51:04 +00:00
console_message = f " \n { ' - ' * 80 } \n ** { self . id . type } **: { completion . content } "
self . console . print ( Markdown ( console_message ) )
await publish_message_to_ui_and_backend (
runtime = self ,
source = self . id . type ,
user_message = completion . content ,
ui_config = self . _ui_config ,
group_chat_topic_type = self . _group_chat_topic_type ,
2024-10-28 16:59:58 +00:00
)
class GroupChatManager ( RoutedAgent ) :
def __init__ (
self ,
model_client : ChatCompletionClient ,
participant_topic_types : List [ str ] ,
participant_descriptions : List [ str ] ,
2024-11-16 03:51:04 +00:00
ui_config : UIAgentConfig ,
2024-10-28 16:59:58 +00:00
max_rounds : int = 3 ,
) - > None :
super ( ) . __init__ ( " Group chat manager " )
self . _model_client = model_client
self . _num_rounds = 0
self . _participant_topic_types = participant_topic_types
self . _chat_history : List [ GroupChatMessage ] = [ ]
self . _max_rounds = max_rounds
self . console = Console ( )
self . _participant_descriptions = participant_descriptions
self . _previous_participant_topic_type : str | None = None
2024-11-16 03:51:04 +00:00
self . _ui_config = ui_config
2024-10-28 16:59:58 +00:00
@message_handler
async def handle_message ( self , message : GroupChatMessage , ctx : MessageContext ) - > None :
assert isinstance ( message . body , UserMessage )
2024-11-16 03:51:04 +00:00
2024-10-31 11:54:24 +00:00
self . _chat_history . append ( message . body ) # type: ignore[reportargumenttype,arg-type]
2024-10-28 16:59:58 +00:00
# Format message history.
messages : List [ str ] = [ ]
for msg in self . _chat_history :
if isinstance ( msg . content , str ) : # type: ignore[attr-defined]
messages . append ( f " { msg . source } : { msg . content } " ) # type: ignore[attr-defined]
elif isinstance ( msg . content , list ) : # type: ignore[attr-defined]
messages . append ( f " { msg . source } : { ' , ' . join ( msg . content ) } " ) # type: ignore[attr-defined,reportUnknownArgumentType]
history = " \n " . join ( messages )
# Format roles.
roles = " \n " . join (
[
f " { topic_type } : { description } " . strip ( )
for topic_type , description in zip (
self . _participant_topic_types , self . _participant_descriptions , strict = True
)
if topic_type != self . _previous_participant_topic_type
]
)
participants = str (
[
topic_type
for topic_type in self . _participant_topic_types
if topic_type != self . _previous_participant_topic_type
]
)
selector_prompt = f """ You are in a role play game. The following roles are available:
{ roles } .
Read the following conversation . Then select the next role from { participants } to play . Only return the role .
{ history }
Read the above conversation . Then select the next role from { participants } to play . if you think it ' s enough talking (for example they have talked for {self._max_rounds} rounds), return ' FINISH ' .
"""
2024-12-04 16:14:41 -08:00
system_message = SystemMessage ( content = selector_prompt )
2024-10-28 16:59:58 +00:00
completion = await self . _model_client . create ( [ system_message ] , cancellation_token = ctx . cancellation_token )
2024-11-16 03:51:04 +00:00
assert isinstance (
completion . content , str
) , f " Completion content must be a string, but is: { type ( completion . content ) } "
2024-10-28 16:59:58 +00:00
if completion . content . upper ( ) == " FINISH " :
2024-11-16 03:51:04 +00:00
finish_msg = " I think it ' s enough iterations on the story! Thanks for collaborating! "
manager_message = f " \n { ' - ' * 80 } \n Manager ( { id ( self ) } ): { finish_msg } "
await publish_message_to_ui (
runtime = self , source = self . id . type , user_message = finish_msg , ui_config = self . _ui_config
)
2024-10-31 11:54:24 +00:00
self . console . print ( Markdown ( manager_message ) )
2024-10-28 16:59:58 +00:00
return
selected_topic_type : str
for topic_type in self . _participant_topic_types :
if topic_type . lower ( ) in completion . content . lower ( ) :
selected_topic_type = topic_type
self . _previous_participant_topic_type = selected_topic_type
self . console . print (
Markdown ( f " \n { ' - ' * 80 } \n Manager ( { id ( self ) } ): Asking ` { selected_topic_type } ` to speak " )
)
await self . publish_message ( RequestToSpeak ( ) , DefaultTopicId ( type = selected_topic_type ) )
return
raise ValueError ( f " Invalid role selected: { completion . content } " )
2024-11-16 03:51:04 +00:00
class UIAgent ( RoutedAgent ) :
""" Handles UI-related tasks and message processing for the distributed group chat system. """
def __init__ ( self , on_message_chunk_func : Callable [ [ MessageChunk ] , Awaitable [ None ] ] ) - > None :
super ( ) . __init__ ( " UI Agent " )
self . _on_message_chunk_func = on_message_chunk_func
@message_handler
async def handle_message_chunk ( self , message : MessageChunk , ctx : MessageContext ) - > None :
await self . _on_message_chunk_func ( message )
async def publish_message_to_ui (
2024-12-04 16:23:20 -08:00
runtime : RoutedAgent | GrpcWorkerAgentRuntime ,
2024-11-16 03:51:04 +00:00
source : str ,
user_message : str ,
ui_config : UIAgentConfig ,
) - > None :
message_id = str ( uuid4 ( ) )
# Stream the message to UI
message_chunks = (
MessageChunk ( message_id = message_id , text = token + " " , author = source , finished = False )
for token in user_message . split ( )
)
for chunk in message_chunks :
await runtime . publish_message (
chunk ,
DefaultTopicId ( type = ui_config . topic_type ) ,
)
await asyncio . sleep ( random . uniform ( ui_config . min_delay , ui_config . max_delay ) )
await runtime . publish_message (
MessageChunk ( message_id = message_id , text = " " , author = source , finished = True ) ,
DefaultTopicId ( type = ui_config . topic_type ) ,
)
async def publish_message_to_ui_and_backend (
2024-12-04 16:23:20 -08:00
runtime : RoutedAgent | GrpcWorkerAgentRuntime ,
2024-11-16 03:51:04 +00:00
source : str ,
user_message : str ,
ui_config : UIAgentConfig ,
group_chat_topic_type : str ,
) - > None :
# Publish messages for ui
await publish_message_to_ui (
runtime = runtime ,
source = source ,
user_message = user_message ,
ui_config = ui_config ,
)
# Publish message to backend
await runtime . publish_message (
GroupChatMessage ( body = UserMessage ( content = user_message , source = source ) ) ,
topic_id = DefaultTopicId ( type = group_chat_topic_type ) ,
)