mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-12 23:41:28 +00:00
Communicate client id via metadata in grpc runtime (#5185)
Communicate client id via metadata
This commit is contained in:
parent
89631966cb
commit
b375d4b18c
@ -117,6 +117,7 @@ class HostConnection:
|
||||
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
||||
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
||||
self._connection_task: Task[None] | None = None
|
||||
self._client_id = str(uuid.uuid4())
|
||||
|
||||
@classmethod
|
||||
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
|
||||
@ -132,7 +133,7 @@ class HostConnection:
|
||||
)
|
||||
instance = cls(channel)
|
||||
instance._connection_task = asyncio.create_task(
|
||||
instance._connect(channel, instance._send_queue, instance._recv_queue)
|
||||
instance._connect(channel, instance._send_queue, instance._recv_queue, instance._client_id)
|
||||
)
|
||||
return instance
|
||||
|
||||
@ -147,6 +148,7 @@ class HostConnection:
|
||||
channel: grpc.aio.Channel,
|
||||
send_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||
client_id: str,
|
||||
) -> None:
|
||||
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
|
||||
|
||||
@ -154,7 +156,7 @@ class HostConnection:
|
||||
|
||||
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
|
||||
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
|
||||
QueueAsyncIterable(send_queue)
|
||||
QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
|
||||
) # type: ignore
|
||||
|
||||
while True:
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
from _collections_abc import AsyncIterator, Iterator
|
||||
from asyncio import Future, Task
|
||||
from typing import Any, Dict, Set, cast
|
||||
from typing import Any, Dict, Sequence, Set, Tuple, cast
|
||||
|
||||
from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription
|
||||
from autogen_core._runtime_impl_helpers import SubscriptionManager
|
||||
@ -19,30 +19,36 @@ from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
|
||||
logger = logging.getLogger("autogen_core")
|
||||
event_logger = logging.getLogger("autogen_core.events")
|
||||
|
||||
ClientConnectionId = str
|
||||
|
||||
|
||||
def metadata_to_dict(metadata: Sequence[Tuple[str, str]] | None) -> Dict[str, str]:
|
||||
if metadata is None:
|
||||
return {}
|
||||
return {key: value for key, value in metadata}
|
||||
|
||||
|
||||
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
"""A gRPC servicer that hosts message delivery service for agents."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._client_id = 0
|
||||
self._client_id_lock = asyncio.Lock()
|
||||
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
||||
self._send_queues: Dict[ClientConnectionId, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
||||
self._agent_type_to_client_id_lock = asyncio.Lock()
|
||||
self._agent_type_to_client_id: Dict[str, int] = {}
|
||||
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {}
|
||||
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
|
||||
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {}
|
||||
self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {}
|
||||
|
||||
async def OpenChannel( # type: ignore
|
||||
self,
|
||||
request_iterator: AsyncIterator[agent_worker_pb2.Message],
|
||||
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
|
||||
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
|
||||
# Aquire the lock to get a new client id.
|
||||
async with self._client_id_lock:
|
||||
self._client_id += 1
|
||||
client_id = self._client_id
|
||||
metadata = metadata_to_dict(context.invocation_metadata()) # type: ignore
|
||||
if (client_id := cast(ClientConnectionId | None, metadata.get("client-id"))) is None: # type: ignore
|
||||
logger.error("Client id not found in metadata. Refusing connection.")
|
||||
return
|
||||
|
||||
# Register the client with the server and create a send queue for the client.
|
||||
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
|
||||
@ -76,7 +82,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
# Remove the client id from the agent type to client id mapping.
|
||||
await self._on_client_disconnect(client_id)
|
||||
|
||||
async def _on_client_disconnect(self, client_id: int) -> None:
|
||||
async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
|
||||
for agent_type in agent_types:
|
||||
@ -93,7 +99,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
raise exception
|
||||
|
||||
async def _receive_messages(
|
||||
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
|
||||
self, client_id: ClientConnectionId, request_iterator: AsyncIterator[agent_worker_pb2.Message]
|
||||
) -> None:
|
||||
# Receive messages from the client and process them.
|
||||
async for message in request_iterator:
|
||||
@ -138,7 +144,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
case None:
|
||||
logger.warning("Received empty message")
|
||||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
|
||||
# Deliver the message to a client given the target agent type.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
target_client_id = self._agent_type_to_client_id.get(request.target.type)
|
||||
@ -161,7 +167,9 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
send_response_task.add_done_callback(self._raise_on_exception)
|
||||
send_response_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
|
||||
async def _wait_and_send_response(
|
||||
self, future: Future[agent_worker_pb2.RpcResponse], client_id: ClientConnectionId
|
||||
) -> None:
|
||||
response = await future
|
||||
message = agent_worker_pb2.Message(response=response)
|
||||
send_queue = self._send_queues.get(client_id)
|
||||
@ -170,7 +178,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
return
|
||||
await send_queue.put(message)
|
||||
|
||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
|
||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
|
||||
# Setting the result of the future will send the response back to the original sender.
|
||||
future = self._pending_responses[client_id].pop(response.request_id)
|
||||
future.set_result(response)
|
||||
@ -180,7 +188,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
|
||||
# Get the client ids of the recipients.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
client_ids: Set[int] = set()
|
||||
client_ids: Set[ClientConnectionId] = set()
|
||||
for recipient in recipients:
|
||||
client_id = self._agent_type_to_client_id.get(recipient.type)
|
||||
if client_id is not None:
|
||||
@ -192,7 +200,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
|
||||
|
||||
async def _process_register_agent_type_request(
|
||||
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int
|
||||
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: ClientConnectionId
|
||||
) -> None:
|
||||
# Register the agent type with the host runtime.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
@ -217,7 +225,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
)
|
||||
|
||||
async def _process_add_subscription_request(
|
||||
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int
|
||||
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: ClientConnectionId
|
||||
) -> None:
|
||||
oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
|
||||
subscription: Subscription | None = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user