mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 22:18:53 +00:00
Add subscription factory to AgentRuntime.register (#393)
* Add subscriptions to factory * tests and bug fix
This commit is contained in:
parent
dc847d3985
commit
30d1b50c0d
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from agnext.application import WorkerAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler, TypeSubscription
|
||||
from agnext.components import TypeRoutedAgent, TypeSubscription, message_handler
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId
|
||||
|
||||
|
||||
|
||||
@ -11,16 +11,18 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from agnext.core import AgentType, Subscription, TopicId
|
||||
|
||||
from ..core import (
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentMetadata,
|
||||
AgentRuntime,
|
||||
AgentType,
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
Subscription,
|
||||
SubscriptionInstantiationContext,
|
||||
TopicId,
|
||||
)
|
||||
from ..core.exceptions import MessageDroppedException
|
||||
from ..core.intervention import DropMessage, InterventionHandler
|
||||
@ -459,9 +461,27 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self,
|
||||
type: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
|
||||
| list[Subscription]
|
||||
| None = None,
|
||||
) -> AgentType:
|
||||
if type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
|
||||
if subscriptions is not None:
|
||||
if callable(subscriptions):
|
||||
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
|
||||
subscriptions_list_result = subscriptions()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions
|
||||
|
||||
for subscription in subscriptions_list:
|
||||
await self.add_subscription(subscription)
|
||||
|
||||
self._agent_factories[type] = agent_factory
|
||||
return AgentType(type)
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ from ..core import (
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
Subscription,
|
||||
SubscriptionInstantiationContext,
|
||||
TopicId,
|
||||
)
|
||||
from ._helpers import SubscriptionManager, get_impl
|
||||
@ -388,6 +389,9 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
self,
|
||||
type: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
|
||||
| list[Subscription]
|
||||
| None = None,
|
||||
) -> AgentType:
|
||||
if type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
@ -397,6 +401,21 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type))
|
||||
await self._host_connection.send(message)
|
||||
|
||||
if subscriptions is not None:
|
||||
if callable(subscriptions):
|
||||
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
|
||||
subscriptions_list_result = subscriptions()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions
|
||||
|
||||
for subscription in subscriptions_list:
|
||||
await self.add_subscription(subscription)
|
||||
|
||||
return AgentType(type)
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
|
||||
@ -15,6 +15,7 @@ from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer
|
||||
from ._subscription import Subscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
|
||||
__all__ = [
|
||||
@ -35,4 +36,5 @@ __all__ = [
|
||||
"MessageContext",
|
||||
"Serialization",
|
||||
"AgentType",
|
||||
"SubscriptionInstantiationContext",
|
||||
]
|
||||
|
||||
@ -71,12 +71,16 @@ class AgentRuntime(Protocol):
|
||||
self,
|
||||
type: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
|
||||
| list[Subscription]
|
||||
| None = None,
|
||||
) -> AgentType:
|
||||
"""Register an agent factory with the runtime associated with a specific type. The type must be unique.
|
||||
|
||||
Args:
|
||||
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
subscriptions (Callable[[], list[Subscription]] | list[Subscription] | None, optional): The subscriptions that the agent should be subscribed to. Defaults to None.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
32
python/src/agnext/core/_subscription_context.py
Normal file
32
python/src/agnext/core/_subscription_context.py
Normal file
@ -0,0 +1,32 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ClassVar, Generator
|
||||
|
||||
from agnext.core._agent_type import AgentType
|
||||
|
||||
|
||||
class SubscriptionInstantiationContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"SubscriptionInstantiationContext cannot be instantiated. It is a static class that provides context management for subscription instantiation."
|
||||
)
|
||||
|
||||
SUBSCRIPTION_CONTEXT_VAR: ClassVar[ContextVar[AgentType]] = ContextVar("SUBSCRIPTION_CONTEXT_VAR")
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: AgentType) -> Generator[None, Any, None]:
|
||||
token = SubscriptionInstantiationContext.SUBSCRIPTION_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
SubscriptionInstantiationContext.SUBSCRIPTION_CONTEXT_VAR.reset(token)
|
||||
|
||||
@classmethod
|
||||
def agent_type(cls) -> AgentType:
|
||||
try:
|
||||
return cls.SUBSCRIPTION_CONTEXT_VAR.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"SubscriptionInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.core import AgentId, AgentInstantiationContext
|
||||
from agnext.core import TopicId
|
||||
from agnext.core._subscription import Subscription
|
||||
from agnext.core._subscription_context import SubscriptionInstantiationContext
|
||||
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
|
||||
|
||||
|
||||
@ -76,3 +79,87 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
for i in range(num_agents):
|
||||
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
|
||||
assert agent.num_calls == total_num_calls_expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_factory_explicit_name() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", "name")])
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await runtime.publish_message(MessageType(), topic_id=topic_id)
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_factory_context_var_name() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)])
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await runtime.publish_message(MessageType(), topic_id=topic_id)
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_factory_async() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
async def sub_factory() -> list[Subscription]:
|
||||
await asyncio.sleep(0.1)
|
||||
return [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)]
|
||||
|
||||
await runtime.register("name", LoopbackAgent, sub_factory)
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await runtime.publish_message(MessageType(), topic_id=topic_id)
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_factory_direct_list() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, [TypeSubscription("default", "name")])
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await runtime.publish_message(MessageType(), topic_id=topic_id)
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user