Add special topic for agent direct messaging (#4385)

* Add special topic for agent direct messaging

* move to base

* update sub counts

* Fix tests
This commit is contained in:
Jack Gerrits 2024-11-26 17:01:25 -05:00 committed by GitHub
parent cf80b1bc14
commit df183be35a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 20 additions and 6 deletions

View File

@ -4,7 +4,7 @@ from _collections_abc import AsyncIterator, Iterator
from asyncio import Future, Task
from typing import Any, Dict, Set
from autogen_core.components._type_prefix_subscription import TypePrefixSubscription
from autogen_core.base._type_prefix_subscription import TypePrefixSubscription
from ..base import Subscription, TopicId
from ..components import TypeSubscription

View File

@ -20,6 +20,7 @@ from ._serialization import MessageSerializer, try_get_known_serializers_for_typ
from ._subscription import Subscription, UnboundSubscription
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
from ._type_prefix_subscription import TypePrefixSubscription
T = TypeVar("T", bound=Agent)
@ -149,6 +150,7 @@ class BaseAgent(ABC, Agent):
factory: Callable[[], Self | Awaitable[Self]],
*,
skip_class_subscriptions: bool = False,
skip_direct_message_subscription: bool = False,
) -> AgentType:
agent_type = AgentType(type)
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
@ -166,6 +168,16 @@ class BaseAgent(ABC, Agent):
for subscription in subscriptions:
await runtime.add_subscription(subscription)
if not skip_direct_message_subscription:
# Additionally adds a special prefix subscription for this agent to receive direct messages
await runtime.add_subscription(
TypePrefixSubscription(
# The prefix MUST include ":" to avoid collisions with other agents
topic_type_prefix=agent_type.type + ":",
agent_type=agent_type.type,
)
)
# TODO: deduplication
for _message_type, serializer in cls._handles_types():
runtime.add_message_serializer(serializer)

View File

@ -1,7 +1,9 @@
import uuid
from ..base import AgentId, Subscription, TopicId
from ..base.exceptions import CantHandleException
from ._agent_id import AgentId
from ._subscription import Subscription
from ._topic import TopicId
from .exceptions import CantHandleException
class TypePrefixSubscription(Subscription):

View File

@ -2,12 +2,12 @@
The :mod:`autogen_core.components` module provides building blocks for creating single agents
"""
from ..base._type_prefix_subscription import TypePrefixSubscription
from ._closure_agent import ClosureAgent
from ._default_subscription import DefaultSubscription, default_subscription, type_subscription
from ._default_topic import DefaultTopicId
from ._image import Image
from ._routed_agent import RoutedAgent, TypeRoutedAgent, event, message_handler, rpc
from ._type_prefix_subscription import TypePrefixSubscription
from ._type_subscription import TypeSubscription
from ._types import FunctionCall

View File

@ -360,7 +360,7 @@ async def test_disconnected_agent() -> None:
)
subscriptions1 = get_current_subscriptions()
assert len(subscriptions1) == 1
assert len(subscriptions1) == 2
recipients1 = await get_subscribed_recipients()
assert AgentId(type="worker1", key="default") in recipients1
@ -388,7 +388,7 @@ async def test_disconnected_agent() -> None:
)
subscriptions3 = get_current_subscriptions()
assert len(subscriptions3) == 1
assert len(subscriptions3) == 2
assert first_subscription_id not in [x.id for x in subscriptions3]
recipients3 = await get_subscribed_recipients()