mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 06:28:50 +00:00
Impl remove and get subscription APIs for python xlang (#5365)
Closes #5297 --------- Co-authored-by: Ryan Sweet <rysweet@microsoft.com> Co-authored-by: Jacob Alber <jaalber@microsoft.com> Co-authored-by: Jacob Alber <jacob.alber@microsoft.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
392aa14491
commit
dc877d5737
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Awaitable, Callable, DefaultDict, List, Set
|
||||
from typing import Awaitable, Callable, DefaultDict, List, Set, Sequence
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
@ -35,6 +35,10 @@ class SubscriptionManager:
|
||||
self._seen_topics: Set[TopicId] = set()
|
||||
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[Subscription]:
|
||||
return self._subscriptions
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
# Check if the subscription already exists
|
||||
if any(sub == subscription for sub in self._subscriptions):
|
||||
|
||||
@ -790,7 +790,15 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
await self._subscription_manager.add_subscription(subscription)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
raise NotImplementedError("Subscriptions cannot be removed while using distributed runtime currently.")
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
|
||||
message = agent_worker_pb2.RemoveSubscriptionRequest(id=id)
|
||||
_response: agent_worker_pb2.RemoveSubscriptionResponse = await self._host_connection.stub.RemoveSubscription(
|
||||
message, metadata=self._host_connection.metadata
|
||||
)
|
||||
|
||||
await self._subscription_manager.remove_subscription(id)
|
||||
|
||||
async def get(
|
||||
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
|
||||
|
||||
@ -11,7 +11,7 @@ from autogen_core._agent_id import AgentId
|
||||
from autogen_core._runtime_impl_helpers import SubscriptionManager
|
||||
|
||||
from ._constants import GRPC_IMPORT_ERROR_STR
|
||||
from ._utils import subscription_from_proto
|
||||
from ._utils import subscription_from_proto, subscription_to_proto
|
||||
|
||||
try:
|
||||
import grpc
|
||||
@ -170,7 +170,11 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
del self._agent_type_to_client_id[agent_type]
|
||||
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()):
|
||||
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
|
||||
await self._subscription_manager.remove_subscription(sub_id)
|
||||
try:
|
||||
await self._subscription_manager.remove_subscription(sub_id)
|
||||
# Catch and ignore if the subscription does not exist.
|
||||
except ValueError:
|
||||
continue
|
||||
logger.info(f"Client {client_id} disconnected successfully")
|
||||
|
||||
def _raise_on_exception(self, task: Task[Any]) -> None:
|
||||
@ -327,7 +331,8 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
],
|
||||
) -> agent_worker_pb2.RemoveSubscriptionResponse:
|
||||
_client_id = await get_client_id_or_abort(context)
|
||||
raise NotImplementedError("Method not implemented.")
|
||||
await self._subscription_manager.remove_subscription(request.id)
|
||||
return agent_worker_pb2.RemoveSubscriptionResponse()
|
||||
|
||||
async def GetSubscriptions( # type: ignore
|
||||
self,
|
||||
@ -337,4 +342,23 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
],
|
||||
) -> agent_worker_pb2.GetSubscriptionsResponse:
|
||||
_client_id = await get_client_id_or_abort(context)
|
||||
raise NotImplementedError("Method not implemented.")
|
||||
subscriptions = self._subscription_manager.subscriptions
|
||||
return agent_worker_pb2.GetSubscriptionsResponse(
|
||||
subscriptions=[subscription_to_proto(sub) for sub in subscriptions]
|
||||
)
|
||||
|
||||
# async def GetState( # type: ignore
|
||||
# self,
|
||||
# request: agent_worker_pb2.AgentId,
|
||||
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse],
|
||||
# ) -> agent_worker_pb2.GetStateResponse:
|
||||
# _client_id = await get_client_id_or_abort(context)
|
||||
# raise NotImplementedError("Method not implemented!")
|
||||
|
||||
# async def SaveState( # type: ignore
|
||||
# self,
|
||||
# request: agent_worker_pb2.AgentState,
|
||||
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse],
|
||||
# ) -> agent_worker_pb2.SaveStateResponse:
|
||||
# _client_id = await get_client_id_or_abort(context)
|
||||
# raise NotImplementedError("Method not implemented!")
|
||||
|
||||
@ -8,6 +8,7 @@ from autogen_core import (
|
||||
PROTOBUF_DATA_CONTENT_TYPE,
|
||||
AgentId,
|
||||
AgentType,
|
||||
DefaultSubscription,
|
||||
DefaultTopicId,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
@ -129,6 +130,47 @@ async def test_register_receives_publish() -> None:
|
||||
|
||||
|
||||
@pytest.mark.grpc
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_doesnt_receive_after_removing_subscription() -> None:
|
||||
host_address = "localhost:50053"
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker1.start()
|
||||
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
await worker1.register_factory(
|
||||
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
|
||||
)
|
||||
sub = DefaultSubscription(agent_type="name1")
|
||||
await worker1.add_subscription(sub)
|
||||
|
||||
agent_1_instance = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent)
|
||||
# Publish message from worker1
|
||||
await worker1.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
|
||||
# Let the agent run for a bit.
|
||||
await agent_1_instance.event.wait()
|
||||
agent_1_instance.event.clear()
|
||||
|
||||
# Agents in default topic source should have received the message.
|
||||
assert agent_1_instance.num_calls == 1
|
||||
|
||||
await worker1.remove_subscription(sub.id)
|
||||
|
||||
# Publish message from worker1
|
||||
await worker1.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
|
||||
# Let the agent run for a bit.
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent should not have received the message.
|
||||
assert agent_1_instance.num_calls == 1
|
||||
|
||||
await worker1.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_cascade_single_worker() -> None:
|
||||
host_address = "localhost:50054"
|
||||
|
||||
@ -16,6 +16,8 @@ from autogen_core import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from asyncio import Event
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageType: ...
|
||||
@ -36,6 +38,7 @@ class LoopbackAgent(RoutedAgent):
|
||||
super().__init__("A loop back agent.")
|
||||
self.num_calls = 0
|
||||
self.received_messages: list[Any] = []
|
||||
self.event = Event()
|
||||
|
||||
@message_handler
|
||||
async def on_new_message(
|
||||
@ -43,6 +46,7 @@ class LoopbackAgent(RoutedAgent):
|
||||
) -> MessageType | ContentMessage:
|
||||
self.num_calls += 1
|
||||
self.received_messages.append(message)
|
||||
self.event.set()
|
||||
return message
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user