autogen/python/examples/patterns/mixture_of_agents_direct.py
Jack Gerrits a13c971b16 Change send/publish api to better support async and represent reality (#137)
* Make send and publish better represent reality

* fix team-one
2024-06-27 13:40:12 -04:00

147 lines
5.2 KiB
Python

"""
This example demonstrates the mixture of agents implemented using direct
messaging and async gathering of results.
Mixture of agents: https://github.com/togethercomputer/moa
The example consists of two types of agents: reference agents and an aggregator agent.
The aggregator agent distributes tasks to reference agents and aggregates the results.
The reference agents handle each task independently and return the results to the aggregator agent.
"""
import asyncio
from dataclasses import dataclass
from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import ChatCompletionClient, OpenAIChatCompletionClient, SystemMessage, UserMessage
from agnext.core import AgentId, CancellationToken
@dataclass
class ReferenceAgentTask:
task: str
@dataclass
class ReferenceAgentTaskResult:
result: str
@dataclass
class AggregatorTask:
task: str
@dataclass
class AggregatorTaskResult:
result: str
class ReferenceAgent(TypeRoutedAgent):
"""The reference agent that handles each task independently."""
def __init__(
self,
description: str,
system_messages: List[SystemMessage],
model_client: ChatCompletionClient,
) -> None:
super().__init__(description)
self._system_messages = system_messages
self._model_client = model_client
@message_handler
async def handle_task(
self, message: ReferenceAgentTask, cancellation_token: CancellationToken
) -> ReferenceAgentTaskResult:
"""Handle a task message. This method sends the task to the model and respond with the result."""
task_message = UserMessage(content=message.task, source=self.metadata["name"])
response = await self._model_client.create(self._system_messages + [task_message])
assert isinstance(response.content, str)
return ReferenceAgentTaskResult(result=response.content)
class AggregatorAgent(TypeRoutedAgent):
"""The aggregator agent that distribute tasks to reference agents and aggregates the results."""
def __init__(
self,
description: str,
system_messages: List[SystemMessage],
model_client: ChatCompletionClient,
references: List[AgentId],
) -> None:
super().__init__(description)
self._system_messages = system_messages
self._model_client = model_client
self._references = references
@message_handler
async def handle_task(self, message: AggregatorTask, cancellation_token: CancellationToken) -> AggregatorTaskResult:
"""Handle a task message. This method sends the task to the reference agents
and aggregates the results."""
ref_task = ReferenceAgentTask(task=message.task)
results: List[ReferenceAgentTaskResult] = await asyncio.gather(
*[await self.send_message(ref_task, ref) for ref in self._references]
)
combined_result = "\n\n".join([r.result for r in results])
response = await self._model_client.create(
self._system_messages + [UserMessage(content=combined_result, source=self.metadata["name"])]
)
assert isinstance(response.content, str)
return AggregatorTaskResult(result=response.content)
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
ref1 = runtime.register_and_get(
"ReferenceAgent1",
lambda: ReferenceAgent(
description="Reference Agent 1",
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo", temperature=0.1),
),
)
ref2 = runtime.register_and_get(
"ReferenceAgent2",
lambda: ReferenceAgent(
description="Reference Agent 2",
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo", temperature=0.5),
),
)
ref3 = runtime.register_and_get(
"ReferenceAgent3",
lambda: ReferenceAgent(
description="Reference Agent 3",
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo", temperature=1.0),
),
)
agg = runtime.register_and_get(
"AggregatorAgent",
lambda: AggregatorAgent(
description="Aggregator Agent",
system_messages=[
SystemMessage(
"...synthesize these responses into a single, high-quality response... Responses from models:"
)
],
model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo"),
references=[ref1, ref2, ref3],
),
)
result = await runtime.send_message(AggregatorTask(task="What are something fun to do in SF?"), agg)
while result.done() is False:
await runtime.process_next()
print(result.result())
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
asyncio.run(main())