autogen/python/examples/patterns/mixture_of_agents_direct.py

141 lines
4.9 KiB
Python
Raw Normal View History

"""This example demonstrates the mixture of agents implemented using direct
messaging and async gathering of results.
Mixture of agents: https://github.com/togethercomputer/moa"""
import asyncio
from dataclasses import dataclass
from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import ChatCompletionClient, OpenAIChatCompletionClient, SystemMessage, UserMessage
from agnext.core import AgentId, CancellationToken
@dataclass
class ReferenceAgentTask:
task: str
@dataclass
class ReferenceAgentTaskResult:
result: str
@dataclass
class AggregatorTask:
task: str
@dataclass
class AggregatorTaskResult:
result: str
class ReferenceAgent(TypeRoutedAgent):
"""The reference agent that handles each task independently."""
def __init__(
self,
description: str,
system_messages: List[SystemMessage],
model_client: ChatCompletionClient,
) -> None:
super().__init__(description)
self._system_messages = system_messages
self._model_client = model_client
@message_handler
async def handle_task(
self, message: ReferenceAgentTask, cancellation_token: CancellationToken
) -> ReferenceAgentTaskResult:
"""Handle a task message. This method sends the task to the model and respond with the result."""
task_message = UserMessage(content=message.task, source=self.metadata["name"])
response = await self._model_client.create(self._system_messages + [task_message])
assert isinstance(response.content, str)
return ReferenceAgentTaskResult(result=response.content)
class AggregatorAgent(TypeRoutedAgent):
"""The aggregator agent that distribute tasks to reference agents and aggregates the results."""
def __init__(
self,
description: str,
system_messages: List[SystemMessage],
model_client: ChatCompletionClient,
references: List[AgentId],
) -> None:
super().__init__(description)
self._system_messages = system_messages
self._model_client = model_client
self._references = references
@message_handler
async def handle_task(self, message: AggregatorTask, cancellation_token: CancellationToken) -> AggregatorTaskResult:
"""Handle a task message. This method sends the task to the reference agents
and aggregates the results."""
ref_task = ReferenceAgentTask(task=message.task)
results: List[ReferenceAgentTaskResult] = await asyncio.gather(
*[self.send_message(ref_task, ref) for ref in self._references]
)
combined_result = "\n\n".join([r.result for r in results])
response = await self._model_client.create(
self._system_messages + [UserMessage(content=combined_result, source=self.metadata["name"])]
)
assert isinstance(response.content, str)
return AggregatorTaskResult(result=response.content)
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
ref1 = runtime.register_and_get(
"ReferenceAgent1",
lambda: ReferenceAgent(
description="Reference Agent 1",
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo", temperature=0.1),
),
)
ref2 = runtime.register_and_get(
"ReferenceAgent2",
lambda: ReferenceAgent(
description="Reference Agent 2",
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo", temperature=0.5),
),
)
ref3 = runtime.register_and_get(
"ReferenceAgent3",
lambda: ReferenceAgent(
description="Reference Agent 3",
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo", temperature=1.0),
),
)
agg = runtime.register_and_get(
"AggregatorAgent",
lambda: AggregatorAgent(
description="Aggregator Agent",
system_messages=[
SystemMessage(
"...synthesize these responses into a single, high-quality response... Responses from models:"
)
],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo"),
references=[ref1, ref2, ref3],
),
)
result = runtime.send_message(AggregatorTask(task="What are something fun to do in SF?"), agg)
while result.done() is False:
await runtime.process_next()
print(result.result())
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
asyncio.run(main())