""" This example demonstrates the mixture of agents implemented using direct messaging and async gathering of results. Mixture of agents: https://github.com/togethercomputer/moa The example consists of two types of agents: reference agents and an aggregator agent. The aggregator agent distributes tasks to reference agents and aggregates the results. The reference agents handle each task independently and return the results to the aggregator agent. """ import asyncio from dataclasses import dataclass from typing import List from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler from agnext.components.models import ChatCompletionClient, OpenAIChatCompletionClient, SystemMessage, UserMessage from agnext.core import AgentId, CancellationToken @dataclass class ReferenceAgentTask: task: str @dataclass class ReferenceAgentTaskResult: result: str @dataclass class AggregatorTask: task: str @dataclass class AggregatorTaskResult: result: str class ReferenceAgent(TypeRoutedAgent): """The reference agent that handles each task independently.""" def __init__( self, description: str, system_messages: List[SystemMessage], model_client: ChatCompletionClient, ) -> None: super().__init__(description) self._system_messages = system_messages self._model_client = model_client @message_handler async def handle_task( self, message: ReferenceAgentTask, cancellation_token: CancellationToken ) -> ReferenceAgentTaskResult: """Handle a task message. This method sends the task to the model and respond with the result.""" task_message = UserMessage(content=message.task, source=self.metadata["name"]) response = await self._model_client.create(self._system_messages + [task_message]) assert isinstance(response.content, str) return ReferenceAgentTaskResult(result=response.content) class AggregatorAgent(TypeRoutedAgent): """The aggregator agent that distribute tasks to reference agents and aggregates the results.""" def __init__( self, description: str, system_messages: List[SystemMessage], model_client: ChatCompletionClient, references: List[AgentId], ) -> None: super().__init__(description) self._system_messages = system_messages self._model_client = model_client self._references = references @message_handler async def handle_task(self, message: AggregatorTask, cancellation_token: CancellationToken) -> AggregatorTaskResult: """Handle a task message. This method sends the task to the reference agents and aggregates the results.""" ref_task = ReferenceAgentTask(task=message.task) results: List[ReferenceAgentTaskResult] = await asyncio.gather( *[self.send_message(ref_task, ref) for ref in self._references] ) combined_result = "\n\n".join([r.result for r in results]) response = await self._model_client.create( self._system_messages + [UserMessage(content=combined_result, source=self.metadata["name"])] ) assert isinstance(response.content, str) return AggregatorTaskResult(result=response.content) async def main() -> None: runtime = SingleThreadedAgentRuntime() ref1 = runtime.register_and_get( "ReferenceAgent1", lambda: ReferenceAgent( description="Reference Agent 1", system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")], model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo", temperature=0.1), ), ) ref2 = runtime.register_and_get( "ReferenceAgent2", lambda: ReferenceAgent( description="Reference Agent 2", system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")], model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo", temperature=0.5), ), ) ref3 = runtime.register_and_get( "ReferenceAgent3", lambda: ReferenceAgent( description="Reference Agent 3", system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")], model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo", temperature=1.0), ), ) agg = runtime.register_and_get( "AggregatorAgent", lambda: AggregatorAgent( description="Aggregator Agent", system_messages=[ SystemMessage( "...synthesize these responses into a single, high-quality response... Responses from models:" ) ], model_client=OpenAIChatCompletionClient(model="gpt-3.5-turbo"), references=[ref1, ref2, ref3], ), ) result = runtime.send_message(AggregatorTask(task="What are something fun to do in SF?"), agg) while result.done() is False: await runtime.process_next() print(result.result()) if __name__ == "__main__": import logging logging.basicConfig(level=logging.WARNING) logging.getLogger("agnext").setLevel(logging.DEBUG) asyncio.run(main())