mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-18 02:28:54 +00:00
Communicate client id via metadata in grpc runtime (#5185)
Communicate client id via metadata
This commit is contained in:
parent
89631966cb
commit
b375d4b18c
@ -117,6 +117,7 @@ class HostConnection:
|
|||||||
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
||||||
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
||||||
self._connection_task: Task[None] | None = None
|
self._connection_task: Task[None] | None = None
|
||||||
|
self._client_id = str(uuid.uuid4())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
|
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
|
||||||
@ -132,7 +133,7 @@ class HostConnection:
|
|||||||
)
|
)
|
||||||
instance = cls(channel)
|
instance = cls(channel)
|
||||||
instance._connection_task = asyncio.create_task(
|
instance._connection_task = asyncio.create_task(
|
||||||
instance._connect(channel, instance._send_queue, instance._recv_queue)
|
instance._connect(channel, instance._send_queue, instance._recv_queue, instance._client_id)
|
||||||
)
|
)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
@ -147,6 +148,7 @@ class HostConnection:
|
|||||||
channel: grpc.aio.Channel,
|
channel: grpc.aio.Channel,
|
||||||
send_queue: asyncio.Queue[agent_worker_pb2.Message],
|
send_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||||
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
|
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||||
|
client_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
|
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
|
||||||
|
|
||||||
@ -154,7 +156,7 @@ class HostConnection:
|
|||||||
|
|
||||||
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
|
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
|
||||||
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
|
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
|
||||||
QueueAsyncIterable(send_queue)
|
QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from _collections_abc import AsyncIterator, Iterator
|
from _collections_abc import AsyncIterator, Iterator
|
||||||
from asyncio import Future, Task
|
from asyncio import Future, Task
|
||||||
from typing import Any, Dict, Set, cast
|
from typing import Any, Dict, Sequence, Set, Tuple, cast
|
||||||
|
|
||||||
from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription
|
from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription
|
||||||
from autogen_core._runtime_impl_helpers import SubscriptionManager
|
from autogen_core._runtime_impl_helpers import SubscriptionManager
|
||||||
@ -19,30 +19,36 @@ from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
|
|||||||
logger = logging.getLogger("autogen_core")
|
logger = logging.getLogger("autogen_core")
|
||||||
event_logger = logging.getLogger("autogen_core.events")
|
event_logger = logging.getLogger("autogen_core.events")
|
||||||
|
|
||||||
|
ClientConnectionId = str
|
||||||
|
|
||||||
|
|
||||||
|
def metadata_to_dict(metadata: Sequence[Tuple[str, str]] | None) -> Dict[str, str]:
|
||||||
|
if metadata is None:
|
||||||
|
return {}
|
||||||
|
return {key: value for key, value in metadata}
|
||||||
|
|
||||||
|
|
||||||
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||||
"""A gRPC servicer that hosts message delivery service for agents."""
|
"""A gRPC servicer that hosts message delivery service for agents."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._client_id = 0
|
self._send_queues: Dict[ClientConnectionId, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
||||||
self._client_id_lock = asyncio.Lock()
|
|
||||||
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
|
||||||
self._agent_type_to_client_id_lock = asyncio.Lock()
|
self._agent_type_to_client_id_lock = asyncio.Lock()
|
||||||
self._agent_type_to_client_id: Dict[str, int] = {}
|
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
|
||||||
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {}
|
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
|
||||||
self._background_tasks: Set[Task[Any]] = set()
|
self._background_tasks: Set[Task[Any]] = set()
|
||||||
self._subscription_manager = SubscriptionManager()
|
self._subscription_manager = SubscriptionManager()
|
||||||
self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {}
|
self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {}
|
||||||
|
|
||||||
async def OpenChannel( # type: ignore
|
async def OpenChannel( # type: ignore
|
||||||
self,
|
self,
|
||||||
request_iterator: AsyncIterator[agent_worker_pb2.Message],
|
request_iterator: AsyncIterator[agent_worker_pb2.Message],
|
||||||
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
|
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
|
||||||
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
|
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
|
||||||
# Aquire the lock to get a new client id.
|
metadata = metadata_to_dict(context.invocation_metadata()) # type: ignore
|
||||||
async with self._client_id_lock:
|
if (client_id := cast(ClientConnectionId | None, metadata.get("client-id"))) is None: # type: ignore
|
||||||
self._client_id += 1
|
logger.error("Client id not found in metadata. Refusing connection.")
|
||||||
client_id = self._client_id
|
return
|
||||||
|
|
||||||
# Register the client with the server and create a send queue for the client.
|
# Register the client with the server and create a send queue for the client.
|
||||||
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
|
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
|
||||||
@ -76,7 +82,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
|||||||
# Remove the client id from the agent type to client id mapping.
|
# Remove the client id from the agent type to client id mapping.
|
||||||
await self._on_client_disconnect(client_id)
|
await self._on_client_disconnect(client_id)
|
||||||
|
|
||||||
async def _on_client_disconnect(self, client_id: int) -> None:
|
async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
|
||||||
async with self._agent_type_to_client_id_lock:
|
async with self._agent_type_to_client_id_lock:
|
||||||
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
|
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
|
||||||
for agent_type in agent_types:
|
for agent_type in agent_types:
|
||||||
@ -93,7 +99,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
|||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
async def _receive_messages(
|
async def _receive_messages(
|
||||||
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
|
self, client_id: ClientConnectionId, request_iterator: AsyncIterator[agent_worker_pb2.Message]
|
||||||
) -> None:
|
) -> None:
|
||||||
# Receive messages from the client and process them.
|
# Receive messages from the client and process them.
|
||||||
async for message in request_iterator:
|
async for message in request_iterator:
|
||||||
@ -138,7 +144,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
|||||||
case None:
|
case None:
|
||||||
logger.warning("Received empty message")
|
logger.warning("Received empty message")
|
||||||
|
|
||||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
|
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
|
||||||
# Deliver the message to a client given the target agent type.
|
# Deliver the message to a client given the target agent type.
|
||||||
async with self._agent_type_to_client_id_lock:
|
async with self._agent_type_to_client_id_lock:
|
||||||
target_client_id = self._agent_type_to_client_id.get(request.target.type)
|
target_client_id = self._agent_type_to_client_id.get(request.target.type)
|
||||||
@ -161,7 +167,9 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
|||||||
send_response_task.add_done_callback(self._raise_on_exception)
|
send_response_task.add_done_callback(self._raise_on_exception)
|
||||||
send_response_task.add_done_callback(self._background_tasks.discard)
|
send_response_task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
|
async def _wait_and_send_response(
|
||||||
|
self, future: Future[agent_worker_pb2.RpcResponse], client_id: ClientConnectionId
|
||||||
|
) -> None:
|
||||||
response = await future
|
response = await future
|
||||||
message = agent_worker_pb2.Message(response=response)
|
message = agent_worker_pb2.Message(response=response)
|
||||||
send_queue = self._send_queues.get(client_id)
|
send_queue = self._send_queues.get(client_id)
|
||||||
@ -170,7 +178,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
|||||||
return
|
return
|
||||||
await send_queue.put(message)
|
await send_queue.put(message)
|
||||||
|
|
||||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
|
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
|
||||||
# Setting the result of the future will send the response back to the original sender.
|
# Setting the result of the future will send the response back to the original sender.
|
||||||
future = self._pending_responses[client_id].pop(response.request_id)
|
future = self._pending_responses[client_id].pop(response.request_id)
|
||||||
future.set_result(response)
|
future.set_result(response)
|
||||||
@ -180,7 +188,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
|||||||
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
|
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
|
||||||
# Get the client ids of the recipients.
|
# Get the client ids of the recipients.
|
||||||
async with self._agent_type_to_client_id_lock:
|
async with self._agent_type_to_client_id_lock:
|
||||||
client_ids: Set[int] = set()
|
client_ids: Set[ClientConnectionId] = set()
|
||||||
for recipient in recipients:
|
for recipient in recipients:
|
||||||
client_id = self._agent_type_to_client_id.get(recipient.type)
|
client_id = self._agent_type_to_client_id.get(recipient.type)
|
||||||
if client_id is not None:
|
if client_id is not None:
|
||||||
@ -192,7 +200,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
|||||||
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
|
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
|
||||||
|
|
||||||
async def _process_register_agent_type_request(
|
async def _process_register_agent_type_request(
|
||||||
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int
|
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: ClientConnectionId
|
||||||
) -> None:
|
) -> None:
|
||||||
# Register the agent type with the host runtime.
|
# Register the agent type with the host runtime.
|
||||||
async with self._agent_type_to_client_id_lock:
|
async with self._agent_type_to_client_id_lock:
|
||||||
@ -217,7 +225,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _process_add_subscription_request(
|
async def _process_add_subscription_request(
|
||||||
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int
|
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: ClientConnectionId
|
||||||
) -> None:
|
) -> None:
|
||||||
oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
|
oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
|
||||||
subscription: Subscription | None = None
|
subscription: Subscription | None = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user