mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-20 14:42:33 +00:00
Python host and worker runtime for distributed agents. (#369)
* Python host runtime impl * update * ignore proto generated files * move worker runtime to application * Move example to samples * Fix import * fix * update * server client * better shutdown * fix doc conf * add type
This commit is contained in:
parent
eb4a5b7df5
commit
5eca0dba4a
1
python/.gitignore
vendored
1
python/.gitignore
vendored
@ -166,6 +166,7 @@ cython_debug/
|
|||||||
|
|
||||||
# Generated proto files
|
# Generated proto files
|
||||||
src/agnext/worker/protos/agent*
|
src/agnext/worker/protos/agent*
|
||||||
|
src/agnext/application/protos/agent*
|
||||||
|
|
||||||
# Generated log files
|
# Generated log files
|
||||||
log.jsonl
|
log.jsonl
|
||||||
|
@ -28,7 +28,7 @@ apidoc_template_dir = "_apidoc_templates"
|
|||||||
apidoc_separate_modules = True
|
apidoc_separate_modules = True
|
||||||
apidoc_extra_args = ["--no-toc"]
|
apidoc_extra_args = ["--no-toc"]
|
||||||
napoleon_custom_sections = [("Returns", "params_style")]
|
napoleon_custom_sections = [("Returns", "params_style")]
|
||||||
apidoc_excluded_paths = ["./worker/protos/"]
|
apidoc_excluded_paths = ["./application/protos/"]
|
||||||
|
|
||||||
templates_path = []
|
templates_path = []
|
||||||
exclude_patterns = ["reference/agnext.rst"]
|
exclude_patterns = ["reference/agnext.rst"]
|
||||||
|
@ -59,7 +59,6 @@ To learn about the core concepts of AGNext, read the `overview <core-concepts/ov
|
|||||||
reference/agnext.components
|
reference/agnext.components
|
||||||
reference/agnext.application
|
reference/agnext.application
|
||||||
reference/agnext.core
|
reference/agnext.core
|
||||||
reference/agnext.worker
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:caption: Other
|
:caption: Other
|
||||||
|
@ -128,7 +128,7 @@ dependencies = [
|
|||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
fix = true
|
fix = true
|
||||||
exclude = ["build", "dist", "src/agnext/worker/protos"]
|
exclude = ["build", "dist", "src/agnext/application/protos"]
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
include = ["src/**", "samples/*.py", "docs/**/*.ipynb"]
|
include = ["src/**", "samples/*.py", "docs/**/*.ipynb"]
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ ignore = ["F401", "E501"]
|
|||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
files = ["src", "samples", "tests"]
|
files = ["src", "samples", "tests"]
|
||||||
exclude = ["src/agnext/worker/protos"]
|
exclude = ["src/agnext/application/protos"]
|
||||||
|
|
||||||
strict = true
|
strict = true
|
||||||
python_version = "3.10"
|
python_version = "3.10"
|
||||||
@ -168,7 +168,7 @@ include = ["src", "tests", "samples"]
|
|||||||
typeCheckingMode = "strict"
|
typeCheckingMode = "strict"
|
||||||
reportUnnecessaryIsInstance = false
|
reportUnnecessaryIsInstance = false
|
||||||
reportMissingTypeStubs = false
|
reportMissingTypeStubs = false
|
||||||
exclude = ["src/agnext/worker/protos"]
|
exclude = ["src/agnext/application/protos"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
minversion = "6.0"
|
minversion = "6.0"
|
||||||
@ -178,7 +178,7 @@ testpaths = ["tests"]
|
|||||||
dependencies = ["hatch-protobuf", "mypy-protobuf~=3.0"]
|
dependencies = ["hatch-protobuf", "mypy-protobuf~=3.0"]
|
||||||
generate_pyi = false
|
generate_pyi = false
|
||||||
proto_paths = ["../protos"]
|
proto_paths = ["../protos"]
|
||||||
output_path = "src/agnext/worker/protos"
|
output_path = "src/agnext/application/protos"
|
||||||
|
|
||||||
|
|
||||||
[[tool.hatch.build.hooks.protobuf.generators]]
|
[[tool.hatch.build.hooks.protobuf.generators]]
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from agnext.application import WorkerAgentRuntime
|
||||||
from agnext.core._serialization import MESSAGE_TYPE_REGISTRY
|
from agnext.core._serialization import MESSAGE_TYPE_REGISTRY
|
||||||
from agnext.worker.worker_runtime import WorkerAgentRuntime
|
|
||||||
from app import build_app
|
from app import build_app
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
|
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
|
||||||
@ -18,7 +18,7 @@ async def main() -> None:
|
|||||||
MESSAGE_TYPE_REGISTRY.add_type(AuditText)
|
MESSAGE_TYPE_REGISTRY.add_type(AuditText)
|
||||||
MESSAGE_TYPE_REGISTRY.add_type(AuditorAlert)
|
MESSAGE_TYPE_REGISTRY.add_type(AuditorAlert)
|
||||||
agnext_logger.info("1")
|
agnext_logger.info("1")
|
||||||
await runtime.setup_channel("localhost:5145")
|
await runtime.start("localhost:5145")
|
||||||
|
|
||||||
agnext_logger.info("2")
|
agnext_logger.info("2")
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ async def main() -> None:
|
|||||||
await asyncio.sleep(1000000)
|
await asyncio.sleep(1000000)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
await runtime.close_channel()
|
await runtime.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
53
python/samples/worker/run_host.py
Normal file
53
python/samples/worker/run_host.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import asyncio
|
||||||
|
import signal
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
from agnext.application import HostRuntimeServicer
|
||||||
|
from agnext.application.protos import agent_worker_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
|
async def serve(server: grpc.aio.Server) -> None: # type: ignore
|
||||||
|
await server.start()
|
||||||
|
print("Server started")
|
||||||
|
await server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
server = grpc.aio.server()
|
||||||
|
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(HostRuntimeServicer(), server)
|
||||||
|
server.add_insecure_port("[::]:50051")
|
||||||
|
|
||||||
|
# Set up signal handling for graceful shutdown
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
def signal_handler() -> None:
|
||||||
|
print("Received exit signal, shutting down gracefully...")
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||||
|
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
# Start server in background task
|
||||||
|
serve_task = asyncio.create_task(serve(server))
|
||||||
|
|
||||||
|
# Wait for the signal to trigger the shutdown event
|
||||||
|
await shutdown_event.wait()
|
||||||
|
|
||||||
|
# Graceful shutdown
|
||||||
|
await server.stop(5) # 5 second grace period
|
||||||
|
await serve_task
|
||||||
|
print("Server stopped")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Server shutdown interrupted.")
|
87
python/samples/worker/run_worker_pub_sub.py
Normal file
87
python/samples/worker/run_worker_pub_sub.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from agnext.application import WorkerAgentRuntime
|
||||||
|
from agnext.components import TypeRoutedAgent, message_handler
|
||||||
|
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AskToGreet:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Greeting:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReturnedGreeting:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Feedback:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReturnedFeedback:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiveAgent(TypeRoutedAgent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("Receive Agent")
|
||||||
|
|
||||||
|
@message_handler
|
||||||
|
async def on_greet(self, message: Greeting, ctx: MessageContext) -> None:
|
||||||
|
await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"))
|
||||||
|
|
||||||
|
@message_handler
|
||||||
|
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
|
||||||
|
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"))
|
||||||
|
|
||||||
|
|
||||||
|
class GreeterAgent(TypeRoutedAgent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("Greeter Agent")
|
||||||
|
|
||||||
|
@message_handler
|
||||||
|
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
|
||||||
|
await self.publish_message(Greeting(f"Hello, {message.content}!"))
|
||||||
|
|
||||||
|
@message_handler
|
||||||
|
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
|
||||||
|
await self.publish_message(Feedback(f"Feedback: {message.content}"))
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
runtime = WorkerAgentRuntime()
|
||||||
|
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
|
||||||
|
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
|
||||||
|
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
|
||||||
|
MESSAGE_TYPE_REGISTRY.add_type(ReturnedGreeting)
|
||||||
|
MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback)
|
||||||
|
await runtime.start(host_connection_string="localhost:50051")
|
||||||
|
|
||||||
|
await runtime.register("reciever", lambda: ReceiveAgent())
|
||||||
|
await runtime.register("greeter", lambda: GreeterAgent())
|
||||||
|
|
||||||
|
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
|
||||||
|
|
||||||
|
# Just to keep the runtime running
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(1000000)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger("agnext")
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
asyncio.run(main())
|
74
python/samples/worker/run_worker_rpc.py
Normal file
74
python/samples/worker/run_worker_rpc.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from agnext.application import WorkerAgentRuntime
|
||||||
|
from agnext.components import TypeRoutedAgent, message_handler
|
||||||
|
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AskToGreet:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Greeting:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Feedback:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiveAgent(TypeRoutedAgent):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("Receive Agent")
|
||||||
|
|
||||||
|
@message_handler
|
||||||
|
async def on_greet(self, message: Greeting, ctx: MessageContext) -> Greeting:
|
||||||
|
return Greeting(content=f"Received: {message.content}")
|
||||||
|
|
||||||
|
@message_handler
|
||||||
|
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
|
||||||
|
print(f"Feedback received: {message.content}")
|
||||||
|
|
||||||
|
|
||||||
|
class GreeterAgent(TypeRoutedAgent):
|
||||||
|
def __init__(self, receive_agent_id: AgentId) -> None:
|
||||||
|
super().__init__("Greeter Agent")
|
||||||
|
self._receive_agent_id = receive_agent_id
|
||||||
|
|
||||||
|
@message_handler
|
||||||
|
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
|
||||||
|
response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id)
|
||||||
|
await self.publish_message(Feedback(f"Feedback: {response.content}"))
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
runtime = WorkerAgentRuntime()
|
||||||
|
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
|
||||||
|
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
|
||||||
|
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
|
||||||
|
await runtime.start(host_connection_string="localhost:50051")
|
||||||
|
|
||||||
|
await runtime.register("reciever", lambda: ReceiveAgent())
|
||||||
|
reciever = await runtime.get("reciever")
|
||||||
|
await runtime.register("greeter", lambda: GreeterAgent(reciever))
|
||||||
|
|
||||||
|
await runtime.publish_message(AskToGreet("Hello World!"), namespace="default")
|
||||||
|
|
||||||
|
# Just to keep the runtime running
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(1000000)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger("agnext")
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
asyncio.run(main())
|
@ -2,6 +2,8 @@
|
|||||||
The :mod:`agnext.application` module provides implementations of core components that are used to compose an application
|
The :mod:`agnext.application` module provides implementations of core components that are used to compose an application
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from ._host_runtime_servicer import HostRuntimeServicer
|
||||||
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||||
|
from ._worker_runtime import WorkerAgentRuntime
|
||||||
|
|
||||||
__all__ = ["SingleThreadedAgentRuntime"]
|
__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "HostRuntimeServicer"]
|
||||||
|
150
python/src/agnext/application/_host_runtime_servicer.py
Normal file
150
python/src/agnext/application/_host_runtime_servicer.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from _collections_abc import AsyncIterator, Iterator
|
||||||
|
from asyncio import Future, Task
|
||||||
|
from typing import Any, Dict, Set
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
|
||||||
|
|
||||||
|
logger = logging.getLogger("agnext")
|
||||||
|
event_logger = logging.getLogger("agnext.events")
|
||||||
|
|
||||||
|
|
||||||
|
class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||||
|
"""A gRPC servicer that hosts message delivery service for agents."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._client_id = 0
|
||||||
|
self._client_id_lock = asyncio.Lock()
|
||||||
|
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
||||||
|
self._agent_type_to_client_id: Dict[str, int] = {}
|
||||||
|
self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {}
|
||||||
|
self._background_tasks: Set[Task[Any]] = set()
|
||||||
|
|
||||||
|
async def OpenChannel( # type: ignore
|
||||||
|
self,
|
||||||
|
request_iterator: AsyncIterator[agent_worker_pb2.Message],
|
||||||
|
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
|
||||||
|
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
|
||||||
|
# Aquire the lock to get a new client id.
|
||||||
|
async with self._client_id_lock:
|
||||||
|
self._client_id += 1
|
||||||
|
client_id = self._client_id
|
||||||
|
|
||||||
|
# Register the client with the server and create a send queue for the client.
|
||||||
|
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
|
||||||
|
self._send_queues[client_id] = send_queue
|
||||||
|
logger.info(f"Client {client_id} connected.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Concurrently handle receiving messages from the client and sending messages to the client.
|
||||||
|
# This task will receive messages from the client.
|
||||||
|
receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
|
||||||
|
|
||||||
|
# Return an async generator that will yield messages from the send queue to the client.
|
||||||
|
while True:
|
||||||
|
message = await send_queue.get()
|
||||||
|
# Yield the message to the client.
|
||||||
|
try:
|
||||||
|
yield message
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
|
||||||
|
break
|
||||||
|
# Wait for the receiving task to finish.
|
||||||
|
await receiving_task
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up the client connection.
|
||||||
|
del self._send_queues[client_id]
|
||||||
|
# Cancel pending requests sent to this client.
|
||||||
|
for future in self._pending_requests.pop(client_id, {}).values():
|
||||||
|
future.cancel()
|
||||||
|
logger.info(f"Client {client_id} disconnected.")
|
||||||
|
|
||||||
|
async def _receive_messages(
|
||||||
|
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
|
||||||
|
) -> None:
|
||||||
|
# Receive messages from the client and process them.
|
||||||
|
async for message in request_iterator:
|
||||||
|
oneofcase = message.WhichOneof("message")
|
||||||
|
match oneofcase:
|
||||||
|
case "request":
|
||||||
|
request: agent_worker_pb2.RpcRequest = message.request
|
||||||
|
logger.info(f"Received request message: {request}")
|
||||||
|
task = asyncio.create_task(self._process_request(request, client_id))
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
case "response":
|
||||||
|
response: agent_worker_pb2.RpcResponse = message.response
|
||||||
|
logger.info(f"Received response message: {response}")
|
||||||
|
task = asyncio.create_task(self._process_response(response, client_id))
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
case "event":
|
||||||
|
event: agent_worker_pb2.Event = message.event
|
||||||
|
logger.info(f"Received event message: {event}")
|
||||||
|
task = asyncio.create_task(self._process_event(event))
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
case "registerAgentType":
|
||||||
|
register_agent_type: agent_worker_pb2.RegisterAgentType = message.registerAgentType
|
||||||
|
logger.info(f"Received register agent type message: {register_agent_type}")
|
||||||
|
task = asyncio.create_task(self._process_register_agent_type(register_agent_type, client_id))
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
case None:
|
||||||
|
logger.warning("Received empty message")
|
||||||
|
|
||||||
|
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
|
||||||
|
# Deliver the message to a client given the target agent type.
|
||||||
|
target_client_id = self._agent_type_to_client_id.get(request.target.name)
|
||||||
|
if target_client_id is None:
|
||||||
|
logger.error(f"Agent {request.target.name} not found, failed to deliver message.")
|
||||||
|
return
|
||||||
|
target_send_queue = self._send_queues.get(target_client_id)
|
||||||
|
if target_send_queue is None:
|
||||||
|
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
|
||||||
|
return
|
||||||
|
await target_send_queue.put(agent_worker_pb2.Message(request=request))
|
||||||
|
|
||||||
|
# Create a future to wait for the response.
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
self._pending_requests.setdefault(client_id, {})[request.request_id] = future
|
||||||
|
|
||||||
|
# Create a task to wait for the response and send it back to the client.
|
||||||
|
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
|
||||||
|
self._background_tasks.add(send_response_task)
|
||||||
|
send_response_task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
|
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
|
||||||
|
response = await future
|
||||||
|
message = agent_worker_pb2.Message(response=response)
|
||||||
|
send_queue = self._send_queues.get(client_id)
|
||||||
|
if send_queue is None:
|
||||||
|
logger.error(f"Client {client_id} not found, failed to send response message.")
|
||||||
|
return
|
||||||
|
await send_queue.put(message)
|
||||||
|
|
||||||
|
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
|
||||||
|
# Setting the result of the future will send the response back to the original sender.
|
||||||
|
future = self._pending_requests[client_id].pop(response.request_id)
|
||||||
|
future.set_result(response)
|
||||||
|
|
||||||
|
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
||||||
|
# Deliver the event to all the clients.
|
||||||
|
# TODO: deliver based on subscriptions.
|
||||||
|
for send_queue in self._send_queues.values():
|
||||||
|
await send_queue.put(agent_worker_pb2.Message(event=event))
|
||||||
|
|
||||||
|
async def _process_register_agent_type(
|
||||||
|
self, register_agent_type: agent_worker_pb2.RegisterAgentType, client_id: int
|
||||||
|
) -> None:
|
||||||
|
# Register the agent type with the host runtime.
|
||||||
|
if register_agent_type.type in self._agent_type_to_client_id:
|
||||||
|
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
|
||||||
|
logger.warning(
|
||||||
|
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}, overwriting the client mapping to client {client_id}."
|
||||||
|
)
|
||||||
|
self._agent_type_to_client_id[register_agent_type.type] = client_id
|
@ -12,8 +12,6 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||||
|
|
||||||
from agnext.core import MessageContext
|
|
||||||
|
|
||||||
from ..core import (
|
from ..core import (
|
||||||
MESSAGE_TYPE_REGISTRY,
|
MESSAGE_TYPE_REGISTRY,
|
||||||
Agent,
|
Agent,
|
||||||
@ -23,6 +21,7 @@ from ..core import (
|
|||||||
AgentProxy,
|
AgentProxy,
|
||||||
AgentRuntime,
|
AgentRuntime,
|
||||||
CancellationToken,
|
CancellationToken,
|
||||||
|
MessageContext,
|
||||||
)
|
)
|
||||||
from ..core.exceptions import MessageDroppedException
|
from ..core.exceptions import MessageDroppedException
|
||||||
from ..core.intervention import DropMessage, InterventionHandler
|
from ..core.intervention import DropMessage, InterventionHandler
|
||||||
|
@ -2,12 +2,9 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
import warnings
|
import warnings
|
||||||
from asyncio import Future, Task
|
from asyncio import Future, Task
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -50,40 +47,6 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger("agnext")
|
logger = logging.getLogger("agnext")
|
||||||
event_logger = logging.getLogger("agnext.events")
|
event_logger = logging.getLogger("agnext.events")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class PublishMessageEnvelope:
|
|
||||||
"""A message envelope for publishing messages to all agents that can handle
|
|
||||||
the message of the type T."""
|
|
||||||
|
|
||||||
message: Any
|
|
||||||
cancellation_token: CancellationToken
|
|
||||||
sender: AgentId | None
|
|
||||||
namespace: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class SendMessageEnvelope:
|
|
||||||
"""A message envelope for sending a message to a specific agent that can handle
|
|
||||||
the message of the type T."""
|
|
||||||
|
|
||||||
message: Any
|
|
||||||
sender: AgentId | None
|
|
||||||
recipient: AgentId
|
|
||||||
future: Future[Any]
|
|
||||||
cancellation_token: CancellationToken
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class ResponseMessageEnvelope:
|
|
||||||
"""A message envelope for sending a response to a message."""
|
|
||||||
|
|
||||||
message: Any
|
|
||||||
future: Future[Any]
|
|
||||||
sender: AgentId
|
|
||||||
recipient: AgentId | None
|
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
T = TypeVar("T", bound=Agent)
|
T = TypeVar("T", bound=Agent)
|
||||||
|
|
||||||
@ -99,7 +62,7 @@ class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class RuntimeConnection:
|
class HostConnection:
|
||||||
DEFAULT_GRPC_CONFIG: ClassVar[Mapping[str, Any]] = {
|
DEFAULT_GRPC_CONFIG: ClassVar[Mapping[str, Any]] = {
|
||||||
"methodConfig": [
|
"methodConfig": [
|
||||||
{
|
{
|
||||||
@ -129,9 +92,6 @@ class RuntimeConnection:
|
|||||||
channel = grpc.aio.insecure_channel(
|
channel = grpc.aio.insecure_channel(
|
||||||
connection_string, options=[("grpc.service_config", json.dumps(grpc_config))]
|
connection_string, options=[("grpc.service_config", json.dumps(grpc_config))]
|
||||||
)
|
)
|
||||||
# logger.info("awaiting channel_ready")
|
|
||||||
# await channel.channel_ready()
|
|
||||||
# logger.info("channel_ready")
|
|
||||||
instance = cls(channel)
|
instance = cls(channel)
|
||||||
instance._connection_task = asyncio.create_task(
|
instance._connection_task = asyncio.create_task(
|
||||||
instance._connect(channel, instance._send_queue, instance._recv_queue)
|
instance._connect(channel, instance._send_queue, instance._recv_queue)
|
||||||
@ -176,102 +136,69 @@ class RuntimeConnection:
|
|||||||
async def recv(self) -> Message:
|
async def recv(self) -> Message:
|
||||||
logger.info("Getting message from queue")
|
logger.info("Getting message from queue")
|
||||||
return await self._recv_queue.get()
|
return await self._recv_queue.get()
|
||||||
logger.info("Got message from queue")
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerAgentRuntime(AgentRuntime):
|
class WorkerAgentRuntime(AgentRuntime):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
|
||||||
# (namespace, type) -> List[AgentId]
|
|
||||||
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
|
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
|
||||||
self._agent_factories: Dict[
|
self._agent_factories: Dict[
|
||||||
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
||||||
] = {}
|
] = {}
|
||||||
# If empty, then all namespaces are valid for that agent type
|
|
||||||
self._valid_namespaces: Dict[str, Sequence[str]] = {}
|
|
||||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||||
self._known_namespaces: set[str] = set()
|
self._known_namespaces: set[str] = set()
|
||||||
self._read_task: None | Task[None] = None
|
self._read_task: None | Task[None] = None
|
||||||
self._running = False
|
self._running = False
|
||||||
self._pending_requests: Dict[str, Future[Any]] = {}
|
self._pending_requests: Dict[str, Future[Any]] = {}
|
||||||
self._pending_requests_lock = threading.Lock()
|
self._pending_requests_lock = asyncio.Lock()
|
||||||
self._next_request_id = 0
|
self._next_request_id = 0
|
||||||
self._runtime_connection: RuntimeConnection | None = None
|
self._host_connection: HostConnection | None = None
|
||||||
|
self._background_tasks: Set[Task[Any]] = set()
|
||||||
|
|
||||||
async def setup_channel(self, connection_string: str) -> None:
|
async def start(self, host_connection_string: str) -> None:
|
||||||
logger.info(f"connecting to: {connection_string}")
|
if self._running:
|
||||||
self._runtime_connection = await RuntimeConnection.from_connection_string(connection_string)
|
raise ValueError("Runtime is already running.")
|
||||||
|
logger.info(f"Connecting to host: {host_connection_string}")
|
||||||
|
self._host_connection = await HostConnection.from_connection_string(host_connection_string)
|
||||||
logger.info("connection")
|
logger.info("connection")
|
||||||
if self._read_task is None:
|
if self._read_task is None:
|
||||||
self._read_task = asyncio.create_task(self.run_read_loop())
|
self._read_task = asyncio.create_task(self._run_read_loop())
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
async def send_register_agent_type(self, agent_type: str) -> None:
|
async def _run_read_loop(self) -> None:
|
||||||
assert self._runtime_connection is not None
|
|
||||||
message = Message(registerAgentType=RegisterAgentType(type=agent_type))
|
|
||||||
await self._runtime_connection.send(message)
|
|
||||||
logger.info("Sent registerAgentType message for %s", agent_type)
|
|
||||||
|
|
||||||
async def run_read_loop(self) -> None:
|
|
||||||
logger.info("Starting read loop")
|
logger.info("Starting read loop")
|
||||||
# TODO: catch exceptions and reconnect
|
# TODO: catch exceptions and reconnect
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
message = await self._runtime_connection.recv() # type: ignore
|
message = await self._host_connection.recv() # type: ignore
|
||||||
logger.info("Got message: %s", message)
|
logger.info("Got message: %s", message)
|
||||||
oneofcase = Message.WhichOneof(message, "message")
|
oneofcase = Message.WhichOneof(message, "message")
|
||||||
match oneofcase:
|
match oneofcase:
|
||||||
case "registerAgentType":
|
case "registerAgentType":
|
||||||
logger.warn("Cant handle registerAgentType")
|
logger.warn("Cant handle registerAgentType, skipping.")
|
||||||
case "request":
|
case "request":
|
||||||
# request: RpcRequest = message.request
|
request: RpcRequest = message.request
|
||||||
# source = AgentId(request.source.name, request.source.namespace)
|
task = asyncio.create_task(self._process_request(request))
|
||||||
# target = AgentId(request.target.name, request.target.namespace)
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
raise NotImplementedError("Sending messages is not yet implemented.")
|
|
||||||
case "response":
|
case "response":
|
||||||
response: RpcResponse = message.response
|
response: RpcResponse = message.response
|
||||||
future = self._pending_requests.pop(response.request_id)
|
task = asyncio.create_task(self._process_response(response))
|
||||||
if len(response.error) > 0:
|
self._background_tasks.add(task)
|
||||||
future.set_exception(Exception(response.error))
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
break
|
|
||||||
future.set_result(response.result)
|
|
||||||
case "event":
|
case "event":
|
||||||
event: Event = message.event
|
event: Event = message.event
|
||||||
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
|
task = asyncio.create_task(self._process_event(event))
|
||||||
# namespace = event.namespace
|
self._background_tasks.add(task)
|
||||||
namespace = "default"
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
logger.info("Got event: %s", message)
|
|
||||||
for agent_id in self._per_type_subscribers[
|
|
||||||
(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
|
|
||||||
]:
|
|
||||||
logger.info("Sending message to %s", agent_id)
|
|
||||||
agent = await self._get_agent(agent_id)
|
|
||||||
message_context = MessageContext(
|
|
||||||
# TODO: should sender be in the proto even for published events?
|
|
||||||
sender=None,
|
|
||||||
# TODO: topic_id
|
|
||||||
topic_id=None,
|
|
||||||
is_rpc=False,
|
|
||||||
cancellation_token=CancellationToken(),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await agent.on_message(message, ctx=message_context)
|
|
||||||
logger.info("%s handled event %s", agent_id, message)
|
|
||||||
except Exception as e:
|
|
||||||
event_logger.error("Error handling message", exc_info=e)
|
|
||||||
|
|
||||||
logger.warn("Cant handle event")
|
|
||||||
case None:
|
case None:
|
||||||
logger.warn("No message")
|
logger.warn("No message")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in read loop", exc_info=e)
|
logger.error("Error in read loop", exc_info=e)
|
||||||
|
|
||||||
async def close_channel(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._runtime_connection is not None:
|
if self._host_connection is not None:
|
||||||
await self._runtime_connection.close()
|
await self._host_connection.close()
|
||||||
if self._read_task is not None:
|
if self._read_task is not None:
|
||||||
await self._read_task
|
await self._read_task
|
||||||
|
|
||||||
@ -279,7 +206,6 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||||||
def _known_agent_names(self) -> Set[str]:
|
def _known_agent_names(self) -> Set[str]:
|
||||||
return set(self._agent_factories.keys())
|
return set(self._agent_factories.keys())
|
||||||
|
|
||||||
# Returns the response of the message
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
self,
|
self,
|
||||||
message: Any,
|
message: Any,
|
||||||
@ -288,25 +214,32 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||||||
sender: AgentId | None = None,
|
sender: AgentId | None = None,
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
assert self._runtime_connection is not None
|
if not self._running:
|
||||||
|
raise ValueError("Runtime must be running when sending message.")
|
||||||
|
assert self._host_connection is not None
|
||||||
# create a new future for the result
|
# create a new future for the result
|
||||||
future = asyncio.get_event_loop().create_future()
|
future = asyncio.get_event_loop().create_future()
|
||||||
with self._pending_requests_lock:
|
async with self._pending_requests_lock:
|
||||||
self._next_request_id += 1
|
self._next_request_id += 1
|
||||||
request_id = self._next_request_id
|
request_id = self._next_request_id
|
||||||
request_id_str = str(request_id)
|
request_id_str = str(request_id)
|
||||||
self._pending_requests[request_id_str] = future
|
self._pending_requests[request_id_str] = future
|
||||||
sender = cast(AgentId, sender)
|
sender = cast(AgentId, sender)
|
||||||
|
method = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||||
|
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=method)
|
||||||
runtime_message = Message(
|
runtime_message = Message(
|
||||||
request=RpcRequest(
|
request=RpcRequest(
|
||||||
request_id=request_id_str,
|
request_id=request_id_str,
|
||||||
target=AgentIdProto(name=recipient.type, namespace=recipient.key),
|
target=AgentIdProto(name=recipient.type, namespace=recipient.key),
|
||||||
source=AgentIdProto(name=sender.type, namespace=sender.key),
|
source=AgentIdProto(name=sender.type, namespace=sender.key),
|
||||||
data=message,
|
method=method,
|
||||||
|
data=serialized_message,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# TODO: Find a way to handle timeouts/errors
|
# TODO: Find a way to handle timeouts/errors
|
||||||
asyncio.create_task(self._runtime_connection.send(runtime_message))
|
task = asyncio.create_task(self._host_connection.send(runtime_message))
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
return await future
|
return await future
|
||||||
|
|
||||||
async def publish_message(
|
async def publish_message(
|
||||||
@ -317,26 +250,24 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||||||
sender: AgentId | None = None,
|
sender: AgentId | None = None,
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self._runtime_connection is not None
|
if not self._running:
|
||||||
|
raise ValueError("Runtime must be running when publishing message.")
|
||||||
|
assert self._host_connection is not None
|
||||||
sender_namespace = sender.key if sender is not None else None
|
sender_namespace = sender.key if sender is not None else None
|
||||||
explicit_namespace = namespace
|
explicit_namespace = namespace
|
||||||
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
|
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert explicit_namespace is not None or sender_namespace is not None
|
assert explicit_namespace is not None or sender_namespace is not None
|
||||||
actual_namespace = cast(str, explicit_namespace or sender_namespace)
|
actual_namespace = cast(str, explicit_namespace or sender_namespace)
|
||||||
await self._process_seen_namespace(actual_namespace)
|
await self._process_seen_namespace(actual_namespace)
|
||||||
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
|
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
|
||||||
message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message))
|
message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message))
|
||||||
|
task = asyncio.create_task(self._host_connection.send(message))
|
||||||
async def write_message() -> None:
|
self._background_tasks.add(task)
|
||||||
assert self._runtime_connection is not None
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
await self._runtime_connection.send(message)
|
|
||||||
|
|
||||||
await asyncio.create_task(write_message())
|
|
||||||
|
|
||||||
async def save_state(self) -> Mapping[str, Any]:
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
raise NotImplementedError("Saving state is not yet implemented.")
|
raise NotImplementedError("Saving state is not yet implemented.")
|
||||||
@ -358,6 +289,8 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||||||
name: str,
|
name: str,
|
||||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not self._running:
|
||||||
|
raise ValueError("Runtime must be running when registering agent.")
|
||||||
if name in self._agent_factories:
|
if name in self._agent_factories:
|
||||||
raise ValueError(f"Agent with name {name} already exists.")
|
raise ValueError(f"Agent with name {name} already exists.")
|
||||||
self._agent_factories[name] = agent_factory
|
self._agent_factories[name] = agent_factory
|
||||||
@ -366,7 +299,75 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||||||
for namespace in self._known_namespaces:
|
for namespace in self._known_namespaces:
|
||||||
await self._get_agent(AgentId(type=name, key=namespace))
|
await self._get_agent(AgentId(type=name, key=namespace))
|
||||||
|
|
||||||
await self.send_register_agent_type(name)
|
assert self._host_connection is not None
|
||||||
|
message = Message(registerAgentType=RegisterAgentType(type=name))
|
||||||
|
await self._host_connection.send(message)
|
||||||
|
logger.info("Sent registerAgentType message for %s", name)
|
||||||
|
|
||||||
|
async def _process_request(self, request: RpcRequest) -> None:
|
||||||
|
assert self._host_connection is not None
|
||||||
|
target = AgentId(request.target.name, request.target.namespace)
|
||||||
|
source = AgentId(request.source.name, request.source.namespace)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info(f"Processing request from {source} to {target}")
|
||||||
|
target_agent = await self._get_agent(target)
|
||||||
|
message_context = MessageContext(
|
||||||
|
sender=source,
|
||||||
|
topic_id=None,
|
||||||
|
is_rpc=True,
|
||||||
|
cancellation_token=CancellationToken(),
|
||||||
|
)
|
||||||
|
message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.method)
|
||||||
|
response = await target_agent.on_message(message, ctx=message_context)
|
||||||
|
serialized_response = MESSAGE_TYPE_REGISTRY.serialize(response, type_name=request.method)
|
||||||
|
response_message = Message(
|
||||||
|
response=RpcResponse(
|
||||||
|
request_id=request.request_id,
|
||||||
|
result=serialized_response,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except BaseException as e:
|
||||||
|
response_message = Message(
|
||||||
|
response=RpcResponse(
|
||||||
|
request_id=request.request_id,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send the response.
|
||||||
|
await self._host_connection.send(response_message)
|
||||||
|
|
||||||
|
async def _process_response(self, response: RpcResponse) -> None:
|
||||||
|
# TODO: deserialize the response and set the future result
|
||||||
|
future = self._pending_requests.pop(response.request_id)
|
||||||
|
if len(response.error) > 0:
|
||||||
|
future.set_exception(Exception(response.error))
|
||||||
|
else:
|
||||||
|
future.set_result(response.result)
|
||||||
|
|
||||||
|
async def _process_event(self, event: Event) -> None:
|
||||||
|
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
|
||||||
|
namespace = event.namespace
|
||||||
|
responses: List[Awaitable[Any]] = []
|
||||||
|
for agent_id in self._per_type_subscribers[(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))]:
|
||||||
|
# TODO: skip the sender?
|
||||||
|
message_context = MessageContext(
|
||||||
|
sender=None,
|
||||||
|
topic_id=None,
|
||||||
|
is_rpc=False,
|
||||||
|
cancellation_token=CancellationToken(),
|
||||||
|
)
|
||||||
|
agent = await self._get_agent(agent_id)
|
||||||
|
future = agent.on_message(message, ctx=message_context)
|
||||||
|
responses.append(future)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = await asyncio.gather(*responses)
|
||||||
|
except BaseException as e:
|
||||||
|
if isinstance(e, asyncio.CancelledError):
|
||||||
|
return
|
||||||
|
event_logger.error("Error handling event message", exc_info=e)
|
||||||
|
|
||||||
async def _invoke_agent_factory(
|
async def _invoke_agent_factory(
|
||||||
self,
|
self,
|
@ -9,11 +9,21 @@ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from .agent_worker_pb2 import Event, Message, RegisterAgentType, RpcRequest, RpcResponse, AgentId
|
from .agent_worker_pb2 import AgentId, Event, Message, RegisterAgentType, RpcRequest, RpcResponse
|
||||||
from .agent_worker_pb2_grpc import AgentRpcStub
|
from .agent_worker_pb2_grpc import AgentRpcServicer, AgentRpcStub, add_AgentRpcServicer_to_server
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agent_worker_pb2_grpc import AgentRpcAsyncStub
|
from .agent_worker_pb2_grpc import AgentRpcAsyncStub
|
||||||
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcAsyncStub", "AgentRpcStub", "Message", "AgentId"]
|
|
||||||
|
__all__ = [
|
||||||
|
"RpcRequest",
|
||||||
|
"RpcResponse",
|
||||||
|
"Event",
|
||||||
|
"RegisterAgentType",
|
||||||
|
"AgentRpcAsyncStub",
|
||||||
|
"AgentRpcStub",
|
||||||
|
"Message",
|
||||||
|
"AgentId",
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcStub", "Message", "AgentId"]
|
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcStub", "Message", "AgentId"]
|
@ -1,7 +0,0 @@
|
|||||||
"""
|
|
||||||
The :mod:`agnext.worker` module provides a set of classes for creating distributed agents
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .worker_runtime import WorkerAgentRuntime
|
|
||||||
|
|
||||||
__all__ = ["WorkerAgentRuntime"]
|
|
@ -1,42 +0,0 @@
|
|||||||
from agnext.worker.worker_runtime import WorkerAgentRuntime
|
|
||||||
from agnext.components import TypeRoutedAgent, message_handler
|
|
||||||
from agnext.core import CancellationToken, AgentId
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExampleMessagePayload:
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleAgent(TypeRoutedAgent):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__("Example Agent")
|
|
||||||
|
|
||||||
@message_handler
|
|
||||||
async def on_example_payload(self, message: ExampleMessagePayload, cancellation_token: CancellationToken) -> None:
|
|
||||||
upper_case = message.content.upper()
|
|
||||||
await self.publish_message(ExampleMessagePayload(content=upper_case))
|
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
|
||||||
logger = logging.getLogger("main")
|
|
||||||
runtime = WorkerAgentRuntime()
|
|
||||||
await runtime.setup_channel(os.environ["AGENT_HOST"])
|
|
||||||
|
|
||||||
runtime.register("ExampleAgent", lambda: ExampleAgent())
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
res = await runtime.send_message("testing!", recipient=AgentId(name="greeter", namespace="testing"), sender=AgentId(name="ExampleAgent", namespace="testing"))
|
|
||||||
logger.info("Response: %s", res)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Error: %s", e)
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user