Communicate client id via metadata in grpc runtime (#5185)

Communicate client id via metadata
This commit is contained in:
Jack Gerrits 2025-01-24 13:41:31 -05:00 committed by GitHub
parent 89631966cb
commit b375d4b18c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 21 deletions

View File

@ -117,6 +117,7 @@ class HostConnection:
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]() self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]() self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._connection_task: Task[None] | None = None self._connection_task: Task[None] | None = None
self._client_id = str(uuid.uuid4())
@classmethod @classmethod
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self: def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
@ -132,7 +133,7 @@ class HostConnection:
) )
instance = cls(channel) instance = cls(channel)
instance._connection_task = asyncio.create_task( instance._connection_task = asyncio.create_task(
instance._connect(channel, instance._send_queue, instance._recv_queue) instance._connect(channel, instance._send_queue, instance._recv_queue, instance._client_id)
) )
return instance return instance
@ -147,6 +148,7 @@ class HostConnection:
channel: grpc.aio.Channel, channel: grpc.aio.Channel,
send_queue: asyncio.Queue[agent_worker_pb2.Message], send_queue: asyncio.Queue[agent_worker_pb2.Message],
receive_queue: asyncio.Queue[agent_worker_pb2.Message], receive_queue: asyncio.Queue[agent_worker_pb2.Message],
client_id: str,
) -> None: ) -> None:
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
@ -154,7 +156,7 @@ class HostConnection:
# TODO: where do exceptions from reading the iterable go? How do we recover from those? # TODO: where do exceptions from reading the iterable go? How do we recover from those?
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
QueueAsyncIterable(send_queue) QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
) # type: ignore ) # type: ignore
while True: while True:

View File

@ -2,7 +2,7 @@ import asyncio
import logging import logging
from _collections_abc import AsyncIterator, Iterator from _collections_abc import AsyncIterator, Iterator
from asyncio import Future, Task from asyncio import Future, Task
from typing import Any, Dict, Set, cast from typing import Any, Dict, Sequence, Set, Tuple, cast
from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription
from autogen_core._runtime_impl_helpers import SubscriptionManager from autogen_core._runtime_impl_helpers import SubscriptionManager
@ -19,30 +19,36 @@ from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
logger = logging.getLogger("autogen_core") logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events") event_logger = logging.getLogger("autogen_core.events")
ClientConnectionId = str
def metadata_to_dict(metadata: Sequence[Tuple[str, str]] | None) -> Dict[str, str]:
if metadata is None:
return {}
return {key: value for key, value in metadata}
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer): class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
"""A gRPC servicer that hosts message delivery service for agents.""" """A gRPC servicer that hosts message delivery service for agents."""
def __init__(self) -> None: def __init__(self) -> None:
self._client_id = 0 self._send_queues: Dict[ClientConnectionId, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._client_id_lock = asyncio.Lock()
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._agent_type_to_client_id_lock = asyncio.Lock() self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, int] = {} self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {} self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set() self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager() self._subscription_manager = SubscriptionManager()
self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {} self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {}
async def OpenChannel( # type: ignore async def OpenChannel( # type: ignore
self, self,
request_iterator: AsyncIterator[agent_worker_pb2.Message], request_iterator: AsyncIterator[agent_worker_pb2.Message],
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message], context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore ) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
# Aquire the lock to get a new client id. metadata = metadata_to_dict(context.invocation_metadata()) # type: ignore
async with self._client_id_lock: if (client_id := cast(ClientConnectionId | None, metadata.get("client-id"))) is None: # type: ignore
self._client_id += 1 logger.error("Client id not found in metadata. Refusing connection.")
client_id = self._client_id return
# Register the client with the server and create a send queue for the client. # Register the client with the server and create a send queue for the client.
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue() send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
@ -76,7 +82,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
# Remove the client id from the agent type to client id mapping. # Remove the client id from the agent type to client id mapping.
await self._on_client_disconnect(client_id) await self._on_client_disconnect(client_id)
async def _on_client_disconnect(self, client_id: int) -> None: async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
async with self._agent_type_to_client_id_lock: async with self._agent_type_to_client_id_lock:
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id] agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
for agent_type in agent_types: for agent_type in agent_types:
@ -93,7 +99,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
raise exception raise exception
async def _receive_messages( async def _receive_messages(
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message] self, client_id: ClientConnectionId, request_iterator: AsyncIterator[agent_worker_pb2.Message]
) -> None: ) -> None:
# Receive messages from the client and process them. # Receive messages from the client and process them.
async for message in request_iterator: async for message in request_iterator:
@ -138,7 +144,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
case None: case None:
logger.warning("Received empty message") logger.warning("Received empty message")
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None: async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
# Deliver the message to a client given the target agent type. # Deliver the message to a client given the target agent type.
async with self._agent_type_to_client_id_lock: async with self._agent_type_to_client_id_lock:
target_client_id = self._agent_type_to_client_id.get(request.target.type) target_client_id = self._agent_type_to_client_id.get(request.target.type)
@ -161,7 +167,9 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
send_response_task.add_done_callback(self._raise_on_exception) send_response_task.add_done_callback(self._raise_on_exception)
send_response_task.add_done_callback(self._background_tasks.discard) send_response_task.add_done_callback(self._background_tasks.discard)
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None: async def _wait_and_send_response(
self, future: Future[agent_worker_pb2.RpcResponse], client_id: ClientConnectionId
) -> None:
response = await future response = await future
message = agent_worker_pb2.Message(response=response) message = agent_worker_pb2.Message(response=response)
send_queue = self._send_queues.get(client_id) send_queue = self._send_queues.get(client_id)
@ -170,7 +178,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
return return
await send_queue.put(message) await send_queue.put(message)
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None: async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
# Setting the result of the future will send the response back to the original sender. # Setting the result of the future will send the response back to the original sender.
future = self._pending_responses[client_id].pop(response.request_id) future = self._pending_responses[client_id].pop(response.request_id)
future.set_result(response) future.set_result(response)
@ -180,7 +188,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Get the client ids of the recipients. # Get the client ids of the recipients.
async with self._agent_type_to_client_id_lock: async with self._agent_type_to_client_id_lock:
client_ids: Set[int] = set() client_ids: Set[ClientConnectionId] = set()
for recipient in recipients: for recipient in recipients:
client_id = self._agent_type_to_client_id.get(recipient.type) client_id = self._agent_type_to_client_id.get(recipient.type)
if client_id is not None: if client_id is not None:
@ -192,7 +200,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event)) await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
async def _process_register_agent_type_request( async def _process_register_agent_type_request(
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: ClientConnectionId
) -> None: ) -> None:
# Register the agent type with the host runtime. # Register the agent type with the host runtime.
async with self._agent_type_to_client_id_lock: async with self._agent_type_to_client_id_lock:
@ -217,7 +225,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
) )
async def _process_add_subscription_request( async def _process_add_subscription_request(
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: ClientConnectionId
) -> None: ) -> None:
oneofcase = add_subscription_req.subscription.WhichOneof("subscription") oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
subscription: Subscription | None = None subscription: Subscription | None = None