mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-18 05:31:25 +00:00
Python host and worker runtime for distributed agents. (#369)
* Python host runtime impl * update * ignore proto generated files * move worker runtime to application * Move example to samples * Fix import * fix * update * server client * better shutdown * fix doc conf * add type
This commit is contained in:
parent
eb4a5b7df5
commit
5eca0dba4a
3
python/.gitignore
vendored
3
python/.gitignore
vendored
@ -166,6 +166,7 @@ cython_debug/
|
||||
|
||||
# Generated proto files
|
||||
src/agnext/worker/protos/agent*
|
||||
src/agnext/application/protos/agent*
|
||||
|
||||
# Generated log files
|
||||
log.jsonl
|
||||
@ -174,4 +175,4 @@ log.jsonl
|
||||
docs/**/jupyter_execute
|
||||
|
||||
# Temporary files
|
||||
tmp_code_*.py
|
||||
tmp_code_*.py
|
||||
|
@ -28,7 +28,7 @@ apidoc_template_dir = "_apidoc_templates"
|
||||
apidoc_separate_modules = True
|
||||
apidoc_extra_args = ["--no-toc"]
|
||||
napoleon_custom_sections = [("Returns", "params_style")]
|
||||
apidoc_excluded_paths = ["./worker/protos/"]
|
||||
apidoc_excluded_paths = ["./application/protos/"]
|
||||
|
||||
templates_path = []
|
||||
exclude_patterns = ["reference/agnext.rst"]
|
||||
|
@ -59,7 +59,6 @@ To learn about the core concepts of AGNext, read the `overview <core-concepts/ov
|
||||
reference/agnext.components
|
||||
reference/agnext.application
|
||||
reference/agnext.core
|
||||
reference/agnext.worker
|
||||
|
||||
.. toctree::
|
||||
:caption: Other
|
||||
|
@ -128,7 +128,7 @@ dependencies = [
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
fix = true
|
||||
exclude = ["build", "dist", "src/agnext/worker/protos"]
|
||||
exclude = ["build", "dist", "src/agnext/application/protos"]
|
||||
target-version = "py310"
|
||||
include = ["src/**", "samples/*.py", "docs/**/*.ipynb"]
|
||||
|
||||
@ -145,7 +145,7 @@ ignore = ["F401", "E501"]
|
||||
|
||||
[tool.mypy]
|
||||
files = ["src", "samples", "tests"]
|
||||
exclude = ["src/agnext/worker/protos"]
|
||||
exclude = ["src/agnext/application/protos"]
|
||||
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
@ -168,7 +168,7 @@ include = ["src", "tests", "samples"]
|
||||
typeCheckingMode = "strict"
|
||||
reportUnnecessaryIsInstance = false
|
||||
reportMissingTypeStubs = false
|
||||
exclude = ["src/agnext/worker/protos"]
|
||||
exclude = ["src/agnext/application/protos"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
@ -178,7 +178,7 @@ testpaths = ["tests"]
|
||||
dependencies = ["hatch-protobuf", "mypy-protobuf~=3.0"]
|
||||
generate_pyi = false
|
||||
proto_paths = ["../protos"]
|
||||
output_path = "src/agnext/worker/protos"
|
||||
output_path = "src/agnext/application/protos"
|
||||
|
||||
|
||||
[[tool.hatch.build.hooks.protobuf.generators]]
|
||||
|
@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from agnext.application import WorkerAgentRuntime
|
||||
from agnext.core._serialization import MESSAGE_TYPE_REGISTRY
|
||||
from agnext.worker.worker_runtime import WorkerAgentRuntime
|
||||
from app import build_app
|
||||
from dotenv import load_dotenv
|
||||
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
|
||||
@ -18,7 +18,7 @@ async def main() -> None:
|
||||
MESSAGE_TYPE_REGISTRY.add_type(AuditText)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(AuditorAlert)
|
||||
agnext_logger.info("1")
|
||||
await runtime.setup_channel("localhost:5145")
|
||||
await runtime.start("localhost:5145")
|
||||
|
||||
agnext_logger.info("2")
|
||||
|
||||
@ -30,7 +30,7 @@ async def main() -> None:
|
||||
await asyncio.sleep(1000000)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
await runtime.close_channel()
|
||||
await runtime.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
53
python/samples/worker/run_host.py
Normal file
53
python/samples/worker/run_host.py
Normal file
@ -0,0 +1,53 @@
|
||||
import asyncio
|
||||
import signal
|
||||
|
||||
import grpc
|
||||
from agnext.application import HostRuntimeServicer
|
||||
from agnext.application.protos import agent_worker_pb2_grpc
|
||||
|
||||
|
||||
async def serve(server: grpc.aio.Server) -> None: # type: ignore
|
||||
await server.start()
|
||||
print("Server started")
|
||||
await server.wait_for_termination()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
server = grpc.aio.server()
|
||||
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(HostRuntimeServicer(), server)
|
||||
server.add_insecure_port("[::]:50051")
|
||||
|
||||
# Set up signal handling for graceful shutdown
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def signal_handler() -> None:
|
||||
print("Received exit signal, shutting down gracefully...")
|
||||
shutdown_event.set()
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Start server in background task
|
||||
serve_task = asyncio.create_task(serve(server))
|
||||
|
||||
# Wait for the signal to trigger the shutdown event
|
||||
await shutdown_event.wait()
|
||||
|
||||
# Graceful shutdown
|
||||
await server.stop(5) # 5 second grace period
|
||||
await serve_task
|
||||
print("Server stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("Server shutdown interrupted.")
|
87
python/samples/worker/run_worker_pub_sub.py
Normal file
87
python/samples/worker/run_worker_pub_sub.py
Normal file
@ -0,0 +1,87 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agnext.application import WorkerAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class AskToGreet:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Greeting:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReturnedGreeting:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Feedback:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReturnedFeedback:
|
||||
content: str
|
||||
|
||||
|
||||
class ReceiveAgent(TypeRoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Receive Agent")
|
||||
|
||||
@message_handler
|
||||
async def on_greet(self, message: Greeting, ctx: MessageContext) -> None:
|
||||
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"))
|
||||
|
||||
@message_handler
|
||||
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
|
||||
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"))
|
||||
|
||||
|
||||
class GreeterAgent(TypeRoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Greeter Agent")
|
||||
|
||||
@message_handler
|
||||
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
|
||||
await self.publish_message(Greeting(f"Hello, {message.content}!"))
|
||||
|
||||
@message_handler
|
||||
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
|
||||
await self.publish_message(Feedback(f"Feedback: {message.content}"))
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime()
|
||||
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(ReturnedGreeting)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback)
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
|
||||
await runtime.register("reciever", lambda: ReceiveAgent())
|
||||
await runtime.register("greeter", lambda: GreeterAgent())
|
||||
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
|
||||
|
||||
# Just to keep the runtime running
|
||||
try:
|
||||
await asyncio.sleep(1000000)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
await runtime.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger("agnext")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
asyncio.run(main())
|
74
python/samples/worker/run_worker_rpc.py
Normal file
74
python/samples/worker/run_worker_rpc.py
Normal file
@ -0,0 +1,74 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agnext.application import WorkerAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class AskToGreet:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Greeting:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Feedback:
|
||||
content: str
|
||||
|
||||
|
||||
class ReceiveAgent(TypeRoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Receive Agent")
|
||||
|
||||
@message_handler
|
||||
async def on_greet(self, message: Greeting, ctx: MessageContext) -> Greeting:
|
||||
return Greeting(content=f"Received: {message.content}")
|
||||
|
||||
@message_handler
|
||||
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
|
||||
print(f"Feedback received: {message.content}")
|
||||
|
||||
|
||||
class GreeterAgent(TypeRoutedAgent):
|
||||
def __init__(self, receive_agent_id: AgentId) -> None:
|
||||
super().__init__("Greeter Agent")
|
||||
self._receive_agent_id = receive_agent_id
|
||||
|
||||
@message_handler
|
||||
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
|
||||
response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id)
|
||||
await self.publish_message(Feedback(f"Feedback: {response.content}"))
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime()
|
||||
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
|
||||
await runtime.register("reciever", lambda: ReceiveAgent())
|
||||
reciever = await runtime.get("reciever")
|
||||
await runtime.register("greeter", lambda: GreeterAgent(reciever))
|
||||
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
|
||||
|
||||
# Just to keep the runtime running
|
||||
try:
|
||||
await asyncio.sleep(1000000)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
await runtime.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger("agnext")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
asyncio.run(main())
|
@ -2,6 +2,8 @@
|
||||
The :mod:`agnext.application` module provides implementations of core components that are used to compose an application
|
||||
"""
|
||||
|
||||
from ._host_runtime_servicer import HostRuntimeServicer
|
||||
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from ._worker_runtime import WorkerAgentRuntime
|
||||
|
||||
__all__ = ["SingleThreadedAgentRuntime"]
|
||||
__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "HostRuntimeServicer"]
|
||||
|
150
python/src/agnext/application/_host_runtime_servicer.py
Normal file
150
python/src/agnext/application/_host_runtime_servicer.py
Normal file
@ -0,0 +1,150 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from _collections_abc import AsyncIterator, Iterator
|
||||
from asyncio import Future, Task
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
import grpc
|
||||
|
||||
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
|
||||
|
||||
logger = logging.getLogger("agnext")
|
||||
event_logger = logging.getLogger("agnext.events")
|
||||
|
||||
|
||||
class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
"""A gRPC servicer that hosts message delivery service for agents."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._client_id = 0
|
||||
self._client_id_lock = asyncio.Lock()
|
||||
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
||||
self._agent_type_to_client_id: Dict[str, int] = {}
|
||||
self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {}
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
|
||||
async def OpenChannel( # type: ignore
|
||||
self,
|
||||
request_iterator: AsyncIterator[agent_worker_pb2.Message],
|
||||
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
|
||||
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
|
||||
# Aquire the lock to get a new client id.
|
||||
async with self._client_id_lock:
|
||||
self._client_id += 1
|
||||
client_id = self._client_id
|
||||
|
||||
# Register the client with the server and create a send queue for the client.
|
||||
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
|
||||
self._send_queues[client_id] = send_queue
|
||||
logger.info(f"Client {client_id} connected.")
|
||||
|
||||
try:
|
||||
# Concurrently handle receiving messages from the client and sending messages to the client.
|
||||
# This task will receive messages from the client.
|
||||
receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
|
||||
|
||||
# Return an async generator that will yield messages from the send queue to the client.
|
||||
while True:
|
||||
message = await send_queue.get()
|
||||
# Yield the message to the client.
|
||||
try:
|
||||
yield message
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
|
||||
break
|
||||
# Wait for the receiving task to finish.
|
||||
await receiving_task
|
||||
|
||||
finally:
|
||||
# Clean up the client connection.
|
||||
del self._send_queues[client_id]
|
||||
# Cancel pending requests sent to this client.
|
||||
for future in self._pending_requests.pop(client_id, {}).values():
|
||||
future.cancel()
|
||||
logger.info(f"Client {client_id} disconnected.")
|
||||
|
||||
async def _receive_messages(
|
||||
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
|
||||
) -> None:
|
||||
# Receive messages from the client and process them.
|
||||
async for message in request_iterator:
|
||||
oneofcase = message.WhichOneof("message")
|
||||
match oneofcase:
|
||||
case "request":
|
||||
request: agent_worker_pb2.RpcRequest = message.request
|
||||
logger.info(f"Received request message: {request}")
|
||||
task = asyncio.create_task(self._process_request(request, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "response":
|
||||
response: agent_worker_pb2.RpcResponse = message.response
|
||||
logger.info(f"Received response message: {response}")
|
||||
task = asyncio.create_task(self._process_response(response, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "event":
|
||||
event: agent_worker_pb2.Event = message.event
|
||||
logger.info(f"Received event message: {event}")
|
||||
task = asyncio.create_task(self._process_event(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "registerAgentType":
|
||||
register_agent_type: agent_worker_pb2.RegisterAgentType = message.registerAgentType
|
||||
logger.info(f"Received register agent type message: {register_agent_type}")
|
||||
task = asyncio.create_task(self._process_register_agent_type(register_agent_type, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warning("Received empty message")
|
||||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
|
||||
# Deliver the message to a client given the target agent type.
|
||||
target_client_id = self._agent_type_to_client_id.get(request.target.name)
|
||||
if target_client_id is None:
|
||||
logger.error(f"Agent {request.target.name} not found, failed to deliver message.")
|
||||
return
|
||||
target_send_queue = self._send_queues.get(target_client_id)
|
||||
if target_send_queue is None:
|
||||
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
|
||||
return
|
||||
await target_send_queue.put(agent_worker_pb2.Message(request=request))
|
||||
|
||||
# Create a future to wait for the response.
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._pending_requests.setdefault(client_id, {})[request.request_id] = future
|
||||
|
||||
# Create a task to wait for the response and send it back to the client.
|
||||
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
|
||||
self._background_tasks.add(send_response_task)
|
||||
send_response_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
|
||||
response = await future
|
||||
message = agent_worker_pb2.Message(response=response)
|
||||
send_queue = self._send_queues.get(client_id)
|
||||
if send_queue is None:
|
||||
logger.error(f"Client {client_id} not found, failed to send response message.")
|
||||
return
|
||||
await send_queue.put(message)
|
||||
|
||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
|
||||
# Setting the result of the future will send the response back to the original sender.
|
||||
future = self._pending_requests[client_id].pop(response.request_id)
|
||||
future.set_result(response)
|
||||
|
||||
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
||||
# Deliver the event to all the clients.
|
||||
# TODO: deliver based on subscriptions.
|
||||
for send_queue in self._send_queues.values():
|
||||
await send_queue.put(agent_worker_pb2.Message(event=event))
|
||||
|
||||
async def _process_register_agent_type(
|
||||
self, register_agent_type: agent_worker_pb2.RegisterAgentType, client_id: int
|
||||
) -> None:
|
||||
# Register the agent type with the host runtime.
|
||||
if register_agent_type.type in self._agent_type_to_client_id:
|
||||
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
|
||||
logger.warning(
|
||||
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}, overwriting the client mapping to client {client_id}."
|
||||
)
|
||||
self._agent_type_to_client_id[register_agent_type.type] = client_id
|
@ -12,8 +12,6 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from agnext.core import MessageContext
|
||||
|
||||
from ..core import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
Agent,
|
||||
@ -23,6 +21,7 @@ from ..core import (
|
||||
AgentProxy,
|
||||
AgentRuntime,
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
)
|
||||
from ..core.exceptions import MessageDroppedException
|
||||
from ..core.intervention import DropMessage, InterventionHandler
|
||||
|
@ -2,12 +2,9 @@ import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import warnings
|
||||
from asyncio import Future, Task
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -50,40 +47,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger("agnext")
|
||||
event_logger = logging.getLogger("agnext.events")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PublishMessageEnvelope:
|
||||
"""A message envelope for publishing messages to all agents that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
cancellation_token: CancellationToken
|
||||
sender: AgentId | None
|
||||
namespace: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SendMessageEnvelope:
|
||||
"""A message envelope for sending a message to a specific agent that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
sender: AgentId | None
|
||||
recipient: AgentId
|
||||
future: Future[Any]
|
||||
cancellation_token: CancellationToken
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ResponseMessageEnvelope:
|
||||
"""A message envelope for sending a response to a message."""
|
||||
|
||||
message: Any
|
||||
future: Future[Any]
|
||||
sender: AgentId
|
||||
recipient: AgentId | None
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
@ -99,7 +62,7 @@ class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]):
|
||||
return self
|
||||
|
||||
|
||||
class RuntimeConnection:
|
||||
class HostConnection:
|
||||
DEFAULT_GRPC_CONFIG: ClassVar[Mapping[str, Any]] = {
|
||||
"methodConfig": [
|
||||
{
|
||||
@ -129,9 +92,6 @@ class RuntimeConnection:
|
||||
channel = grpc.aio.insecure_channel(
|
||||
connection_string, options=[("grpc.service_config", json.dumps(grpc_config))]
|
||||
)
|
||||
# logger.info("awaiting channel_ready")
|
||||
# await channel.channel_ready()
|
||||
# logger.info("channel_ready")
|
||||
instance = cls(channel)
|
||||
instance._connection_task = asyncio.create_task(
|
||||
instance._connect(channel, instance._send_queue, instance._recv_queue)
|
||||
@ -176,102 +136,69 @@ class RuntimeConnection:
|
||||
async def recv(self) -> Message:
|
||||
logger.info("Getting message from queue")
|
||||
return await self._recv_queue.get()
|
||||
logger.info("Got message from queue")
|
||||
|
||||
|
||||
class WorkerAgentRuntime(AgentRuntime):
|
||||
def __init__(self) -> None:
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[
|
||||
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
||||
] = {}
|
||||
# If empty, then all namespaces are valid for that agent type
|
||||
self._valid_namespaces: Dict[str, Sequence[str]] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._known_namespaces: set[str] = set()
|
||||
self._read_task: None | Task[None] = None
|
||||
self._running = False
|
||||
self._pending_requests: Dict[str, Future[Any]] = {}
|
||||
self._pending_requests_lock = threading.Lock()
|
||||
self._pending_requests_lock = asyncio.Lock()
|
||||
self._next_request_id = 0
|
||||
self._runtime_connection: RuntimeConnection | None = None
|
||||
self._host_connection: HostConnection | None = None
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
|
||||
async def setup_channel(self, connection_string: str) -> None:
|
||||
logger.info(f"connecting to: {connection_string}")
|
||||
self._runtime_connection = await RuntimeConnection.from_connection_string(connection_string)
|
||||
async def start(self, host_connection_string: str) -> None:
|
||||
if self._running:
|
||||
raise ValueError("Runtime is already running.")
|
||||
logger.info(f"Connecting to host: {host_connection_string}")
|
||||
self._host_connection = await HostConnection.from_connection_string(host_connection_string)
|
||||
logger.info("connection")
|
||||
if self._read_task is None:
|
||||
self._read_task = asyncio.create_task(self.run_read_loop())
|
||||
self._read_task = asyncio.create_task(self._run_read_loop())
|
||||
self._running = True
|
||||
|
||||
async def send_register_agent_type(self, agent_type: str) -> None:
|
||||
assert self._runtime_connection is not None
|
||||
message = Message(registerAgentType=RegisterAgentType(type=agent_type))
|
||||
await self._runtime_connection.send(message)
|
||||
logger.info("Sent registerAgentType message for %s", agent_type)
|
||||
|
||||
async def run_read_loop(self) -> None:
|
||||
async def _run_read_loop(self) -> None:
|
||||
logger.info("Starting read loop")
|
||||
# TODO: catch exceptions and reconnect
|
||||
while self._running:
|
||||
try:
|
||||
message = await self._runtime_connection.recv() # type: ignore
|
||||
message = await self._host_connection.recv() # type: ignore
|
||||
logger.info("Got message: %s", message)
|
||||
oneofcase = Message.WhichOneof(message, "message")
|
||||
match oneofcase:
|
||||
case "registerAgentType":
|
||||
logger.warn("Cant handle registerAgentType")
|
||||
logger.warn("Cant handle registerAgentType, skipping.")
|
||||
case "request":
|
||||
# request: RpcRequest = message.request
|
||||
# source = AgentId(request.source.name, request.source.namespace)
|
||||
# target = AgentId(request.target.name, request.target.namespace)
|
||||
|
||||
raise NotImplementedError("Sending messages is not yet implemented.")
|
||||
request: RpcRequest = message.request
|
||||
task = asyncio.create_task(self._process_request(request))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "response":
|
||||
response: RpcResponse = message.response
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if len(response.error) > 0:
|
||||
future.set_exception(Exception(response.error))
|
||||
break
|
||||
future.set_result(response.result)
|
||||
task = asyncio.create_task(self._process_response(response))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "event":
|
||||
event: Event = message.event
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
|
||||
# namespace = event.namespace
|
||||
namespace = "default"
|
||||
|
||||
logger.info("Got event: %s", message)
|
||||
for agent_id in self._per_type_subscribers[
|
||||
(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
|
||||
]:
|
||||
logger.info("Sending message to %s", agent_id)
|
||||
agent = await self._get_agent(agent_id)
|
||||
message_context = MessageContext(
|
||||
# TODO: should sender be in the proto even for published events?
|
||||
sender=None,
|
||||
# TODO: topic_id
|
||||
topic_id=None,
|
||||
is_rpc=False,
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
try:
|
||||
await agent.on_message(message, ctx=message_context)
|
||||
logger.info("%s handled event %s", agent_id, message)
|
||||
except Exception as e:
|
||||
event_logger.error("Error handling message", exc_info=e)
|
||||
|
||||
logger.warn("Cant handle event")
|
||||
task = asyncio.create_task(self._process_event(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warn("No message")
|
||||
except Exception as e:
|
||||
logger.error("Error in read loop", exc_info=e)
|
||||
|
||||
async def close_channel(self) -> None:
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
if self._runtime_connection is not None:
|
||||
await self._runtime_connection.close()
|
||||
if self._host_connection is not None:
|
||||
await self._host_connection.close()
|
||||
if self._read_task is not None:
|
||||
await self._read_task
|
||||
|
||||
@ -279,7 +206,6 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
def _known_agent_names(self) -> Set[str]:
|
||||
return set(self._agent_factories.keys())
|
||||
|
||||
# Returns the response of the message
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
@ -288,25 +214,32 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Any:
|
||||
assert self._runtime_connection is not None
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when sending message.")
|
||||
assert self._host_connection is not None
|
||||
# create a new future for the result
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
with self._pending_requests_lock:
|
||||
async with self._pending_requests_lock:
|
||||
self._next_request_id += 1
|
||||
request_id = self._next_request_id
|
||||
request_id_str = str(request_id)
|
||||
self._pending_requests[request_id_str] = future
|
||||
sender = cast(AgentId, sender)
|
||||
runtime_message = Message(
|
||||
request=RpcRequest(
|
||||
request_id=request_id_str,
|
||||
target=AgentIdProto(name=recipient.type, namespace=recipient.key),
|
||||
source=AgentIdProto(name=sender.type, namespace=sender.key),
|
||||
data=message,
|
||||
)
|
||||
request_id_str = str(request_id)
|
||||
self._pending_requests[request_id_str] = future
|
||||
sender = cast(AgentId, sender)
|
||||
method = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=method)
|
||||
runtime_message = Message(
|
||||
request=RpcRequest(
|
||||
request_id=request_id_str,
|
||||
target=AgentIdProto(name=recipient.type, namespace=recipient.key),
|
||||
source=AgentIdProto(name=sender.type, namespace=sender.key),
|
||||
method=method,
|
||||
data=serialized_message,
|
||||
)
|
||||
# TODO: Find a way to handle timeouts/errors
|
||||
asyncio.create_task(self._runtime_connection.send(runtime_message))
|
||||
)
|
||||
# TODO: Find a way to handle timeouts/errors
|
||||
task = asyncio.create_task(self._host_connection.send(runtime_message))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return await future
|
||||
|
||||
async def publish_message(
|
||||
@ -317,26 +250,24 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None:
|
||||
assert self._runtime_connection is not None
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when publishing message.")
|
||||
assert self._host_connection is not None
|
||||
sender_namespace = sender.key if sender is not None else None
|
||||
explicit_namespace = namespace
|
||||
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
||||
raise ValueError(
|
||||
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
|
||||
)
|
||||
|
||||
assert explicit_namespace is not None or sender_namespace is not None
|
||||
actual_namespace = cast(str, explicit_namespace or sender_namespace)
|
||||
await self._process_seen_namespace(actual_namespace)
|
||||
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
|
||||
message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message))
|
||||
|
||||
async def write_message() -> None:
|
||||
assert self._runtime_connection is not None
|
||||
await self._runtime_connection.send(message)
|
||||
|
||||
await asyncio.create_task(write_message())
|
||||
task = asyncio.create_task(self._host_connection.send(message))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
raise NotImplementedError("Saving state is not yet implemented.")
|
||||
@ -358,6 +289,8 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
) -> None:
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when registering agent.")
|
||||
if name in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {name} already exists.")
|
||||
self._agent_factories[name] = agent_factory
|
||||
@ -366,7 +299,75 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
for namespace in self._known_namespaces:
|
||||
await self._get_agent(AgentId(type=name, key=namespace))
|
||||
|
||||
await self.send_register_agent_type(name)
|
||||
assert self._host_connection is not None
|
||||
message = Message(registerAgentType=RegisterAgentType(type=name))
|
||||
await self._host_connection.send(message)
|
||||
logger.info("Sent registerAgentType message for %s", name)
|
||||
|
||||
async def _process_request(self, request: RpcRequest) -> None:
|
||||
assert self._host_connection is not None
|
||||
target = AgentId(request.target.name, request.target.namespace)
|
||||
source = AgentId(request.source.name, request.source.namespace)
|
||||
|
||||
try:
|
||||
logging.info(f"Processing request from {source} to {target}")
|
||||
target_agent = await self._get_agent(target)
|
||||
message_context = MessageContext(
|
||||
sender=source,
|
||||
topic_id=None,
|
||||
is_rpc=True,
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.method)
|
||||
response = await target_agent.on_message(message, ctx=message_context)
|
||||
serialized_response = MESSAGE_TYPE_REGISTRY.serialize(response, type_name=request.method)
|
||||
response_message = Message(
|
||||
response=RpcResponse(
|
||||
request_id=request.request_id,
|
||||
result=serialized_response,
|
||||
)
|
||||
)
|
||||
except BaseException as e:
|
||||
response_message = Message(
|
||||
response=RpcResponse(
|
||||
request_id=request.request_id,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
# Send the response.
|
||||
await self._host_connection.send(response_message)
|
||||
|
||||
async def _process_response(self, response: RpcResponse) -> None:
|
||||
# TODO: deserialize the response and set the future result
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if len(response.error) > 0:
|
||||
future.set_exception(Exception(response.error))
|
||||
else:
|
||||
future.set_result(response.result)
|
||||
|
||||
async def _process_event(self, event: Event) -> None:
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
|
||||
namespace = event.namespace
|
||||
responses: List[Awaitable[Any]] = []
|
||||
for agent_id in self._per_type_subscribers[(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))]:
|
||||
# TODO: skip the sender?
|
||||
message_context = MessageContext(
|
||||
sender=None,
|
||||
topic_id=None,
|
||||
is_rpc=False,
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
agent = await self._get_agent(agent_id)
|
||||
future = agent.on_message(message, ctx=message_context)
|
||||
responses.append(future)
|
||||
|
||||
try:
|
||||
_ = await asyncio.gather(*responses)
|
||||
except BaseException as e:
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
return
|
||||
event_logger.error("Error handling event message", exc_info=e)
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
@ -9,11 +9,21 @@ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .agent_worker_pb2 import Event, Message, RegisterAgentType, RpcRequest, RpcResponse, AgentId
|
||||
from .agent_worker_pb2_grpc import AgentRpcStub
|
||||
from .agent_worker_pb2 import AgentId, Event, Message, RegisterAgentType, RpcRequest, RpcResponse
|
||||
from .agent_worker_pb2_grpc import AgentRpcServicer, AgentRpcStub, add_AgentRpcServicer_to_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent_worker_pb2_grpc import AgentRpcAsyncStub
|
||||
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcAsyncStub", "AgentRpcStub", "Message", "AgentId"]
|
||||
|
||||
__all__ = [
|
||||
"RpcRequest",
|
||||
"RpcResponse",
|
||||
"Event",
|
||||
"RegisterAgentType",
|
||||
"AgentRpcAsyncStub",
|
||||
"AgentRpcStub",
|
||||
"Message",
|
||||
"AgentId",
|
||||
]
|
||||
else:
|
||||
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcStub", "Message", "AgentId"]
|
@ -1,7 +0,0 @@
|
||||
"""
|
||||
The :mod:`agnext.worker` module provides a set of classes for creating distributed agents
|
||||
"""
|
||||
|
||||
from .worker_runtime import WorkerAgentRuntime
|
||||
|
||||
__all__ = ["WorkerAgentRuntime"]
|
@ -1,42 +0,0 @@
|
||||
from agnext.worker.worker_runtime import WorkerAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import CancellationToken, AgentId
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ExampleMessagePayload:
|
||||
content: str
|
||||
|
||||
|
||||
class ExampleAgent(TypeRoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Example Agent")
|
||||
|
||||
@message_handler
|
||||
async def on_example_payload(self, message: ExampleMessagePayload, cancellation_token: CancellationToken) -> None:
|
||||
upper_case = message.content.upper()
|
||||
await self.publish_message(ExampleMessagePayload(content=upper_case))
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
logger = logging.getLogger("main")
|
||||
runtime = WorkerAgentRuntime()
|
||||
await runtime.setup_channel(os.environ["AGENT_HOST"])
|
||||
|
||||
runtime.register("ExampleAgent", lambda: ExampleAgent())
|
||||
while True:
|
||||
try:
|
||||
res = await runtime.send_message("testing!", recipient=AgentId(name="greeter", namespace="testing"), sender=AgentId(name="ExampleAgent", namespace="testing"))
|
||||
logger.info("Response: %s", res)
|
||||
except Exception as e:
|
||||
logger.warning("Error: %s", e)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
asyncio.run(main())
|
||||
|
Loading…
x
Reference in New Issue
Block a user