Impl remove and get subscription APIs for python xlang (#5365)

Closes #5297

---------

Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
Co-authored-by: Jacob Alber <jaalber@microsoft.com>
Co-authored-by: Jacob Alber <jacob.alber@microsoft.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits 2025-02-11 17:42:09 -05:00 committed by GitHub
parent 392aa14491
commit dc877d5737
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 88 additions and 6 deletions

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Awaitable, Callable, DefaultDict, List, Set
from typing import Awaitable, Callable, DefaultDict, List, Set, Sequence
from ._agent import Agent
from ._agent_id import AgentId
@ -35,6 +35,10 @@ class SubscriptionManager:
self._seen_topics: Set[TopicId] = set()
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
@property
def subscriptions(self) -> Sequence[Subscription]:
return self._subscriptions
async def add_subscription(self, subscription: Subscription) -> None:
# Check if the subscription already exists
if any(sub == subscription for sub in self._subscriptions):

View File

@ -790,7 +790,15 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
await self._subscription_manager.add_subscription(subscription)
async def remove_subscription(self, id: str) -> None:
raise NotImplementedError("Subscriptions cannot be removed while using distributed runtime currently.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message = agent_worker_pb2.RemoveSubscriptionRequest(id=id)
_response: agent_worker_pb2.RemoveSubscriptionResponse = await self._host_connection.stub.RemoveSubscription(
message, metadata=self._host_connection.metadata
)
await self._subscription_manager.remove_subscription(id)
async def get(
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True

View File

@ -11,7 +11,7 @@ from autogen_core._agent_id import AgentId
from autogen_core._runtime_impl_helpers import SubscriptionManager
from ._constants import GRPC_IMPORT_ERROR_STR
from ._utils import subscription_from_proto
from ._utils import subscription_from_proto, subscription_to_proto
try:
import grpc
@ -170,7 +170,11 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
del self._agent_type_to_client_id[agent_type]
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()):
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
await self._subscription_manager.remove_subscription(sub_id)
try:
await self._subscription_manager.remove_subscription(sub_id)
# Catch and ignore if the subscription does not exist.
except ValueError:
continue
logger.info(f"Client {client_id} disconnected successfully")
def _raise_on_exception(self, task: Task[Any]) -> None:
@ -327,7 +331,8 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
],
) -> agent_worker_pb2.RemoveSubscriptionResponse:
_client_id = await get_client_id_or_abort(context)
raise NotImplementedError("Method not implemented.")
await self._subscription_manager.remove_subscription(request.id)
return agent_worker_pb2.RemoveSubscriptionResponse()
async def GetSubscriptions( # type: ignore
self,
@ -337,4 +342,23 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
],
) -> agent_worker_pb2.GetSubscriptionsResponse:
_client_id = await get_client_id_or_abort(context)
raise NotImplementedError("Method not implemented.")
subscriptions = self._subscription_manager.subscriptions
return agent_worker_pb2.GetSubscriptionsResponse(
subscriptions=[subscription_to_proto(sub) for sub in subscriptions]
)
# async def GetState( # type: ignore
# self,
# request: agent_worker_pb2.AgentId,
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse],
# ) -> agent_worker_pb2.GetStateResponse:
# _client_id = await get_client_id_or_abort(context)
# raise NotImplementedError("Method not implemented!")
# async def SaveState( # type: ignore
# self,
# request: agent_worker_pb2.AgentState,
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse],
# ) -> agent_worker_pb2.SaveStateResponse:
# _client_id = await get_client_id_or_abort(context)
# raise NotImplementedError("Method not implemented!")

View File

@ -8,6 +8,7 @@ from autogen_core import (
PROTOBUF_DATA_CONTENT_TYPE,
AgentId,
AgentType,
DefaultSubscription,
DefaultTopicId,
MessageContext,
RoutedAgent,
@ -129,6 +130,47 @@ async def test_register_receives_publish() -> None:
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_register_doesnt_receive_after_removing_subscription() -> None:
host_address = "localhost:50053"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
sub = DefaultSubscription(agent_type="name1")
await worker1.add_subscription(sub)
agent_1_instance = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent)
# Publish message from worker1
await worker1.publish_message(MessageType(), topic_id=DefaultTopicId())
# Let the agent run for a bit.
await agent_1_instance.event.wait()
agent_1_instance.event.clear()
# Agents in default topic source should have received the message.
assert agent_1_instance.num_calls == 1
await worker1.remove_subscription(sub.id)
# Publish message from worker1
await worker1.publish_message(MessageType(), topic_id=DefaultTopicId())
# Let the agent run for a bit.
await asyncio.sleep(2)
# Agent should not have received the message.
assert agent_1_instance.num_calls == 1
await worker1.stop()
await host.stop()
@pytest.mark.asyncio
async def test_register_receives_publish_cascade_single_worker() -> None:
host_address = "localhost:50054"

View File

@ -16,6 +16,8 @@ from autogen_core import (
)
from pydantic import BaseModel
from asyncio import Event
@dataclass
class MessageType: ...
@ -36,6 +38,7 @@ class LoopbackAgent(RoutedAgent):
super().__init__("A loop back agent.")
self.num_calls = 0
self.received_messages: list[Any] = []
self.event = Event()
@message_handler
async def on_new_message(
@ -43,6 +46,7 @@ class LoopbackAgent(RoutedAgent):
) -> MessageType | ContentMessage:
self.num_calls += 1
self.received_messages.append(message)
self.event.set()
return message