Refactor subscription in single threaded agent runtime (#388)

This commit is contained in:
Eric Zhu 2024-08-21 20:22:10 -07:00 committed by GitHub
parent ed0890525d
commit 494b805080
2 changed files with 60 additions and 46 deletions

View File

@ -1,8 +1,11 @@
from typing import Awaitable, Callable
from collections import defaultdict
from typing import Awaitable, Callable, DefaultDict, List, Set
from ..core._agent import Agent
from ..core._agent_id import AgentId
from ..core._agent_type import AgentType
from ..core._subscription import Subscription
from ..core._topic import TopicId
async def get_impl(
@ -24,3 +27,53 @@ async def get_impl(
await instance_getter(id)
return id
class SubscriptionManager:
def __init__(self) -> None:
self._subscriptions: List[Subscription] = []
self._seen_topics: Set[TopicId] = set()
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
async def add_subscription(self, subscription: Subscription) -> None:
# Check if the subscription already exists
if any(sub.id == subscription.id for sub in self._subscriptions):
raise ValueError("Subscription already exists")
if len(self._seen_topics) > 0:
raise NotImplementedError("Cannot add subscription after topics have been seen yet")
self._subscriptions.append(subscription)
async def remove_subscription(self, id: str) -> None:
# Check if the subscription exists
if not any(sub.id == id for sub in self._subscriptions):
raise ValueError("Subscription does not exist")
def is_not_sub(x: Subscription) -> bool:
return x.id != id
self._subscriptions = list(filter(is_not_sub, self._subscriptions))
# Rebuild the subscriptions
self._rebuild_subscriptions(self._seen_topics)
async def get_subscribed_recipients(self, topic: TopicId) -> List[AgentId]:
if topic not in self._seen_topics:
self._build_for_new_topic(topic)
return self._subscribed_recipients[topic]
# TODO: optimize this...
def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None:
self._subscribed_recipients.clear()
for topic in topics:
self._build_for_new_topic(topic)
def _build_for_new_topic(self, topic: TopicId) -> None:
if topic in self._seen_topics:
return
self._seen_topics.add(topic)
for subscription in self._subscriptions:
if subscription.is_match(topic):
self._subscribed_recipients[topic].append(subscription.map_to_agent(topic))

View File

@ -6,11 +6,10 @@ import logging
import threading
import warnings
from asyncio import CancelledError, Future, Task
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from agnext.core import AgentType, Subscription, TopicId
@ -25,7 +24,7 @@ from ..core import (
)
from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler
from ._helpers import get_impl
from ._helpers import SubscriptionManager, get_impl
logger = logging.getLogger("agnext")
event_logger = logging.getLogger("agnext.events")
@ -132,11 +131,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._intervention_handler = intervention_handler
self._outstanding_tasks = Counter()
self._background_tasks: Set[Task[Any]] = set()
self._subscriptions: List[Subscription] = []
self._seen_topics: Set[TopicId] = set()
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
self._subscription_manager = SubscriptionManager()
self._run_context: RunContext | None = None
@property
@ -286,10 +281,8 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._outstanding_tasks.decrement()
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
self._build_for_new_topic(message_envelope.topic_id)
responses: List[Awaitable[Any]] = []
recipients = self._subscribed_recipients[message_envelope.topic_id]
recipients = await self._subscription_manager.get_subscribed_recipients(message_envelope.topic_id)
for agent_id in recipients:
# Avoid sending the message back to the sender
if message_envelope.sender is not None and agent_id == message_envelope.sender:
@ -522,42 +515,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return agent_instance
async def add_subscription(self, subscription: Subscription) -> None:
# Check if the subscription already exists
if any(sub.id == subscription.id for sub in self._subscriptions):
raise ValueError("Subscription already exists")
if len(self._seen_topics) > 0:
raise NotImplementedError("Cannot add subscription after topics have been seen yet")
self._subscriptions.append(subscription)
await self._subscription_manager.add_subscription(subscription)
async def remove_subscription(self, id: str) -> None:
# Check if the subscription exists
if not any(sub.id == id for sub in self._subscriptions):
raise ValueError("Subscription does not exist")
def is_not_sub(x: Subscription) -> bool:
return x.id != id
self._subscriptions = list(filter(is_not_sub, self._subscriptions))
# Rebuild the subscriptions
self._rebuild_subscriptions(self._seen_topics)
# TODO: optimize this...
def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None:
self._subscribed_recipients.clear()
for topic in topics:
self._build_for_new_topic(topic)
def _build_for_new_topic(self, topic: TopicId) -> None:
if topic in self._seen_topics:
return
self._seen_topics.add(topic)
for subscription in self._subscriptions:
if subscription.is_match(topic):
self._subscribed_recipients[topic].append(subscription.map_to_agent(topic))
await self._subscription_manager.remove_subscription(id)
async def get(
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True