mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-12 07:21:18 +00:00
allow class associated subscriptions to be skipped on register (#587)
* allow class associated subscriptions to be skipped on register * format
This commit is contained in:
parent
7f25d28aac
commit
d6dce9ebb1
@ -144,23 +144,28 @@ class BaseAgent(ABC, Agent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def register(
|
async def register(
|
||||||
cls, runtime: AgentRuntime, type: str, factory: Callable[[], Self | Awaitable[Self]]
|
cls,
|
||||||
|
runtime: AgentRuntime,
|
||||||
|
type: str,
|
||||||
|
factory: Callable[[], Self | Awaitable[Self]],
|
||||||
|
*,
|
||||||
|
skip_class_subscriptions: bool = False,
|
||||||
) -> AgentType:
|
) -> AgentType:
|
||||||
agent_type = AgentType(type)
|
agent_type = AgentType(type)
|
||||||
with SubscriptionInstantiationContext.populate_context(agent_type):
|
|
||||||
subscriptions = []
|
|
||||||
for unbound_subscription in cls._unbound_subscriptions():
|
|
||||||
subscriptions_list_result = unbound_subscription()
|
|
||||||
if inspect.isawaitable(subscriptions_list_result):
|
|
||||||
subscriptions_list = await subscriptions_list_result
|
|
||||||
else:
|
|
||||||
subscriptions_list = subscriptions_list_result
|
|
||||||
|
|
||||||
subscriptions.extend(subscriptions_list)
|
|
||||||
|
|
||||||
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
|
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
|
||||||
for subscription in subscriptions:
|
if not skip_class_subscriptions:
|
||||||
await runtime.add_subscription(subscription)
|
with SubscriptionInstantiationContext.populate_context(agent_type):
|
||||||
|
subscriptions = []
|
||||||
|
for unbound_subscription in cls._unbound_subscriptions():
|
||||||
|
subscriptions_list_result = unbound_subscription()
|
||||||
|
if inspect.isawaitable(subscriptions_list_result):
|
||||||
|
subscriptions_list = await subscriptions_list_result
|
||||||
|
else:
|
||||||
|
subscriptions_list = subscriptions_list_result
|
||||||
|
|
||||||
|
subscriptions.extend(subscriptions_list)
|
||||||
|
for subscription in subscriptions:
|
||||||
|
await runtime.add_subscription(subscription)
|
||||||
|
|
||||||
# TODO: deduplication
|
# TODO: deduplication
|
||||||
for _message_type, serializer in cls._handles_types():
|
for _message_type, serializer in cls._handles_types():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user