From d6dce9ebb1ca27dc33f1fb3eae1e7611d6e665dc Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 19 Sep 2024 15:50:59 -0400 Subject: [PATCH] allow class associated subscriptions to be skipped on register (#587) * allow class associated subscriptions to be skipped on register * format --- .../src/autogen_core/base/_base_agent.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index cb0da711c..69f3d1c4c 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -144,23 +144,28 @@ class BaseAgent(ABC, Agent): @classmethod async def register( - cls, runtime: AgentRuntime, type: str, factory: Callable[[], Self | Awaitable[Self]] + cls, + runtime: AgentRuntime, + type: str, + factory: Callable[[], Self | Awaitable[Self]], + *, + skip_class_subscriptions: bool = False, ) -> AgentType: agent_type = AgentType(type) - with SubscriptionInstantiationContext.populate_context(agent_type): - subscriptions = [] - for unbound_subscription in cls._unbound_subscriptions(): - subscriptions_list_result = unbound_subscription() - if inspect.isawaitable(subscriptions_list_result): - subscriptions_list = await subscriptions_list_result - else: - subscriptions_list = subscriptions_list_result - - subscriptions.extend(subscriptions_list) - agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls) - for subscription in subscriptions: - await runtime.add_subscription(subscription) + if not skip_class_subscriptions: + with SubscriptionInstantiationContext.populate_context(agent_type): + subscriptions = [] + for unbound_subscription in cls._unbound_subscriptions(): + subscriptions_list_result = unbound_subscription() + if inspect.isawaitable(subscriptions_list_result): + subscriptions_list = await subscriptions_list_result + else: + subscriptions_list = subscriptions_list_result + + subscriptions.extend(subscriptions_list) + for subscription in subscriptions: + await runtime.add_subscription(subscription) # TODO: deduplication for _message_type, serializer in cls._handles_types():