2024-06-24 16:22:08 -07:00
|
|
|
import asyncio
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Any, List
|
|
|
|
|
|
|
|
from agnext.application import SingleThreadedAgentRuntime
|
|
|
|
from agnext.components import TypeRoutedAgent, message_handler
|
|
|
|
from agnext.components.models import (
|
|
|
|
AssistantMessage,
|
|
|
|
ChatCompletionClient,
|
|
|
|
LLMMessage,
|
2024-06-25 12:39:25 -04:00
|
|
|
OpenAIChatCompletionClient,
|
2024-06-24 16:22:08 -07:00
|
|
|
SystemMessage,
|
|
|
|
UserMessage,
|
|
|
|
)
|
|
|
|
from agnext.core import AgentId, CancellationToken
|
|
|
|
from agnext.core.intervention import DefaultInterventionHandler
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Message:
|
|
|
|
source: str
|
|
|
|
content: str
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class RequestToSpeak:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Termination:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class RoundRobinGroupChatManager(TypeRoutedAgent):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
description: str,
|
|
|
|
participants: List[AgentId],
|
|
|
|
num_rounds: int,
|
|
|
|
) -> None:
|
|
|
|
super().__init__(description)
|
|
|
|
self._participants = participants
|
|
|
|
self._num_rounds = num_rounds
|
|
|
|
self._round_count = 0
|
|
|
|
|
|
|
|
@message_handler
|
|
|
|
async def handle_message(self, message: Message, cancellation_token: CancellationToken) -> None:
|
|
|
|
# Select the next speaker in a round-robin fashion
|
|
|
|
speaker = self._participants[self._round_count % len(self._participants)]
|
|
|
|
self._round_count += 1
|
|
|
|
if self._round_count == self._num_rounds * len(self._participants):
|
|
|
|
# End the conversation after the specified number of rounds.
|
|
|
|
self.publish_message(Termination())
|
|
|
|
return
|
|
|
|
# Send a request to speak message to the selected speaker.
|
|
|
|
self.send_message(RequestToSpeak(), speaker)
|
|
|
|
|
|
|
|
|
|
|
|
class GroupChatParticipant(TypeRoutedAgent):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
description: str,
|
|
|
|
system_messages: List[SystemMessage],
|
|
|
|
model_client: ChatCompletionClient,
|
|
|
|
) -> None:
|
|
|
|
super().__init__(description)
|
|
|
|
self._system_messages = system_messages
|
|
|
|
self._model_client = model_client
|
|
|
|
self._memory: List[Message] = []
|
|
|
|
|
|
|
|
@message_handler
|
|
|
|
async def handle_message(self, message: Message, cancellation_token: CancellationToken) -> None:
|
|
|
|
self._memory.append(message)
|
|
|
|
|
|
|
|
@message_handler
|
|
|
|
async def handle_request_to_speak(self, message: RequestToSpeak, cancellation_token: CancellationToken) -> None:
|
|
|
|
# Generate a response to the last message in the memory
|
|
|
|
if not self._memory:
|
|
|
|
return
|
|
|
|
llm_messages: List[LLMMessage] = []
|
|
|
|
for m in self._memory[-10:]:
|
|
|
|
if m.source == self.metadata["name"]:
|
|
|
|
llm_messages.append(AssistantMessage(content=m.content, source=self.metadata["name"]))
|
|
|
|
else:
|
|
|
|
llm_messages.append(UserMessage(content=m.content, source=m.source))
|
|
|
|
response = await self._model_client.create(self._system_messages + llm_messages)
|
|
|
|
assert isinstance(response.content, str)
|
|
|
|
speach = Message(content=response.content, source=self.metadata["name"])
|
|
|
|
self._memory.append(speach)
|
|
|
|
self.publish_message(speach)
|
|
|
|
|
|
|
|
|
|
|
|
class TerminationHandler(DefaultInterventionHandler):
|
|
|
|
"""A handler that listens for termination messages."""
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self._terminated = False
|
|
|
|
|
|
|
|
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
|
|
|
|
if isinstance(message, Termination):
|
|
|
|
self._terminated = True
|
|
|
|
return message
|
|
|
|
|
|
|
|
@property
|
|
|
|
def terminated(self) -> bool:
|
|
|
|
return self._terminated
|
|
|
|
|
|
|
|
|
|
|
|
async def main() -> None:
|
|
|
|
# Create the termination handler.
|
|
|
|
termination_handler = TerminationHandler()
|
|
|
|
|
|
|
|
# Create the runtime.
|
|
|
|
runtime = SingleThreadedAgentRuntime(intervention_handler=termination_handler)
|
|
|
|
|
|
|
|
# Register the participants.
|
|
|
|
agent1 = runtime.register_and_get(
|
|
|
|
"DataScientist",
|
|
|
|
lambda: GroupChatParticipant(
|
|
|
|
description="A data scientist",
|
|
|
|
system_messages=[SystemMessage("You are a data scientist.")],
|
2024-06-25 12:39:25 -04:00
|
|
|
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo"),
|
2024-06-24 16:22:08 -07:00
|
|
|
),
|
|
|
|
)
|
|
|
|
agent2 = runtime.register_and_get(
|
|
|
|
"Engineer",
|
|
|
|
lambda: GroupChatParticipant(
|
|
|
|
description="An engineer",
|
|
|
|
system_messages=[SystemMessage("You are an engineer.")],
|
2024-06-25 12:39:25 -04:00
|
|
|
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo"),
|
2024-06-24 16:22:08 -07:00
|
|
|
),
|
|
|
|
)
|
|
|
|
agent3 = runtime.register_and_get(
|
|
|
|
"Artist",
|
|
|
|
lambda: GroupChatParticipant(
|
|
|
|
description="An artist",
|
|
|
|
system_messages=[SystemMessage("You are an artist.")],
|
2024-06-25 12:39:25 -04:00
|
|
|
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo"),
|
2024-06-24 16:22:08 -07:00
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Register the group chat manager.
|
|
|
|
runtime.register(
|
|
|
|
"GroupChatManager",
|
|
|
|
lambda: RoundRobinGroupChatManager(
|
|
|
|
description="A group chat manager",
|
|
|
|
participants=[agent1, agent2, agent3],
|
|
|
|
num_rounds=3,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Start the conversation.
|
|
|
|
runtime.publish_message(Message(content="Hello, everyone!", source="Moderator"), namespace="default")
|
|
|
|
|
|
|
|
# Run the runtime until termination.
|
|
|
|
while not termination_handler.terminated:
|
|
|
|
await runtime.process_next()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import logging
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.WARNING)
|
|
|
|
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
|
|
|
asyncio.run(main())
|