Instantiation context refactor (#293)

* WIP refactor instantiation context

* finish up changes

* Update python/src/agnext/core/_agent_runtime.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* Update python/src/agnext/core/_agent_runtime.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* add warning

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits 2024-08-02 11:02:45 -04:00 committed by GitHub
parent d8bf7ee8a8
commit 1f9d5177d3
10 changed files with 90 additions and 104 deletions

View File

@ -9,7 +9,7 @@ from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.memory import ChatMemory
from agnext.components.models import ChatCompletionClient, SystemMessage
from agnext.core import AgentRuntime, CancellationToken
from agnext.core import AgentInstantiationContext, AgentRuntime, CancellationToken
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -97,8 +97,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
)
alice = await runtime.register_and_get_proxy(
"Alice",
lambda rt, id: ChatRoomAgent(
name=id.name,
lambda: ChatRoomAgent(
name=AgentInstantiationContext.current_agent_id().name,
description="Alice in the chat room.",
background_story="Alice is a software engineer who loves to code.",
memory=BufferedChatMemory(buffer_size=10),
@ -107,8 +107,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
)
bob = await runtime.register_and_get_proxy(
"Bob",
lambda rt, id: ChatRoomAgent(
name=id.name,
lambda: ChatRoomAgent(
name=AgentInstantiationContext.current_agent_id().name,
description="Bob in the chat room.",
background_story="Bob is a data scientist who loves to analyze data.",
memory=BufferedChatMemory(buffer_size=10),
@ -117,8 +117,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
)
charlie = await runtime.register_and_get_proxy(
"Charlie",
lambda rt, id: ChatRoomAgent(
name=id.name,
lambda: ChatRoomAgent(
name=AgentInstantiationContext.current_agent_id().name,
description="Charlie in the chat room.",
background_story="Charlie is a designer who loves to create art.",
memory=BufferedChatMemory(buffer_size=10),

View File

@ -4,6 +4,7 @@ import asyncio
import inspect
import logging
import threading
import warnings
from asyncio import CancelledError, Future, Task
from collections import defaultdict
from collections.abc import Sequence
@ -15,11 +16,11 @@ from ..core import (
MESSAGE_TYPE_REGISTRY,
Agent,
AgentId,
AgentInstantiationContext,
AgentMetadata,
AgentProxy,
AgentRuntime,
CancellationToken,
agent_instantiation_context,
)
from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler
@ -461,11 +462,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
agent_id: AgentId,
) -> T:
with agent_instantiation_context((self, agent_id)):
with AgentInstantiationContext.populate_context((self, agent_id)):
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
warnings.warn(
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
stacklevel=2,
)
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:

View File

@ -3,8 +3,9 @@ from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar, get_typ
from ..core._agent import Agent
from ..core._agent_id import AgentId
from ..core._agent_instantiation import AgentInstantiationContext
from ..core._agent_metadata import AgentMetadata
from ..core._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime
from ..core._agent_runtime import AgentRuntime
from ..core._cancellation_token import CancellationToken
from ..core._serialization import MESSAGE_TYPE_REGISTRY
from ..core.exceptions import CantHandleException
@ -46,8 +47,9 @@ class ClosureAgent(Agent):
self, description: str, closure: Callable[[AgentRuntime, AgentId, T, CancellationToken], Awaitable[Any]]
) -> None:
try:
runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get()
except LookupError as e:
runtime = AgentInstantiationContext.current_runtime()
id = AgentInstantiationContext.current_agent_id()
except Exception as e:
raise RuntimeError(
"ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
) from e

View File

@ -4,10 +4,11 @@ The :mod:`agnext.core` module provides the foundational generic interfaces upon
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_instantiation import AgentInstantiationContext
from ._agent_metadata import AgentMetadata
from ._agent_props import AgentChildren
from ._agent_proxy import AgentProxy
from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime, agent_instantiation_context
from ._agent_runtime import AgentRuntime
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
from ._serialization import MESSAGE_TYPE_REGISTRY, TypeDeserializer, TypeSerializer
@ -23,8 +24,7 @@ __all__ = [
"BaseAgent",
"CancellationToken",
"AgentChildren",
"agent_instantiation_context",
"AGENT_INSTANTIATION_CONTEXT_VAR",
"AgentInstantiationContext",
"MESSAGE_TYPE_REGISTRY",
"TypeSerializer",
"TypeDeserializer",

View File

@ -0,0 +1,44 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, ClassVar, Generator
from ._agent_id import AgentId
from ._agent_runtime import AgentRuntime
class AgentInstantiationContext:
def __init__(self) -> None:
raise RuntimeError(
"AgentInstantiationContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
)
AGENT_INSTANTIATION_CONTEXT_VAR: ClassVar[ContextVar[tuple[AgentRuntime, AgentId]]] = ContextVar(
"AGENT_INSTANTIATION_CONTEXT_VAR"
)
@classmethod
@contextmanager
def populate_context(cls, ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]:
token = AgentInstantiationContext.AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx)
try:
yield
finally:
AgentInstantiationContext.AGENT_INSTANTIATION_CONTEXT_VAR.reset(token)
@classmethod
def current_runtime(cls) -> AgentRuntime:
try:
return cls.AGENT_INSTANTIATION_CONTEXT_VAR.get()[0]
except LookupError as e:
raise RuntimeError(
"AgentInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
) from e
@classmethod
def current_agent_id(cls) -> AgentId:
try:
return cls.AGENT_INSTANTIATION_CONTEXT_VAR.get()[1]
except LookupError as e:
raise RuntimeError(
"AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
) from e

View File

@ -1,8 +1,6 @@
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, runtime_checkable
from ._agent import Agent
from ._agent_id import AgentId
@ -14,19 +12,6 @@ from ._cancellation_token import CancellationToken
T = TypeVar("T", bound=Agent)
AGENT_INSTANTIATION_CONTEXT_VAR: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar(
"AGENT_INSTANTIATION_CONTEXT_VAR"
)
@contextmanager
def agent_instantiation_context(ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]:
token = AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx)
try:
yield
finally:
AGENT_INSTANTIATION_CONTEXT_VAR.reset(token)
@runtime_checkable
class AgentRuntime(Protocol):
@ -79,30 +64,16 @@ class AgentRuntime(Protocol):
UndeliverableException: If the message cannot be delivered.
"""
@overload
async def register(
self,
name: str,
agent_factory: Callable[[], T | Awaitable[T]],
) -> None: ...
@overload
async def register(
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None: ...
async def register(
self,
name: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None:
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent, where T is a concrete Agent type.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
Example:
@ -166,36 +137,18 @@ class AgentRuntime(Protocol):
"""
...
@overload
async def register_and_get(
self,
name: str,
agent_factory: Callable[[], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentId: ...
@overload
async def register_and_get(
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentId: ...
async def register_and_get(
self,
name: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentId:
"""Register an agent factory with the runtime associated with a specific name and get the agent id. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent, where T is a concrete Agent type.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:
@ -204,36 +157,18 @@ class AgentRuntime(Protocol):
await self.register(name, agent_factory)
return await self.get(name, namespace=namespace)
@overload
async def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentProxy: ...
@overload
async def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentProxy: ...
async def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentProxy:
"""Register an agent factory with the runtime associated with a specific name and get the agent proxy. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent, where T is a concrete Agent type.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type.
namespace (str, optional): The namespace of the agent. Defaults to "default".
Returns:

View File

@ -4,8 +4,9 @@ from typing import Any, Mapping, Sequence
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_instantiation import AgentInstantiationContext
from ._agent_metadata import AgentMetadata
from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime
from ._agent_runtime import AgentRuntime
from ._cancellation_token import CancellationToken
@ -22,7 +23,8 @@ class BaseAgent(ABC, Agent):
def __init__(self, description: str, subscriptions: Sequence[str]) -> None:
try:
runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get()
runtime = AgentInstantiationContext.current_runtime()
id = AgentInstantiationContext.current_agent_id()
except LookupError as e:
raise RuntimeError(
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."

View File

@ -3,6 +3,7 @@ import inspect
import json
import logging
import threading
import warnings
from asyncio import Future, Task
from collections import defaultdict
from collections.abc import Sequence
@ -30,16 +31,9 @@ import grpc
from grpc.aio import StreamStreamCall
from typing_extensions import Self
from agnext.core import MESSAGE_TYPE_REGISTRY, agent_instantiation_context
from agnext.core import MESSAGE_TYPE_REGISTRY
from ..core import (
Agent,
AgentId,
AgentMetadata,
AgentProxy,
AgentRuntime,
CancellationToken,
)
from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentProxy, AgentRuntime, CancellationToken
from .protos import AgentId as AgentIdProto
from .protos import (
AgentRpcStub,
@ -371,11 +365,15 @@ class WorkerAgentRuntime(AgentRuntime):
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
agent_id: AgentId,
) -> T:
with agent_instantiation_context((self, agent_id)):
with AgentInstantiationContext.populate_context((self, agent_id)):
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
warnings.warn(
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
stacklevel=2,
)
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:

View File

@ -1,6 +1,6 @@
import pytest
from pytest_mock import MockerFixture
from agnext.core import AgentRuntime, AGENT_INSTANTIATION_CONTEXT_VAR, AgentId
from agnext.core import AgentRuntime, AgentInstantiationContext, AgentId
from test_utils import NoopAgent
@ -11,9 +11,8 @@ async def test_base_agent_create(mocker: MockerFixture) -> None:
runtime = mocker.Mock(spec=AgentRuntime)
# Shows how to set the context for the agent instantiation in a test context
AGENT_INSTANTIATION_CONTEXT_VAR.set((runtime, AgentId("name", "namespace")))
agent = NoopAgent()
assert agent.runtime == runtime
assert agent.id == AgentId("name", "namespace")
with AgentInstantiationContext.populate_context((runtime, AgentId("name", "namespace"))):
agent = NoopAgent()
assert agent.runtime == runtime
assert agent.id == AgentId("name", "namespace")

View File

@ -1,6 +1,6 @@
import pytest
from agnext.application import SingleThreadedAgentRuntime
from agnext.core import AgentId, AgentRuntime
from agnext.core import AgentId, AgentInstantiationContext
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
@ -8,7 +8,8 @@ from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, Mess
async def test_agent_names_must_be_unique() -> None:
runtime = SingleThreadedAgentRuntime()
def agent_factory(runtime: AgentRuntime, id: AgentId) -> NoopAgent:
def agent_factory() -> NoopAgent:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId("name1", "default")
agent = NoopAgent()
assert agent.id == id