mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-20 06:31:54 +00:00

* Python host runtime impl * update * ignore proto generated files * move worker runtime to application * Move example to samples * Fix import * fix * update * server client * better shutdown * fix doc conf * add type
151 lines
7.5 KiB
Python
151 lines
7.5 KiB
Python
import asyncio
|
|
import logging
|
|
from _collections_abc import AsyncIterator, Iterator
|
|
from asyncio import Future, Task
|
|
from typing import Any, Dict, Set
|
|
|
|
import grpc
|
|
|
|
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
|
|
|
|
logger = logging.getLogger("agnext")
|
|
event_logger = logging.getLogger("agnext.events")
|
|
|
|
|
|
class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|
"""A gRPC servicer that hosts message delivery service for agents."""
|
|
|
|
def __init__(self) -> None:
|
|
self._client_id = 0
|
|
self._client_id_lock = asyncio.Lock()
|
|
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
|
self._agent_type_to_client_id: Dict[str, int] = {}
|
|
self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {}
|
|
self._background_tasks: Set[Task[Any]] = set()
|
|
|
|
async def OpenChannel( # type: ignore
|
|
self,
|
|
request_iterator: AsyncIterator[agent_worker_pb2.Message],
|
|
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
|
|
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
|
|
# Aquire the lock to get a new client id.
|
|
async with self._client_id_lock:
|
|
self._client_id += 1
|
|
client_id = self._client_id
|
|
|
|
# Register the client with the server and create a send queue for the client.
|
|
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
|
|
self._send_queues[client_id] = send_queue
|
|
logger.info(f"Client {client_id} connected.")
|
|
|
|
try:
|
|
# Concurrently handle receiving messages from the client and sending messages to the client.
|
|
# This task will receive messages from the client.
|
|
receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
|
|
|
|
# Return an async generator that will yield messages from the send queue to the client.
|
|
while True:
|
|
message = await send_queue.get()
|
|
# Yield the message to the client.
|
|
try:
|
|
yield message
|
|
except Exception as e:
|
|
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
|
|
break
|
|
# Wait for the receiving task to finish.
|
|
await receiving_task
|
|
|
|
finally:
|
|
# Clean up the client connection.
|
|
del self._send_queues[client_id]
|
|
# Cancel pending requests sent to this client.
|
|
for future in self._pending_requests.pop(client_id, {}).values():
|
|
future.cancel()
|
|
logger.info(f"Client {client_id} disconnected.")
|
|
|
|
async def _receive_messages(
|
|
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
|
|
) -> None:
|
|
# Receive messages from the client and process them.
|
|
async for message in request_iterator:
|
|
oneofcase = message.WhichOneof("message")
|
|
match oneofcase:
|
|
case "request":
|
|
request: agent_worker_pb2.RpcRequest = message.request
|
|
logger.info(f"Received request message: {request}")
|
|
task = asyncio.create_task(self._process_request(request, client_id))
|
|
self._background_tasks.add(task)
|
|
task.add_done_callback(self._background_tasks.discard)
|
|
case "response":
|
|
response: agent_worker_pb2.RpcResponse = message.response
|
|
logger.info(f"Received response message: {response}")
|
|
task = asyncio.create_task(self._process_response(response, client_id))
|
|
self._background_tasks.add(task)
|
|
task.add_done_callback(self._background_tasks.discard)
|
|
case "event":
|
|
event: agent_worker_pb2.Event = message.event
|
|
logger.info(f"Received event message: {event}")
|
|
task = asyncio.create_task(self._process_event(event))
|
|
self._background_tasks.add(task)
|
|
task.add_done_callback(self._background_tasks.discard)
|
|
case "registerAgentType":
|
|
register_agent_type: agent_worker_pb2.RegisterAgentType = message.registerAgentType
|
|
logger.info(f"Received register agent type message: {register_agent_type}")
|
|
task = asyncio.create_task(self._process_register_agent_type(register_agent_type, client_id))
|
|
self._background_tasks.add(task)
|
|
task.add_done_callback(self._background_tasks.discard)
|
|
case None:
|
|
logger.warning("Received empty message")
|
|
|
|
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
|
|
# Deliver the message to a client given the target agent type.
|
|
target_client_id = self._agent_type_to_client_id.get(request.target.name)
|
|
if target_client_id is None:
|
|
logger.error(f"Agent {request.target.name} not found, failed to deliver message.")
|
|
return
|
|
target_send_queue = self._send_queues.get(target_client_id)
|
|
if target_send_queue is None:
|
|
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
|
|
return
|
|
await target_send_queue.put(agent_worker_pb2.Message(request=request))
|
|
|
|
# Create a future to wait for the response.
|
|
future = asyncio.get_event_loop().create_future()
|
|
self._pending_requests.setdefault(client_id, {})[request.request_id] = future
|
|
|
|
# Create a task to wait for the response and send it back to the client.
|
|
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
|
|
self._background_tasks.add(send_response_task)
|
|
send_response_task.add_done_callback(self._background_tasks.discard)
|
|
|
|
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
|
|
response = await future
|
|
message = agent_worker_pb2.Message(response=response)
|
|
send_queue = self._send_queues.get(client_id)
|
|
if send_queue is None:
|
|
logger.error(f"Client {client_id} not found, failed to send response message.")
|
|
return
|
|
await send_queue.put(message)
|
|
|
|
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
|
|
# Setting the result of the future will send the response back to the original sender.
|
|
future = self._pending_requests[client_id].pop(response.request_id)
|
|
future.set_result(response)
|
|
|
|
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
|
# Deliver the event to all the clients.
|
|
# TODO: deliver based on subscriptions.
|
|
for send_queue in self._send_queues.values():
|
|
await send_queue.put(agent_worker_pb2.Message(event=event))
|
|
|
|
async def _process_register_agent_type(
|
|
self, register_agent_type: agent_worker_pb2.RegisterAgentType, client_id: int
|
|
) -> None:
|
|
# Register the agent type with the host runtime.
|
|
if register_agent_type.type in self._agent_type_to_client_id:
|
|
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
|
|
logger.warning(
|
|
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}, overwriting the client mapping to client {client_id}."
|
|
)
|
|
self._agent_type_to_client_id[register_agent_type.type] = client_id
|