Communicate client id via metadata in grpc runtime (#5185)

Communicate client id via metadata
This commit is contained in:
Jack Gerrits 2025-01-24 13:41:31 -05:00 committed by GitHub
parent 89631966cb
commit b375d4b18c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 21 deletions

View File

@ -117,6 +117,7 @@ class HostConnection:
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._connection_task: Task[None] | None = None
self._client_id = str(uuid.uuid4())
@classmethod
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
@ -132,7 +133,7 @@ class HostConnection:
)
instance = cls(channel)
instance._connection_task = asyncio.create_task(
instance._connect(channel, instance._send_queue, instance._recv_queue)
instance._connect(channel, instance._send_queue, instance._recv_queue, instance._client_id)
)
return instance
@ -147,6 +148,7 @@ class HostConnection:
channel: grpc.aio.Channel,
send_queue: asyncio.Queue[agent_worker_pb2.Message],
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
client_id: str,
) -> None:
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
@ -154,7 +156,7 @@ class HostConnection:
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
QueueAsyncIterable(send_queue)
QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
) # type: ignore
while True:

View File

@ -2,7 +2,7 @@ import asyncio
import logging
from _collections_abc import AsyncIterator, Iterator
from asyncio import Future, Task
from typing import Any, Dict, Set, cast
from typing import Any, Dict, Sequence, Set, Tuple, cast
from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription
from autogen_core._runtime_impl_helpers import SubscriptionManager
@ -19,30 +19,36 @@ from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
ClientConnectionId = str
def metadata_to_dict(metadata: Sequence[Tuple[str, str]] | None) -> Dict[str, str]:
if metadata is None:
return {}
return {key: value for key, value in metadata}
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
"""A gRPC servicer that hosts message delivery service for agents."""
def __init__(self) -> None:
self._client_id = 0
self._client_id_lock = asyncio.Lock()
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._send_queues: Dict[ClientConnectionId, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, int] = {}
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {}
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {}
self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {}
async def OpenChannel( # type: ignore
self,
request_iterator: AsyncIterator[agent_worker_pb2.Message],
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
# Aquire the lock to get a new client id.
async with self._client_id_lock:
self._client_id += 1
client_id = self._client_id
metadata = metadata_to_dict(context.invocation_metadata()) # type: ignore
if (client_id := cast(ClientConnectionId | None, metadata.get("client-id"))) is None: # type: ignore
logger.error("Client id not found in metadata. Refusing connection.")
return
# Register the client with the server and create a send queue for the client.
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
@ -76,7 +82,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
# Remove the client id from the agent type to client id mapping.
await self._on_client_disconnect(client_id)
async def _on_client_disconnect(self, client_id: int) -> None:
async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
async with self._agent_type_to_client_id_lock:
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
for agent_type in agent_types:
@ -93,7 +99,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
raise exception
async def _receive_messages(
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
self, client_id: ClientConnectionId, request_iterator: AsyncIterator[agent_worker_pb2.Message]
) -> None:
# Receive messages from the client and process them.
async for message in request_iterator:
@ -138,7 +144,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
case None:
logger.warning("Received empty message")
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
# Deliver the message to a client given the target agent type.
async with self._agent_type_to_client_id_lock:
target_client_id = self._agent_type_to_client_id.get(request.target.type)
@ -161,7 +167,9 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
send_response_task.add_done_callback(self._raise_on_exception)
send_response_task.add_done_callback(self._background_tasks.discard)
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
async def _wait_and_send_response(
self, future: Future[agent_worker_pb2.RpcResponse], client_id: ClientConnectionId
) -> None:
response = await future
message = agent_worker_pb2.Message(response=response)
send_queue = self._send_queues.get(client_id)
@ -170,7 +178,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
return
await send_queue.put(message)
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
# Setting the result of the future will send the response back to the original sender.
future = self._pending_responses[client_id].pop(response.request_id)
future.set_result(response)
@ -180,7 +188,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Get the client ids of the recipients.
async with self._agent_type_to_client_id_lock:
client_ids: Set[int] = set()
client_ids: Set[ClientConnectionId] = set()
for recipient in recipients:
client_id = self._agent_type_to_client_id.get(recipient.type)
if client_id is not None:
@ -192,7 +200,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
async def _process_register_agent_type_request(
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: ClientConnectionId
) -> None:
# Register the agent type with the host runtime.
async with self._agent_type_to_client_id_lock:
@ -217,7 +225,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
)
async def _process_add_subscription_request(
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: ClientConnectionId
) -> None:
oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
subscription: Subscription | None = None