Python host and worker runtime for distributed agents. (#369)

* Python host runtime impl

* update

* ignore proto generated files

* move worker runtime to application

* Move example to samples

* Fix import

* fix

* update

* server client

* better shutdown

* fix doc conf

* add type
This commit is contained in:
Eric Zhu 2024-08-19 07:06:41 -07:00 committed by GitHub
parent eb4a5b7df5
commit 5eca0dba4a
15 changed files with 515 additions and 188 deletions

3
python/.gitignore vendored
View File

@ -166,6 +166,7 @@ cython_debug/
# Generated proto files
src/agnext/worker/protos/agent*
src/agnext/application/protos/agent*
# Generated log files
log.jsonl
@ -174,4 +175,4 @@ log.jsonl
docs/**/jupyter_execute
# Temporary files
tmp_code_*.py
tmp_code_*.py

View File

@ -28,7 +28,7 @@ apidoc_template_dir = "_apidoc_templates"
apidoc_separate_modules = True
apidoc_extra_args = ["--no-toc"]
napoleon_custom_sections = [("Returns", "params_style")]
apidoc_excluded_paths = ["./worker/protos/"]
apidoc_excluded_paths = ["./application/protos/"]
templates_path = []
exclude_patterns = ["reference/agnext.rst"]

View File

@ -59,7 +59,6 @@ To learn about the core concepts of AGNext, read the `overview <core-concepts/ov
reference/agnext.components
reference/agnext.application
reference/agnext.core
reference/agnext.worker
.. toctree::
:caption: Other

View File

@ -128,7 +128,7 @@ dependencies = [
[tool.ruff]
line-length = 120
fix = true
exclude = ["build", "dist", "src/agnext/worker/protos"]
exclude = ["build", "dist", "src/agnext/application/protos"]
target-version = "py310"
include = ["src/**", "samples/*.py", "docs/**/*.ipynb"]
@ -145,7 +145,7 @@ ignore = ["F401", "E501"]
[tool.mypy]
files = ["src", "samples", "tests"]
exclude = ["src/agnext/worker/protos"]
exclude = ["src/agnext/application/protos"]
strict = true
python_version = "3.10"
@ -168,7 +168,7 @@ include = ["src", "tests", "samples"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false
exclude = ["src/agnext/worker/protos"]
exclude = ["src/agnext/application/protos"]
[tool.pytest.ini_options]
minversion = "6.0"
@ -178,7 +178,7 @@ testpaths = ["tests"]
dependencies = ["hatch-protobuf", "mypy-protobuf~=3.0"]
generate_pyi = false
proto_paths = ["../protos"]
output_path = "src/agnext/worker/protos"
output_path = "src/agnext/application/protos"
[[tool.hatch.build.hooks.protobuf.generators]]

View File

@ -1,8 +1,8 @@
import asyncio
import logging
from agnext.application import WorkerAgentRuntime
from agnext.core._serialization import MESSAGE_TYPE_REGISTRY
from agnext.worker.worker_runtime import WorkerAgentRuntime
from app import build_app
from dotenv import load_dotenv
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
@ -18,7 +18,7 @@ async def main() -> None:
MESSAGE_TYPE_REGISTRY.add_type(AuditText)
MESSAGE_TYPE_REGISTRY.add_type(AuditorAlert)
agnext_logger.info("1")
await runtime.setup_channel("localhost:5145")
await runtime.start("localhost:5145")
agnext_logger.info("2")
@ -30,7 +30,7 @@ async def main() -> None:
await asyncio.sleep(1000000)
except KeyboardInterrupt:
pass
await runtime.close_channel()
await runtime.stop()
if __name__ == "__main__":

View File

@ -0,0 +1,53 @@
import asyncio
import signal
import grpc
from agnext.application import HostRuntimeServicer
from agnext.application.protos import agent_worker_pb2_grpc
async def serve(server: grpc.aio.Server) -> None: # type: ignore
await server.start()
print("Server started")
await server.wait_for_termination()
async def main() -> None:
server = grpc.aio.server()
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(HostRuntimeServicer(), server)
server.add_insecure_port("[::]:50051")
# Set up signal handling for graceful shutdown
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def signal_handler() -> None:
print("Received exit signal, shutting down gracefully...")
shutdown_event.set()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
# Start server in background task
serve_task = asyncio.create_task(serve(server))
# Wait for the signal to trigger the shutdown event
await shutdown_event.wait()
# Graceful shutdown
await server.stop(5) # 5 second grace period
await serve_task
print("Server stopped")
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Server shutdown interrupted.")

View File

@ -0,0 +1,87 @@
import asyncio
import logging
from dataclasses import dataclass
from agnext.application import WorkerAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext
@dataclass
class AskToGreet:
content: str
@dataclass
class Greeting:
content: str
@dataclass
class ReturnedGreeting:
content: str
@dataclass
class Feedback:
content: str
@dataclass
class ReturnedFeedback:
content: str
class ReceiveAgent(TypeRoutedAgent):
def __init__(self) -> None:
super().__init__("Receive Agent")
@message_handler
async def on_greet(self, message: Greeting, ctx: MessageContext) -> None:
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"))
@message_handler
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"))
class GreeterAgent(TypeRoutedAgent):
def __init__(self) -> None:
super().__init__("Greeter Agent")
@message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
await self.publish_message(Greeting(f"Hello, {message.content}!"))
@message_handler
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
await self.publish_message(Feedback(f"Feedback: {message.content}"))
async def main() -> None:
runtime = WorkerAgentRuntime()
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
MESSAGE_TYPE_REGISTRY.add_type(ReturnedGreeting)
MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback)
await runtime.start(host_connection_string="localhost:50051")
await runtime.register("reciever", lambda: ReceiveAgent())
await runtime.register("greeter", lambda: GreeterAgent())
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
# Just to keep the runtime running
try:
await asyncio.sleep(1000000)
except KeyboardInterrupt:
pass
await runtime.stop()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("agnext")
logger.setLevel(logging.DEBUG)
asyncio.run(main())

View File

@ -0,0 +1,74 @@
import asyncio
import logging
from dataclasses import dataclass
from agnext.application import WorkerAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext
@dataclass
class AskToGreet:
content: str
@dataclass
class Greeting:
content: str
@dataclass
class Feedback:
content: str
class ReceiveAgent(TypeRoutedAgent):
def __init__(self) -> None:
super().__init__("Receive Agent")
@message_handler
async def on_greet(self, message: Greeting, ctx: MessageContext) -> Greeting:
return Greeting(content=f"Received: {message.content}")
@message_handler
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
print(f"Feedback received: {message.content}")
class GreeterAgent(TypeRoutedAgent):
def __init__(self, receive_agent_id: AgentId) -> None:
super().__init__("Greeter Agent")
self._receive_agent_id = receive_agent_id
@message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id)
await self.publish_message(Feedback(f"Feedback: {response.content}"))
async def main() -> None:
runtime = WorkerAgentRuntime()
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
await runtime.start(host_connection_string="localhost:50051")
await runtime.register("reciever", lambda: ReceiveAgent())
reciever = await runtime.get("reciever")
await runtime.register("greeter", lambda: GreeterAgent(reciever))
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
# Just to keep the runtime running
try:
await asyncio.sleep(1000000)
except KeyboardInterrupt:
pass
await runtime.stop()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("agnext")
logger.setLevel(logging.DEBUG)
asyncio.run(main())

View File

@ -2,6 +2,8 @@
The :mod:`agnext.application` module provides implementations of core components that are used to compose an application
"""
from ._host_runtime_servicer import HostRuntimeServicer
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
from ._worker_runtime import WorkerAgentRuntime
__all__ = ["SingleThreadedAgentRuntime"]
__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "HostRuntimeServicer"]

View File

@ -0,0 +1,150 @@
import asyncio
import logging
from _collections_abc import AsyncIterator, Iterator
from asyncio import Future, Task
from typing import Any, Dict, Set
import grpc
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
logger = logging.getLogger("agnext")
event_logger = logging.getLogger("agnext.events")
class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
"""A gRPC servicer that hosts message delivery service for agents."""
def __init__(self) -> None:
self._client_id = 0
self._client_id_lock = asyncio.Lock()
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._agent_type_to_client_id: Dict[str, int] = {}
self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set()
async def OpenChannel( # type: ignore
self,
request_iterator: AsyncIterator[agent_worker_pb2.Message],
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
# Aquire the lock to get a new client id.
async with self._client_id_lock:
self._client_id += 1
client_id = self._client_id
# Register the client with the server and create a send queue for the client.
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
self._send_queues[client_id] = send_queue
logger.info(f"Client {client_id} connected.")
try:
# Concurrently handle receiving messages from the client and sending messages to the client.
# This task will receive messages from the client.
receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
# Return an async generator that will yield messages from the send queue to the client.
while True:
message = await send_queue.get()
# Yield the message to the client.
try:
yield message
except Exception as e:
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
break
# Wait for the receiving task to finish.
await receiving_task
finally:
# Clean up the client connection.
del self._send_queues[client_id]
# Cancel pending requests sent to this client.
for future in self._pending_requests.pop(client_id, {}).values():
future.cancel()
logger.info(f"Client {client_id} disconnected.")
async def _receive_messages(
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
) -> None:
# Receive messages from the client and process them.
async for message in request_iterator:
oneofcase = message.WhichOneof("message")
match oneofcase:
case "request":
request: agent_worker_pb2.RpcRequest = message.request
logger.info(f"Received request message: {request}")
task = asyncio.create_task(self._process_request(request, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: agent_worker_pb2.RpcResponse = message.response
logger.info(f"Received response message: {response}")
task = asyncio.create_task(self._process_response(response, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case "event":
event: agent_worker_pb2.Event = message.event
logger.info(f"Received event message: {event}")
task = asyncio.create_task(self._process_event(event))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentType":
register_agent_type: agent_worker_pb2.RegisterAgentType = message.registerAgentType
logger.info(f"Received register agent type message: {register_agent_type}")
task = asyncio.create_task(self._process_register_agent_type(register_agent_type, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("Received empty message")
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
# Deliver the message to a client given the target agent type.
target_client_id = self._agent_type_to_client_id.get(request.target.name)
if target_client_id is None:
logger.error(f"Agent {request.target.name} not found, failed to deliver message.")
return
target_send_queue = self._send_queues.get(target_client_id)
if target_send_queue is None:
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
return
await target_send_queue.put(agent_worker_pb2.Message(request=request))
# Create a future to wait for the response.
future = asyncio.get_event_loop().create_future()
self._pending_requests.setdefault(client_id, {})[request.request_id] = future
# Create a task to wait for the response and send it back to the client.
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
self._background_tasks.add(send_response_task)
send_response_task.add_done_callback(self._background_tasks.discard)
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
response = await future
message = agent_worker_pb2.Message(response=response)
send_queue = self._send_queues.get(client_id)
if send_queue is None:
logger.error(f"Client {client_id} not found, failed to send response message.")
return
await send_queue.put(message)
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
# Setting the result of the future will send the response back to the original sender.
future = self._pending_requests[client_id].pop(response.request_id)
future.set_result(response)
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
# Deliver the event to all the clients.
# TODO: deliver based on subscriptions.
for send_queue in self._send_queues.values():
await send_queue.put(agent_worker_pb2.Message(event=event))
async def _process_register_agent_type(
self, register_agent_type: agent_worker_pb2.RegisterAgentType, client_id: int
) -> None:
# Register the agent type with the host runtime.
if register_agent_type.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
logger.warning(
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}, overwriting the client mapping to client {client_id}."
)
self._agent_type_to_client_id[register_agent_type.type] = client_id

View File

@ -12,8 +12,6 @@ from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from agnext.core import MessageContext
from ..core import (
MESSAGE_TYPE_REGISTRY,
Agent,
@ -23,6 +21,7 @@ from ..core import (
AgentProxy,
AgentRuntime,
CancellationToken,
MessageContext,
)
from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler

View File

@ -2,12 +2,9 @@ import asyncio
import inspect
import json
import logging
import threading
import warnings
from asyncio import Future, Task
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
@ -50,40 +47,6 @@ if TYPE_CHECKING:
logger = logging.getLogger("agnext")
event_logger = logging.getLogger("agnext.events")
@dataclass(kw_only=True)
class PublishMessageEnvelope:
"""A message envelope for publishing messages to all agents that can handle
the message of the type T."""
message: Any
cancellation_token: CancellationToken
sender: AgentId | None
namespace: str
@dataclass(kw_only=True)
class SendMessageEnvelope:
"""A message envelope for sending a message to a specific agent that can handle
the message of the type T."""
message: Any
sender: AgentId | None
recipient: AgentId
future: Future[Any]
cancellation_token: CancellationToken
@dataclass(kw_only=True)
class ResponseMessageEnvelope:
"""A message envelope for sending a response to a message."""
message: Any
future: Future[Any]
sender: AgentId
recipient: AgentId | None
P = ParamSpec("P")
T = TypeVar("T", bound=Agent)
@ -99,7 +62,7 @@ class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]):
return self
class RuntimeConnection:
class HostConnection:
DEFAULT_GRPC_CONFIG: ClassVar[Mapping[str, Any]] = {
"methodConfig": [
{
@ -129,9 +92,6 @@ class RuntimeConnection:
channel = grpc.aio.insecure_channel(
connection_string, options=[("grpc.service_config", json.dumps(grpc_config))]
)
# logger.info("awaiting channel_ready")
# await channel.channel_ready()
# logger.info("channel_ready")
instance = cls(channel)
instance._connection_task = asyncio.create_task(
instance._connect(channel, instance._send_queue, instance._recv_queue)
@ -176,102 +136,69 @@ class RuntimeConnection:
async def recv(self) -> Message:
logger.info("Getting message from queue")
return await self._recv_queue.get()
logger.info("Got message from queue")
class WorkerAgentRuntime(AgentRuntime):
def __init__(self) -> None:
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
] = {}
# If empty, then all namespaces are valid for that agent type
self._valid_namespaces: Dict[str, Sequence[str]] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._known_namespaces: set[str] = set()
self._read_task: None | Task[None] = None
self._running = False
self._pending_requests: Dict[str, Future[Any]] = {}
self._pending_requests_lock = threading.Lock()
self._pending_requests_lock = asyncio.Lock()
self._next_request_id = 0
self._runtime_connection: RuntimeConnection | None = None
self._host_connection: HostConnection | None = None
self._background_tasks: Set[Task[Any]] = set()
async def setup_channel(self, connection_string: str) -> None:
logger.info(f"connecting to: {connection_string}")
self._runtime_connection = await RuntimeConnection.from_connection_string(connection_string)
async def start(self, host_connection_string: str) -> None:
if self._running:
raise ValueError("Runtime is already running.")
logger.info(f"Connecting to host: {host_connection_string}")
self._host_connection = await HostConnection.from_connection_string(host_connection_string)
logger.info("connection")
if self._read_task is None:
self._read_task = asyncio.create_task(self.run_read_loop())
self._read_task = asyncio.create_task(self._run_read_loop())
self._running = True
async def send_register_agent_type(self, agent_type: str) -> None:
assert self._runtime_connection is not None
message = Message(registerAgentType=RegisterAgentType(type=agent_type))
await self._runtime_connection.send(message)
logger.info("Sent registerAgentType message for %s", agent_type)
async def run_read_loop(self) -> None:
async def _run_read_loop(self) -> None:
logger.info("Starting read loop")
# TODO: catch exceptions and reconnect
while self._running:
try:
message = await self._runtime_connection.recv() # type: ignore
message = await self._host_connection.recv() # type: ignore
logger.info("Got message: %s", message)
oneofcase = Message.WhichOneof(message, "message")
match oneofcase:
case "registerAgentType":
logger.warn("Cant handle registerAgentType")
logger.warn("Cant handle registerAgentType, skipping.")
case "request":
# request: RpcRequest = message.request
# source = AgentId(request.source.name, request.source.namespace)
# target = AgentId(request.target.name, request.target.namespace)
raise NotImplementedError("Sending messages is not yet implemented.")
request: RpcRequest = message.request
task = asyncio.create_task(self._process_request(request))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: RpcResponse = message.response
future = self._pending_requests.pop(response.request_id)
if len(response.error) > 0:
future.set_exception(Exception(response.error))
break
future.set_result(response.result)
task = asyncio.create_task(self._process_response(response))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case "event":
event: Event = message.event
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
# namespace = event.namespace
namespace = "default"
logger.info("Got event: %s", message)
for agent_id in self._per_type_subscribers[
(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
]:
logger.info("Sending message to %s", agent_id)
agent = await self._get_agent(agent_id)
message_context = MessageContext(
# TODO: should sender be in the proto even for published events?
sender=None,
# TODO: topic_id
topic_id=None,
is_rpc=False,
cancellation_token=CancellationToken(),
)
try:
await agent.on_message(message, ctx=message_context)
logger.info("%s handled event %s", agent_id, message)
except Exception as e:
event_logger.error("Error handling message", exc_info=e)
logger.warn("Cant handle event")
task = asyncio.create_task(self._process_event(event))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warn("No message")
except Exception as e:
logger.error("Error in read loop", exc_info=e)
async def close_channel(self) -> None:
async def stop(self) -> None:
self._running = False
if self._runtime_connection is not None:
await self._runtime_connection.close()
if self._host_connection is not None:
await self._host_connection.close()
if self._read_task is not None:
await self._read_task
@ -279,7 +206,6 @@ class WorkerAgentRuntime(AgentRuntime):
def _known_agent_names(self) -> Set[str]:
return set(self._agent_factories.keys())
# Returns the response of the message
async def send_message(
self,
message: Any,
@ -288,25 +214,32 @@ class WorkerAgentRuntime(AgentRuntime):
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Any:
assert self._runtime_connection is not None
if not self._running:
raise ValueError("Runtime must be running when sending message.")
assert self._host_connection is not None
# create a new future for the result
future = asyncio.get_event_loop().create_future()
with self._pending_requests_lock:
async with self._pending_requests_lock:
self._next_request_id += 1
request_id = self._next_request_id
request_id_str = str(request_id)
self._pending_requests[request_id_str] = future
sender = cast(AgentId, sender)
runtime_message = Message(
request=RpcRequest(
request_id=request_id_str,
target=AgentIdProto(name=recipient.type, namespace=recipient.key),
source=AgentIdProto(name=sender.type, namespace=sender.key),
data=message,
)
request_id_str = str(request_id)
self._pending_requests[request_id_str] = future
sender = cast(AgentId, sender)
method = MESSAGE_TYPE_REGISTRY.type_name(message)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=method)
runtime_message = Message(
request=RpcRequest(
request_id=request_id_str,
target=AgentIdProto(name=recipient.type, namespace=recipient.key),
source=AgentIdProto(name=sender.type, namespace=sender.key),
method=method,
data=serialized_message,
)
# TODO: Find a way to handle timeouts/errors
asyncio.create_task(self._runtime_connection.send(runtime_message))
)
# TODO: Find a way to handle timeouts/errors
task = asyncio.create_task(self._host_connection.send(runtime_message))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return await future
async def publish_message(
@ -317,26 +250,24 @@ class WorkerAgentRuntime(AgentRuntime):
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
assert self._runtime_connection is not None
if not self._running:
raise ValueError("Runtime must be running when publishing message.")
assert self._host_connection is not None
sender_namespace = sender.key if sender is not None else None
explicit_namespace = namespace
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
raise ValueError(
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
)
assert explicit_namespace is not None or sender_namespace is not None
actual_namespace = cast(str, explicit_namespace or sender_namespace)
await self._process_seen_namespace(actual_namespace)
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message))
async def write_message() -> None:
assert self._runtime_connection is not None
await self._runtime_connection.send(message)
await asyncio.create_task(write_message())
task = asyncio.create_task(self._host_connection.send(message))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
async def save_state(self) -> Mapping[str, Any]:
raise NotImplementedError("Saving state is not yet implemented.")
@ -358,6 +289,8 @@ class WorkerAgentRuntime(AgentRuntime):
name: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None:
if not self._running:
raise ValueError("Runtime must be running when registering agent.")
if name in self._agent_factories:
raise ValueError(f"Agent with name {name} already exists.")
self._agent_factories[name] = agent_factory
@ -366,7 +299,75 @@ class WorkerAgentRuntime(AgentRuntime):
for namespace in self._known_namespaces:
await self._get_agent(AgentId(type=name, key=namespace))
await self.send_register_agent_type(name)
assert self._host_connection is not None
message = Message(registerAgentType=RegisterAgentType(type=name))
await self._host_connection.send(message)
logger.info("Sent registerAgentType message for %s", name)
async def _process_request(self, request: RpcRequest) -> None:
assert self._host_connection is not None
target = AgentId(request.target.name, request.target.namespace)
source = AgentId(request.source.name, request.source.namespace)
try:
logging.info(f"Processing request from {source} to {target}")
target_agent = await self._get_agent(target)
message_context = MessageContext(
sender=source,
topic_id=None,
is_rpc=True,
cancellation_token=CancellationToken(),
)
message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.method)
response = await target_agent.on_message(message, ctx=message_context)
serialized_response = MESSAGE_TYPE_REGISTRY.serialize(response, type_name=request.method)
response_message = Message(
response=RpcResponse(
request_id=request.request_id,
result=serialized_response,
)
)
except BaseException as e:
response_message = Message(
response=RpcResponse(
request_id=request.request_id,
error=str(e),
)
)
# Send the response.
await self._host_connection.send(response_message)
async def _process_response(self, response: RpcResponse) -> None:
# TODO: deserialize the response and set the future result
future = self._pending_requests.pop(response.request_id)
if len(response.error) > 0:
future.set_exception(Exception(response.error))
else:
future.set_result(response.result)
async def _process_event(self, event: Event) -> None:
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
namespace = event.namespace
responses: List[Awaitable[Any]] = []
for agent_id in self._per_type_subscribers[(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))]:
# TODO: skip the sender?
message_context = MessageContext(
sender=None,
topic_id=None,
is_rpc=False,
cancellation_token=CancellationToken(),
)
agent = await self._get_agent(agent_id)
future = agent.on_message(message, ctx=message_context)
responses.append(future)
try:
_ = await asyncio.gather(*responses)
except BaseException as e:
if isinstance(e, asyncio.CancelledError):
return
event_logger.error("Error handling event message", exc_info=e)
async def _invoke_agent_factory(
self,

View File

@ -9,11 +9,21 @@ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
from typing import TYPE_CHECKING
from .agent_worker_pb2 import Event, Message, RegisterAgentType, RpcRequest, RpcResponse, AgentId
from .agent_worker_pb2_grpc import AgentRpcStub
from .agent_worker_pb2 import AgentId, Event, Message, RegisterAgentType, RpcRequest, RpcResponse
from .agent_worker_pb2_grpc import AgentRpcServicer, AgentRpcStub, add_AgentRpcServicer_to_server
if TYPE_CHECKING:
from .agent_worker_pb2_grpc import AgentRpcAsyncStub
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcAsyncStub", "AgentRpcStub", "Message", "AgentId"]
__all__ = [
"RpcRequest",
"RpcResponse",
"Event",
"RegisterAgentType",
"AgentRpcAsyncStub",
"AgentRpcStub",
"Message",
"AgentId",
]
else:
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcStub", "Message", "AgentId"]

View File

@ -1,7 +0,0 @@
"""
The :mod:`agnext.worker` module provides a set of classes for creating distributed agents
"""
from .worker_runtime import WorkerAgentRuntime
__all__ = ["WorkerAgentRuntime"]

View File

@ -1,42 +0,0 @@
from agnext.worker.worker_runtime import WorkerAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import CancellationToken, AgentId
import logging
import asyncio
import os
from dataclasses import dataclass
@dataclass
class ExampleMessagePayload:
content: str
class ExampleAgent(TypeRoutedAgent):
def __init__(self) -> None:
super().__init__("Example Agent")
@message_handler
async def on_example_payload(self, message: ExampleMessagePayload, cancellation_token: CancellationToken) -> None:
upper_case = message.content.upper()
await self.publish_message(ExampleMessagePayload(content=upper_case))
async def main() -> None:
logger = logging.getLogger("main")
runtime = WorkerAgentRuntime()
await runtime.setup_channel(os.environ["AGENT_HOST"])
runtime.register("ExampleAgent", lambda: ExampleAgent())
while True:
try:
res = await runtime.send_message("testing!", recipient=AgentId(name="greeter", namespace="testing"), sender=AgentId(name="ExampleAgent", namespace="testing"))
logger.info("Response: %s", res)
except Exception as e:
logger.warning("Error: %s", e)
await asyncio.sleep(5)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
asyncio.run(main())