autogen/python/samples/worker/run_worker_pub_sub.py

88 lines
2.3 KiB
Python
Raw Normal View History

import asyncio
import logging
from dataclasses import dataclass
from agnext.application import WorkerAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext
@dataclass
class AskToGreet:
content: str
@dataclass
class Greeting:
content: str
@dataclass
class ReturnedGreeting:
content: str
@dataclass
class Feedback:
content: str
@dataclass
class ReturnedFeedback:
content: str
class ReceiveAgent(TypeRoutedAgent):
def __init__(self) -> None:
super().__init__("Receive Agent")
@message_handler
async def on_greet(self, message: Greeting, ctx: MessageContext) -> None:
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"))
@message_handler
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"))
class GreeterAgent(TypeRoutedAgent):
def __init__(self) -> None:
super().__init__("Greeter Agent")
@message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
await self.publish_message(Greeting(f"Hello, {message.content}!"))
@message_handler
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
await self.publish_message(Feedback(f"Feedback: {message.content}"))
async def main() -> None:
runtime = WorkerAgentRuntime()
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
MESSAGE_TYPE_REGISTRY.add_type(ReturnedGreeting)
MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback)
await runtime.start(host_connection_string="localhost:50051")
await runtime.register("reciever", lambda: ReceiveAgent())
await runtime.register("greeter", lambda: GreeterAgent())
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
# Just to keep the runtime running
try:
await asyncio.sleep(1000000)
except KeyboardInterrupt:
pass
await runtime.stop()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("agnext")
logger.setLevel(logging.DEBUG)
asyncio.run(main())