mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-28 07:29:54 +00:00
Refactor subscription in single threaded agent runtime (#388)
This commit is contained in:
parent
ed0890525d
commit
494b805080
@ -1,8 +1,11 @@
|
||||
from typing import Awaitable, Callable
|
||||
from collections import defaultdict
|
||||
from typing import Awaitable, Callable, DefaultDict, List, Set
|
||||
|
||||
from ..core._agent import Agent
|
||||
from ..core._agent_id import AgentId
|
||||
from ..core._agent_type import AgentType
|
||||
from ..core._subscription import Subscription
|
||||
from ..core._topic import TopicId
|
||||
|
||||
|
||||
async def get_impl(
|
||||
@ -24,3 +27,53 @@ async def get_impl(
|
||||
await instance_getter(id)
|
||||
|
||||
return id
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
def __init__(self) -> None:
|
||||
self._subscriptions: List[Subscription] = []
|
||||
self._seen_topics: Set[TopicId] = set()
|
||||
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
# Check if the subscription already exists
|
||||
if any(sub.id == subscription.id for sub in self._subscriptions):
|
||||
raise ValueError("Subscription already exists")
|
||||
|
||||
if len(self._seen_topics) > 0:
|
||||
raise NotImplementedError("Cannot add subscription after topics have been seen yet")
|
||||
|
||||
self._subscriptions.append(subscription)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
# Check if the subscription exists
|
||||
if not any(sub.id == id for sub in self._subscriptions):
|
||||
raise ValueError("Subscription does not exist")
|
||||
|
||||
def is_not_sub(x: Subscription) -> bool:
|
||||
return x.id != id
|
||||
|
||||
self._subscriptions = list(filter(is_not_sub, self._subscriptions))
|
||||
|
||||
# Rebuild the subscriptions
|
||||
self._rebuild_subscriptions(self._seen_topics)
|
||||
|
||||
async def get_subscribed_recipients(self, topic: TopicId) -> List[AgentId]:
|
||||
if topic not in self._seen_topics:
|
||||
self._build_for_new_topic(topic)
|
||||
return self._subscribed_recipients[topic]
|
||||
|
||||
# TODO: optimize this...
|
||||
def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None:
|
||||
self._subscribed_recipients.clear()
|
||||
for topic in topics:
|
||||
self._build_for_new_topic(topic)
|
||||
|
||||
def _build_for_new_topic(self, topic: TopicId) -> None:
|
||||
if topic in self._seen_topics:
|
||||
return
|
||||
|
||||
self._seen_topics.add(topic)
|
||||
for subscription in self._subscriptions:
|
||||
if subscription.is_match(topic):
|
||||
self._subscribed_recipients[topic].append(subscription.map_to_agent(topic))
|
||||
|
||||
@ -6,11 +6,10 @@ import logging
|
||||
import threading
|
||||
import warnings
|
||||
from asyncio import CancelledError, Future, Task
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from agnext.core import AgentType, Subscription, TopicId
|
||||
|
||||
@ -25,7 +24,7 @@ from ..core import (
|
||||
)
|
||||
from ..core.exceptions import MessageDroppedException
|
||||
from ..core.intervention import DropMessage, InterventionHandler
|
||||
from ._helpers import get_impl
|
||||
from ._helpers import SubscriptionManager, get_impl
|
||||
|
||||
logger = logging.getLogger("agnext")
|
||||
event_logger = logging.getLogger("agnext.events")
|
||||
@ -132,11 +131,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._intervention_handler = intervention_handler
|
||||
self._outstanding_tasks = Counter()
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
|
||||
self._subscriptions: List[Subscription] = []
|
||||
self._seen_topics: Set[TopicId] = set()
|
||||
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
|
||||
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._run_context: RunContext | None = None
|
||||
|
||||
@property
|
||||
@ -286,10 +281,8 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._outstanding_tasks.decrement()
|
||||
|
||||
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
||||
self._build_for_new_topic(message_envelope.topic_id)
|
||||
responses: List[Awaitable[Any]] = []
|
||||
|
||||
recipients = self._subscribed_recipients[message_envelope.topic_id]
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(message_envelope.topic_id)
|
||||
for agent_id in recipients:
|
||||
# Avoid sending the message back to the sender
|
||||
if message_envelope.sender is not None and agent_id == message_envelope.sender:
|
||||
@ -522,42 +515,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
return agent_instance
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
# Check if the subscription already exists
|
||||
if any(sub.id == subscription.id for sub in self._subscriptions):
|
||||
raise ValueError("Subscription already exists")
|
||||
|
||||
if len(self._seen_topics) > 0:
|
||||
raise NotImplementedError("Cannot add subscription after topics have been seen yet")
|
||||
|
||||
self._subscriptions.append(subscription)
|
||||
await self._subscription_manager.add_subscription(subscription)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
# Check if the subscription exists
|
||||
if not any(sub.id == id for sub in self._subscriptions):
|
||||
raise ValueError("Subscription does not exist")
|
||||
|
||||
def is_not_sub(x: Subscription) -> bool:
|
||||
return x.id != id
|
||||
|
||||
self._subscriptions = list(filter(is_not_sub, self._subscriptions))
|
||||
|
||||
# Rebuild the subscriptions
|
||||
self._rebuild_subscriptions(self._seen_topics)
|
||||
|
||||
# TODO: optimize this...
|
||||
def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None:
|
||||
self._subscribed_recipients.clear()
|
||||
for topic in topics:
|
||||
self._build_for_new_topic(topic)
|
||||
|
||||
def _build_for_new_topic(self, topic: TopicId) -> None:
|
||||
if topic in self._seen_topics:
|
||||
return
|
||||
|
||||
self._seen_topics.add(topic)
|
||||
for subscription in self._subscriptions:
|
||||
if subscription.is_match(topic):
|
||||
self._subscribed_recipients[topic].append(subscription.map_to_agent(topic))
|
||||
await self._subscription_manager.remove_subscription(id)
|
||||
|
||||
async def get(
|
||||
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user