diff --git a/python/.gitignore b/python/.gitignore index 1e5623e46..65be202c2 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -166,6 +166,7 @@ cython_debug/ # Generated proto files src/agnext/worker/protos/agent* +src/agnext/application/protos/agent* # Generated log files log.jsonl @@ -174,4 +175,4 @@ log.jsonl docs/**/jupyter_execute # Temporary files -tmp_code_*.py \ No newline at end of file +tmp_code_*.py diff --git a/python/docs/src/conf.py b/python/docs/src/conf.py index fc5befd29..55b258215 100644 --- a/python/docs/src/conf.py +++ b/python/docs/src/conf.py @@ -28,7 +28,7 @@ apidoc_template_dir = "_apidoc_templates" apidoc_separate_modules = True apidoc_extra_args = ["--no-toc"] napoleon_custom_sections = [("Returns", "params_style")] -apidoc_excluded_paths = ["./worker/protos/"] +apidoc_excluded_paths = ["./application/protos/"] templates_path = [] exclude_patterns = ["reference/agnext.rst"] diff --git a/python/docs/src/index.rst b/python/docs/src/index.rst index 0f0512004..1c6b3c9a7 100644 --- a/python/docs/src/index.rst +++ b/python/docs/src/index.rst @@ -59,7 +59,6 @@ To learn about the core concepts of AGNext, read the `overview None: MESSAGE_TYPE_REGISTRY.add_type(AuditText) MESSAGE_TYPE_REGISTRY.add_type(AuditorAlert) agnext_logger.info("1") - await runtime.setup_channel("localhost:5145") + await runtime.start("localhost:5145") agnext_logger.info("2") @@ -30,7 +30,7 @@ async def main() -> None: await asyncio.sleep(1000000) except KeyboardInterrupt: pass - await runtime.close_channel() + await runtime.stop() if __name__ == "__main__": diff --git a/python/samples/worker/run_host.py b/python/samples/worker/run_host.py new file mode 100644 index 000000000..eed1d586e --- /dev/null +++ b/python/samples/worker/run_host.py @@ -0,0 +1,53 @@ +import asyncio +import signal + +import grpc +from agnext.application import HostRuntimeServicer +from agnext.application.protos import agent_worker_pb2_grpc + + +async def serve(server: grpc.aio.Server) -> None: # type: ignore + await server.start() + print("Server started") + await server.wait_for_termination() + + +async def main() -> None: + server = grpc.aio.server() + agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(HostRuntimeServicer(), server) + server.add_insecure_port("[::]:50051") + + # Set up signal handling for graceful shutdown + loop = asyncio.get_running_loop() + + shutdown_event = asyncio.Event() + + def signal_handler() -> None: + print("Received exit signal, shutting down gracefully...") + shutdown_event.set() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + # Start server in background task + serve_task = asyncio.create_task(serve(server)) + + # Wait for the signal to trigger the shutdown event + await shutdown_event.wait() + + # Graceful shutdown + await server.stop(5) # 5 second grace period + await serve_task + print("Server stopped") + + +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.WARNING) + logging.getLogger("agnext").setLevel(logging.DEBUG) + + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("Server shutdown interrupted.") diff --git a/python/samples/worker/run_worker_pub_sub.py b/python/samples/worker/run_worker_pub_sub.py new file mode 100644 index 000000000..e46aa5f21 --- /dev/null +++ b/python/samples/worker/run_worker_pub_sub.py @@ -0,0 +1,87 @@ +import asyncio +import logging +from dataclasses import dataclass + +from agnext.application import WorkerAgentRuntime +from agnext.components import TypeRoutedAgent, message_handler +from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext + + +@dataclass +class AskToGreet: + content: str + + +@dataclass +class Greeting: + content: str + + +@dataclass +class ReturnedGreeting: + content: str + + +@dataclass +class Feedback: + content: str + + +@dataclass +class ReturnedFeedback: + content: str + + +class ReceiveAgent(TypeRoutedAgent): + def __init__(self) -> None: + super().__init__("Receive Agent") + + @message_handler + async def on_greet(self, message: Greeting, ctx: MessageContext) -> None: + await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}")) + + @message_handler + async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None: + await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}")) + + +class GreeterAgent(TypeRoutedAgent): + def __init__(self) -> None: + super().__init__("Greeter Agent") + + @message_handler + async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None: + await self.publish_message(Greeting(f"Hello, {message.content}!")) + + @message_handler + async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None: + await self.publish_message(Feedback(f"Feedback: {message.content}")) + + +async def main() -> None: + runtime = WorkerAgentRuntime() + MESSAGE_TYPE_REGISTRY.add_type(Greeting) + MESSAGE_TYPE_REGISTRY.add_type(AskToGreet) + MESSAGE_TYPE_REGISTRY.add_type(Feedback) + MESSAGE_TYPE_REGISTRY.add_type(ReturnedGreeting) + MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback) + await runtime.start(host_connection_string="localhost:50051") + + await runtime.register("reciever", lambda: ReceiveAgent()) + await runtime.register("greeter", lambda: GreeterAgent()) + + await runtime.publish_message(AskToGreet("Hello World!"), namespace="default") + + # Just to keep the runtime running + try: + await asyncio.sleep(1000000) + except KeyboardInterrupt: + pass + await runtime.stop() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger("agnext") + logger.setLevel(logging.DEBUG) + asyncio.run(main()) diff --git a/python/samples/worker/run_worker_rpc.py b/python/samples/worker/run_worker_rpc.py new file mode 100644 index 000000000..2168182b7 --- /dev/null +++ b/python/samples/worker/run_worker_rpc.py @@ -0,0 +1,74 @@ +import asyncio +import logging +from dataclasses import dataclass + +from agnext.application import WorkerAgentRuntime +from agnext.components import TypeRoutedAgent, message_handler +from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext + + +@dataclass +class AskToGreet: + content: str + + +@dataclass +class Greeting: + content: str + + +@dataclass +class Feedback: + content: str + + +class ReceiveAgent(TypeRoutedAgent): + def __init__(self) -> None: + super().__init__("Receive Agent") + + @message_handler + async def on_greet(self, message: Greeting, ctx: MessageContext) -> Greeting: + return Greeting(content=f"Received: {message.content}") + + @message_handler + async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None: + print(f"Feedback received: {message.content}") + + +class GreeterAgent(TypeRoutedAgent): + def __init__(self, receive_agent_id: AgentId) -> None: + super().__init__("Greeter Agent") + self._receive_agent_id = receive_agent_id + + @message_handler + async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None: + response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id) + await self.publish_message(Feedback(f"Feedback: {response.content}")) + + +async def main() -> None: + runtime = WorkerAgentRuntime() + MESSAGE_TYPE_REGISTRY.add_type(Greeting) + MESSAGE_TYPE_REGISTRY.add_type(AskToGreet) + MESSAGE_TYPE_REGISTRY.add_type(Feedback) + await runtime.start(host_connection_string="localhost:50051") + + await runtime.register("reciever", lambda: ReceiveAgent()) + reciever = await runtime.get("reciever") + await runtime.register("greeter", lambda: GreeterAgent(reciever)) + + await runtime.publish_message(AskToGreet("Hello World!"), namespace="default") + + # Just to keep the runtime running + try: + await asyncio.sleep(1000000) + except KeyboardInterrupt: + pass + await runtime.stop() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger("agnext") + logger.setLevel(logging.DEBUG) + asyncio.run(main()) diff --git a/python/src/agnext/application/__init__.py b/python/src/agnext/application/__init__.py index 72aac4583..b7379f251 100644 --- a/python/src/agnext/application/__init__.py +++ b/python/src/agnext/application/__init__.py @@ -2,6 +2,8 @@ The :mod:`agnext.application` module provides implementations of core components that are used to compose an application """ +from ._host_runtime_servicer import HostRuntimeServicer from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime +from ._worker_runtime import WorkerAgentRuntime -__all__ = ["SingleThreadedAgentRuntime"] +__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "HostRuntimeServicer"] diff --git a/python/src/agnext/application/_host_runtime_servicer.py b/python/src/agnext/application/_host_runtime_servicer.py new file mode 100644 index 000000000..a44a55ed1 --- /dev/null +++ b/python/src/agnext/application/_host_runtime_servicer.py @@ -0,0 +1,150 @@ +import asyncio +import logging +from _collections_abc import AsyncIterator, Iterator +from asyncio import Future, Task +from typing import Any, Dict, Set + +import grpc + +from .protos import agent_worker_pb2, agent_worker_pb2_grpc + +logger = logging.getLogger("agnext") +event_logger = logging.getLogger("agnext.events") + + +class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer): + """A gRPC servicer that hosts message delivery service for agents.""" + + def __init__(self) -> None: + self._client_id = 0 + self._client_id_lock = asyncio.Lock() + self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {} + self._agent_type_to_client_id: Dict[str, int] = {} + self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {} + self._background_tasks: Set[Task[Any]] = set() + + async def OpenChannel( # type: ignore + self, + request_iterator: AsyncIterator[agent_worker_pb2.Message], + context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message], + ) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore + # Aquire the lock to get a new client id. + async with self._client_id_lock: + self._client_id += 1 + client_id = self._client_id + + # Register the client with the server and create a send queue for the client. + send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue() + self._send_queues[client_id] = send_queue + logger.info(f"Client {client_id} connected.") + + try: + # Concurrently handle receiving messages from the client and sending messages to the client. + # This task will receive messages from the client. + receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator)) + + # Return an async generator that will yield messages from the send queue to the client. + while True: + message = await send_queue.get() + # Yield the message to the client. + try: + yield message + except Exception as e: + logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True) + break + # Wait for the receiving task to finish. + await receiving_task + + finally: + # Clean up the client connection. + del self._send_queues[client_id] + # Cancel pending requests sent to this client. + for future in self._pending_requests.pop(client_id, {}).values(): + future.cancel() + logger.info(f"Client {client_id} disconnected.") + + async def _receive_messages( + self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message] + ) -> None: + # Receive messages from the client and process them. + async for message in request_iterator: + oneofcase = message.WhichOneof("message") + match oneofcase: + case "request": + request: agent_worker_pb2.RpcRequest = message.request + logger.info(f"Received request message: {request}") + task = asyncio.create_task(self._process_request(request, client_id)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + case "response": + response: agent_worker_pb2.RpcResponse = message.response + logger.info(f"Received response message: {response}") + task = asyncio.create_task(self._process_response(response, client_id)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + case "event": + event: agent_worker_pb2.Event = message.event + logger.info(f"Received event message: {event}") + task = asyncio.create_task(self._process_event(event)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + case "registerAgentType": + register_agent_type: agent_worker_pb2.RegisterAgentType = message.registerAgentType + logger.info(f"Received register agent type message: {register_agent_type}") + task = asyncio.create_task(self._process_register_agent_type(register_agent_type, client_id)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + case None: + logger.warning("Received empty message") + + async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None: + # Deliver the message to a client given the target agent type. + target_client_id = self._agent_type_to_client_id.get(request.target.name) + if target_client_id is None: + logger.error(f"Agent {request.target.name} not found, failed to deliver message.") + return + target_send_queue = self._send_queues.get(target_client_id) + if target_send_queue is None: + logger.error(f"Client {target_client_id} not found, failed to deliver message.") + return + await target_send_queue.put(agent_worker_pb2.Message(request=request)) + + # Create a future to wait for the response. + future = asyncio.get_event_loop().create_future() + self._pending_requests.setdefault(client_id, {})[request.request_id] = future + + # Create a task to wait for the response and send it back to the client. + send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id)) + self._background_tasks.add(send_response_task) + send_response_task.add_done_callback(self._background_tasks.discard) + + async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None: + response = await future + message = agent_worker_pb2.Message(response=response) + send_queue = self._send_queues.get(client_id) + if send_queue is None: + logger.error(f"Client {client_id} not found, failed to send response message.") + return + await send_queue.put(message) + + async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None: + # Setting the result of the future will send the response back to the original sender. + future = self._pending_requests[client_id].pop(response.request_id) + future.set_result(response) + + async def _process_event(self, event: agent_worker_pb2.Event) -> None: + # Deliver the event to all the clients. + # TODO: deliver based on subscriptions. + for send_queue in self._send_queues.values(): + await send_queue.put(agent_worker_pb2.Message(event=event)) + + async def _process_register_agent_type( + self, register_agent_type: agent_worker_pb2.RegisterAgentType, client_id: int + ) -> None: + # Register the agent type with the host runtime. + if register_agent_type.type in self._agent_type_to_client_id: + existing_client_id = self._agent_type_to_client_id[register_agent_type.type] + logger.warning( + f"Agent type {register_agent_type.type} already registered with client {existing_client_id}, overwriting the client mapping to client {client_id}." + ) + self._agent_type_to_client_id[register_agent_type.type] = client_id diff --git a/python/src/agnext/application/_single_threaded_agent_runtime.py b/python/src/agnext/application/_single_threaded_agent_runtime.py index a30840dec..6a44cb573 100644 --- a/python/src/agnext/application/_single_threaded_agent_runtime.py +++ b/python/src/agnext/application/_single_threaded_agent_runtime.py @@ -12,8 +12,6 @@ from dataclasses import dataclass from enum import Enum from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast -from agnext.core import MessageContext - from ..core import ( MESSAGE_TYPE_REGISTRY, Agent, @@ -23,6 +21,7 @@ from ..core import ( AgentProxy, AgentRuntime, CancellationToken, + MessageContext, ) from ..core.exceptions import MessageDroppedException from ..core.intervention import DropMessage, InterventionHandler diff --git a/python/src/agnext/worker/worker_runtime.py b/python/src/agnext/application/_worker_runtime.py similarity index 66% rename from python/src/agnext/worker/worker_runtime.py rename to python/src/agnext/application/_worker_runtime.py index 2b4a5ce45..556f0a7bf 100644 --- a/python/src/agnext/worker/worker_runtime.py +++ b/python/src/agnext/application/_worker_runtime.py @@ -2,12 +2,9 @@ import asyncio import inspect import json import logging -import threading import warnings from asyncio import Future, Task from collections import defaultdict -from collections.abc import Sequence -from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -50,40 +47,6 @@ if TYPE_CHECKING: logger = logging.getLogger("agnext") event_logger = logging.getLogger("agnext.events") - -@dataclass(kw_only=True) -class PublishMessageEnvelope: - """A message envelope for publishing messages to all agents that can handle - the message of the type T.""" - - message: Any - cancellation_token: CancellationToken - sender: AgentId | None - namespace: str - - -@dataclass(kw_only=True) -class SendMessageEnvelope: - """A message envelope for sending a message to a specific agent that can handle - the message of the type T.""" - - message: Any - sender: AgentId | None - recipient: AgentId - future: Future[Any] - cancellation_token: CancellationToken - - -@dataclass(kw_only=True) -class ResponseMessageEnvelope: - """A message envelope for sending a response to a message.""" - - message: Any - future: Future[Any] - sender: AgentId - recipient: AgentId | None - - P = ParamSpec("P") T = TypeVar("T", bound=Agent) @@ -99,7 +62,7 @@ class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]): return self -class RuntimeConnection: +class HostConnection: DEFAULT_GRPC_CONFIG: ClassVar[Mapping[str, Any]] = { "methodConfig": [ { @@ -129,9 +92,6 @@ class RuntimeConnection: channel = grpc.aio.insecure_channel( connection_string, options=[("grpc.service_config", json.dumps(grpc_config))] ) - # logger.info("awaiting channel_ready") - # await channel.channel_ready() - # logger.info("channel_ready") instance = cls(channel) instance._connection_task = asyncio.create_task( instance._connect(channel, instance._send_queue, instance._recv_queue) @@ -176,102 +136,69 @@ class RuntimeConnection: async def recv(self) -> Message: logger.info("Getting message from queue") return await self._recv_queue.get() - logger.info("Got message from queue") class WorkerAgentRuntime(AgentRuntime): def __init__(self) -> None: - self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] - # (namespace, type) -> List[AgentId] self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set) self._agent_factories: Dict[ str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] ] = {} - # If empty, then all namespaces are valid for that agent type - self._valid_namespaces: Dict[str, Sequence[str]] = {} self._instantiated_agents: Dict[AgentId, Agent] = {} self._known_namespaces: set[str] = set() self._read_task: None | Task[None] = None self._running = False self._pending_requests: Dict[str, Future[Any]] = {} - self._pending_requests_lock = threading.Lock() + self._pending_requests_lock = asyncio.Lock() self._next_request_id = 0 - self._runtime_connection: RuntimeConnection | None = None + self._host_connection: HostConnection | None = None + self._background_tasks: Set[Task[Any]] = set() - async def setup_channel(self, connection_string: str) -> None: - logger.info(f"connecting to: {connection_string}") - self._runtime_connection = await RuntimeConnection.from_connection_string(connection_string) + async def start(self, host_connection_string: str) -> None: + if self._running: + raise ValueError("Runtime is already running.") + logger.info(f"Connecting to host: {host_connection_string}") + self._host_connection = await HostConnection.from_connection_string(host_connection_string) logger.info("connection") if self._read_task is None: - self._read_task = asyncio.create_task(self.run_read_loop()) + self._read_task = asyncio.create_task(self._run_read_loop()) self._running = True - async def send_register_agent_type(self, agent_type: str) -> None: - assert self._runtime_connection is not None - message = Message(registerAgentType=RegisterAgentType(type=agent_type)) - await self._runtime_connection.send(message) - logger.info("Sent registerAgentType message for %s", agent_type) - - async def run_read_loop(self) -> None: + async def _run_read_loop(self) -> None: logger.info("Starting read loop") # TODO: catch exceptions and reconnect while self._running: try: - message = await self._runtime_connection.recv() # type: ignore + message = await self._host_connection.recv() # type: ignore logger.info("Got message: %s", message) oneofcase = Message.WhichOneof(message, "message") match oneofcase: case "registerAgentType": - logger.warn("Cant handle registerAgentType") + logger.warn("Cant handle registerAgentType, skipping.") case "request": - # request: RpcRequest = message.request - # source = AgentId(request.source.name, request.source.namespace) - # target = AgentId(request.target.name, request.target.namespace) - - raise NotImplementedError("Sending messages is not yet implemented.") + request: RpcRequest = message.request + task = asyncio.create_task(self._process_request(request)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) case "response": response: RpcResponse = message.response - future = self._pending_requests.pop(response.request_id) - if len(response.error) > 0: - future.set_exception(Exception(response.error)) - break - future.set_result(response.result) + task = asyncio.create_task(self._process_response(response)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) case "event": event: Event = message.event - message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type) - # namespace = event.namespace - namespace = "default" - - logger.info("Got event: %s", message) - for agent_id in self._per_type_subscribers[ - (namespace, MESSAGE_TYPE_REGISTRY.type_name(message)) - ]: - logger.info("Sending message to %s", agent_id) - agent = await self._get_agent(agent_id) - message_context = MessageContext( - # TODO: should sender be in the proto even for published events? - sender=None, - # TODO: topic_id - topic_id=None, - is_rpc=False, - cancellation_token=CancellationToken(), - ) - try: - await agent.on_message(message, ctx=message_context) - logger.info("%s handled event %s", agent_id, message) - except Exception as e: - event_logger.error("Error handling message", exc_info=e) - - logger.warn("Cant handle event") + task = asyncio.create_task(self._process_event(event)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) case None: logger.warn("No message") except Exception as e: logger.error("Error in read loop", exc_info=e) - async def close_channel(self) -> None: + async def stop(self) -> None: self._running = False - if self._runtime_connection is not None: - await self._runtime_connection.close() + if self._host_connection is not None: + await self._host_connection.close() if self._read_task is not None: await self._read_task @@ -279,7 +206,6 @@ class WorkerAgentRuntime(AgentRuntime): def _known_agent_names(self) -> Set[str]: return set(self._agent_factories.keys()) - # Returns the response of the message async def send_message( self, message: Any, @@ -288,25 +214,32 @@ class WorkerAgentRuntime(AgentRuntime): sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> Any: - assert self._runtime_connection is not None + if not self._running: + raise ValueError("Runtime must be running when sending message.") + assert self._host_connection is not None # create a new future for the result future = asyncio.get_event_loop().create_future() - with self._pending_requests_lock: + async with self._pending_requests_lock: self._next_request_id += 1 request_id = self._next_request_id - request_id_str = str(request_id) - self._pending_requests[request_id_str] = future - sender = cast(AgentId, sender) - runtime_message = Message( - request=RpcRequest( - request_id=request_id_str, - target=AgentIdProto(name=recipient.type, namespace=recipient.key), - source=AgentIdProto(name=sender.type, namespace=sender.key), - data=message, - ) + request_id_str = str(request_id) + self._pending_requests[request_id_str] = future + sender = cast(AgentId, sender) + method = MESSAGE_TYPE_REGISTRY.type_name(message) + serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=method) + runtime_message = Message( + request=RpcRequest( + request_id=request_id_str, + target=AgentIdProto(name=recipient.type, namespace=recipient.key), + source=AgentIdProto(name=sender.type, namespace=sender.key), + method=method, + data=serialized_message, ) - # TODO: Find a way to handle timeouts/errors - asyncio.create_task(self._runtime_connection.send(runtime_message)) + ) + # TODO: Find a way to handle timeouts/errors + task = asyncio.create_task(self._host_connection.send(runtime_message)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) return await future async def publish_message( @@ -317,26 +250,24 @@ class WorkerAgentRuntime(AgentRuntime): sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> None: - assert self._runtime_connection is not None + if not self._running: + raise ValueError("Runtime must be running when publishing message.") + assert self._host_connection is not None sender_namespace = sender.key if sender is not None else None explicit_namespace = namespace if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace: raise ValueError( f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}" ) - assert explicit_namespace is not None or sender_namespace is not None actual_namespace = cast(str, explicit_namespace or sender_namespace) await self._process_seen_namespace(actual_namespace) message_type = MESSAGE_TYPE_REGISTRY.type_name(message) serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type) message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message)) - - async def write_message() -> None: - assert self._runtime_connection is not None - await self._runtime_connection.send(message) - - await asyncio.create_task(write_message()) + task = asyncio.create_task(self._host_connection.send(message)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) async def save_state(self) -> Mapping[str, Any]: raise NotImplementedError("Saving state is not yet implemented.") @@ -358,6 +289,8 @@ class WorkerAgentRuntime(AgentRuntime): name: str, agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], ) -> None: + if not self._running: + raise ValueError("Runtime must be running when registering agent.") if name in self._agent_factories: raise ValueError(f"Agent with name {name} already exists.") self._agent_factories[name] = agent_factory @@ -366,7 +299,75 @@ class WorkerAgentRuntime(AgentRuntime): for namespace in self._known_namespaces: await self._get_agent(AgentId(type=name, key=namespace)) - await self.send_register_agent_type(name) + assert self._host_connection is not None + message = Message(registerAgentType=RegisterAgentType(type=name)) + await self._host_connection.send(message) + logger.info("Sent registerAgentType message for %s", name) + + async def _process_request(self, request: RpcRequest) -> None: + assert self._host_connection is not None + target = AgentId(request.target.name, request.target.namespace) + source = AgentId(request.source.name, request.source.namespace) + + try: + logging.info(f"Processing request from {source} to {target}") + target_agent = await self._get_agent(target) + message_context = MessageContext( + sender=source, + topic_id=None, + is_rpc=True, + cancellation_token=CancellationToken(), + ) + message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.method) + response = await target_agent.on_message(message, ctx=message_context) + serialized_response = MESSAGE_TYPE_REGISTRY.serialize(response, type_name=request.method) + response_message = Message( + response=RpcResponse( + request_id=request.request_id, + result=serialized_response, + ) + ) + except BaseException as e: + response_message = Message( + response=RpcResponse( + request_id=request.request_id, + error=str(e), + ) + ) + + # Send the response. + await self._host_connection.send(response_message) + + async def _process_response(self, response: RpcResponse) -> None: + # TODO: deserialize the response and set the future result + future = self._pending_requests.pop(response.request_id) + if len(response.error) > 0: + future.set_exception(Exception(response.error)) + else: + future.set_result(response.result) + + async def _process_event(self, event: Event) -> None: + message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type) + namespace = event.namespace + responses: List[Awaitable[Any]] = [] + for agent_id in self._per_type_subscribers[(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))]: + # TODO: skip the sender? + message_context = MessageContext( + sender=None, + topic_id=None, + is_rpc=False, + cancellation_token=CancellationToken(), + ) + agent = await self._get_agent(agent_id) + future = agent.on_message(message, ctx=message_context) + responses.append(future) + + try: + _ = await asyncio.gather(*responses) + except BaseException as e: + if isinstance(e, asyncio.CancelledError): + return + event_logger.error("Error handling event message", exc_info=e) async def _invoke_agent_factory( self, diff --git a/python/src/agnext/worker/protos/__init__.py b/python/src/agnext/application/protos/__init__.py similarity index 51% rename from python/src/agnext/worker/protos/__init__.py rename to python/src/agnext/application/protos/__init__.py index da9c3c155..d03694c3d 100644 --- a/python/src/agnext/worker/protos/__init__.py +++ b/python/src/agnext/application/protos/__init__.py @@ -9,11 +9,21 @@ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) from typing import TYPE_CHECKING -from .agent_worker_pb2 import Event, Message, RegisterAgentType, RpcRequest, RpcResponse, AgentId -from .agent_worker_pb2_grpc import AgentRpcStub +from .agent_worker_pb2 import AgentId, Event, Message, RegisterAgentType, RpcRequest, RpcResponse +from .agent_worker_pb2_grpc import AgentRpcServicer, AgentRpcStub, add_AgentRpcServicer_to_server if TYPE_CHECKING: from .agent_worker_pb2_grpc import AgentRpcAsyncStub - __all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcAsyncStub", "AgentRpcStub", "Message", "AgentId"] + + __all__ = [ + "RpcRequest", + "RpcResponse", + "Event", + "RegisterAgentType", + "AgentRpcAsyncStub", + "AgentRpcStub", + "Message", + "AgentId", + ] else: __all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcStub", "Message", "AgentId"] diff --git a/python/src/agnext/worker/__init__.py b/python/src/agnext/worker/__init__.py deleted file mode 100644 index a0c6457f5..000000000 --- a/python/src/agnext/worker/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -The :mod:`agnext.worker` module provides a set of classes for creating distributed agents -""" - -from .worker_runtime import WorkerAgentRuntime - -__all__ = ["WorkerAgentRuntime"] diff --git a/python/worker_example.py b/python/worker_example.py deleted file mode 100644 index 00823ec9e..000000000 --- a/python/worker_example.py +++ /dev/null @@ -1,42 +0,0 @@ -from agnext.worker.worker_runtime import WorkerAgentRuntime -from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import CancellationToken, AgentId -import logging -import asyncio -import os - -from dataclasses import dataclass - -@dataclass -class ExampleMessagePayload: - content: str - - -class ExampleAgent(TypeRoutedAgent): - def __init__(self) -> None: - super().__init__("Example Agent") - - @message_handler - async def on_example_payload(self, message: ExampleMessagePayload, cancellation_token: CancellationToken) -> None: - upper_case = message.content.upper() - await self.publish_message(ExampleMessagePayload(content=upper_case)) - - -async def main() -> None: - logger = logging.getLogger("main") - runtime = WorkerAgentRuntime() - await runtime.setup_channel(os.environ["AGENT_HOST"]) - - runtime.register("ExampleAgent", lambda: ExampleAgent()) - while True: - try: - res = await runtime.send_message("testing!", recipient=AgentId(name="greeter", namespace="testing"), sender=AgentId(name="ExampleAgent", namespace="testing")) - logger.info("Response: %s", res) - except Exception as e: - logger.warning("Error: %s", e) - await asyncio.sleep(5) - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - asyncio.run(main()) -