autogen/python/examples/patterns/group_chat_pub_sub.py

167 lines
5.2 KiB
Python
Raw Normal View History

import asyncio
from dataclasses import dataclass
from typing import Any, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
OpenAIChatCompletionClient,
SystemMessage,
UserMessage,
)
from agnext.core import AgentId, CancellationToken
from agnext.core.intervention import DefaultInterventionHandler
@dataclass
class Message:
source: str
content: str
@dataclass
class RequestToSpeak:
pass
@dataclass
class Termination:
pass
class RoundRobinGroupChatManager(TypeRoutedAgent):
def __init__(
self,
description: str,
participants: List[AgentId],
num_rounds: int,
) -> None:
super().__init__(description)
self._participants = participants
self._num_rounds = num_rounds
self._round_count = 0
@message_handler
async def handle_message(self, message: Message, cancellation_token: CancellationToken) -> None:
# Select the next speaker in a round-robin fashion
speaker = self._participants[self._round_count % len(self._participants)]
self._round_count += 1
if self._round_count == self._num_rounds * len(self._participants):
# End the conversation after the specified number of rounds.
self.publish_message(Termination())
return
# Send a request to speak message to the selected speaker.
self.send_message(RequestToSpeak(), speaker)
class GroupChatParticipant(TypeRoutedAgent):
def __init__(
self,
description: str,
system_messages: List[SystemMessage],
model_client: ChatCompletionClient,
) -> None:
super().__init__(description)
self._system_messages = system_messages
self._model_client = model_client
self._memory: List[Message] = []
@message_handler
async def handle_message(self, message: Message, cancellation_token: CancellationToken) -> None:
self._memory.append(message)
@message_handler
async def handle_request_to_speak(self, message: RequestToSpeak, cancellation_token: CancellationToken) -> None:
# Generate a response to the last message in the memory
if not self._memory:
return
llm_messages: List[LLMMessage] = []
for m in self._memory[-10:]:
if m.source == self.metadata["name"]:
llm_messages.append(AssistantMessage(content=m.content, source=self.metadata["name"]))
else:
llm_messages.append(UserMessage(content=m.content, source=m.source))
response = await self._model_client.create(self._system_messages + llm_messages)
assert isinstance(response.content, str)
speach = Message(content=response.content, source=self.metadata["name"])
self._memory.append(speach)
self.publish_message(speach)
class TerminationHandler(DefaultInterventionHandler):
"""A handler that listens for termination messages."""
def __init__(self) -> None:
self._terminated = False
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
if isinstance(message, Termination):
self._terminated = True
return message
@property
def terminated(self) -> bool:
return self._terminated
async def main() -> None:
# Create the termination handler.
termination_handler = TerminationHandler()
# Create the runtime.
runtime = SingleThreadedAgentRuntime(intervention_handler=termination_handler)
# Register the participants.
agent1 = runtime.register_and_get(
"DataScientist",
lambda: GroupChatParticipant(
description="A data scientist",
system_messages=[SystemMessage("You are a data scientist.")],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo"),
),
)
agent2 = runtime.register_and_get(
"Engineer",
lambda: GroupChatParticipant(
description="An engineer",
system_messages=[SystemMessage("You are an engineer.")],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo"),
),
)
agent3 = runtime.register_and_get(
"Artist",
lambda: GroupChatParticipant(
description="An artist",
system_messages=[SystemMessage("You are an artist.")],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo"),
),
)
# Register the group chat manager.
runtime.register(
"GroupChatManager",
lambda: RoundRobinGroupChatManager(
description="A group chat manager",
participants=[agent1, agent2, agent3],
num_rounds=3,
),
)
# Start the conversation.
runtime.publish_message(Message(content="Hello, everyone!", source="Moderator"), namespace="default")
# Run the runtime until termination.
while not termination_handler.terminated:
await runtime.process_next()
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
asyncio.run(main())