Add subscription factory to AgentRuntime.register (#393)

* Add subscriptions to factory

* tests and bug fix
This commit is contained in:
Jack Gerrits 2024-08-22 16:53:35 -04:00 committed by GitHub
parent dc847d3985
commit 30d1b50c0d
7 changed files with 167 additions and 3 deletions

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Any, NoReturn
from agnext.application import WorkerAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler, TypeSubscription
from agnext.components import TypeRoutedAgent, TypeSubscription, message_handler
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId

View File

@ -11,16 +11,18 @@ from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from agnext.core import AgentType, Subscription, TopicId
from ..core import (
Agent,
AgentId,
AgentInstantiationContext,
AgentMetadata,
AgentRuntime,
AgentType,
CancellationToken,
MessageContext,
Subscription,
SubscriptionInstantiationContext,
TopicId,
)
from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler
@ -459,9 +461,27 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self,
type: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
| list[Subscription]
| None = None,
) -> AgentType:
if type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
if subscriptions is not None:
if callable(subscriptions):
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
subscriptions_list_result = subscriptions()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result
else:
subscriptions_list = subscriptions_list_result
else:
subscriptions_list = subscriptions
for subscription in subscriptions_list:
await self.add_subscription(subscription)
self._agent_factories[type] = agent_factory
return AgentType(type)

View File

@ -40,6 +40,7 @@ from ..core import (
CancellationToken,
MessageContext,
Subscription,
SubscriptionInstantiationContext,
TopicId,
)
from ._helpers import SubscriptionManager, get_impl
@ -388,6 +389,9 @@ class WorkerAgentRuntime(AgentRuntime):
self,
type: str,
agent_factory: Callable[[], T | Awaitable[T]],
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
| list[Subscription]
| None = None,
) -> AgentType:
if type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
@ -397,6 +401,21 @@ class WorkerAgentRuntime(AgentRuntime):
raise RuntimeError("Host connection is not set.")
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type))
await self._host_connection.send(message)
if subscriptions is not None:
if callable(subscriptions):
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
subscriptions_list_result = subscriptions()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result
else:
subscriptions_list = subscriptions_list_result
else:
subscriptions_list = subscriptions
for subscription in subscriptions_list:
await self.add_subscription(subscription)
return AgentType(type)
async def _invoke_agent_factory(

View File

@ -15,6 +15,7 @@ from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer
from ._subscription import Subscription
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
__all__ = [
@ -35,4 +36,5 @@ __all__ = [
"MessageContext",
"Serialization",
"AgentType",
"SubscriptionInstantiationContext",
]

View File

@ -71,12 +71,16 @@ class AgentRuntime(Protocol):
self,
type: str,
agent_factory: Callable[[], T | Awaitable[T]],
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
| list[Subscription]
| None = None,
) -> AgentType:
"""Register an agent factory with the runtime associated with a specific type. The type must be unique.
Args:
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
subscriptions (Callable[[], list[Subscription]] | list[Subscription] | None, optional): The subscriptions that the agent should be subscribed to. Defaults to None.
Example:

View File

@ -0,0 +1,32 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, ClassVar, Generator
from agnext.core._agent_type import AgentType
class SubscriptionInstantiationContext:
def __init__(self) -> None:
raise RuntimeError(
"SubscriptionInstantiationContext cannot be instantiated. It is a static class that provides context management for subscription instantiation."
)
SUBSCRIPTION_CONTEXT_VAR: ClassVar[ContextVar[AgentType]] = ContextVar("SUBSCRIPTION_CONTEXT_VAR")
@classmethod
@contextmanager
def populate_context(cls, ctx: AgentType) -> Generator[None, Any, None]:
token = SubscriptionInstantiationContext.SUBSCRIPTION_CONTEXT_VAR.set(ctx)
try:
yield
finally:
SubscriptionInstantiationContext.SUBSCRIPTION_CONTEXT_VAR.reset(token)
@classmethod
def agent_type(cls) -> AgentType:
try:
return cls.SUBSCRIPTION_CONTEXT_VAR.get()
except LookupError as e:
raise RuntimeError(
"SubscriptionInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
) from e

View File

@ -1,8 +1,11 @@
import asyncio
import pytest
from agnext.application import SingleThreadedAgentRuntime
from agnext.components._type_subscription import TypeSubscription
from agnext.core import AgentId, AgentInstantiationContext
from agnext.core import TopicId
from agnext.core._subscription import Subscription
from agnext.core._subscription_context import SubscriptionInstantiationContext
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
@ -76,3 +79,87 @@ async def test_register_receives_publish_cascade() -> None:
for i in range(num_agents):
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected
@pytest.mark.asyncio
async def test_register_factory_explicit_name() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", "name")])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_factory_context_var_name() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_factory_async() -> None:
runtime = SingleThreadedAgentRuntime()
async def sub_factory() -> list[Subscription]:
await asyncio.sleep(0.1)
return [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)]
await runtime.register("name", LoopbackAgent, sub_factory)
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_factory_direct_list() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, [TypeSubscription("default", "name")])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0