mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-24 21:49:42 +00:00
Instantiation context refactor (#293)
* WIP refactor instantiation context * finish up changes * Update python/src/agnext/core/_agent_runtime.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * Update python/src/agnext/core/_agent_runtime.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * add warning --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
d8bf7ee8a8
commit
1f9d5177d3
@ -9,7 +9,7 @@ from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components.memory import ChatMemory
|
||||
from agnext.components.models import ChatCompletionClient, SystemMessage
|
||||
from agnext.core import AgentRuntime, CancellationToken
|
||||
from agnext.core import AgentInstantiationContext, AgentRuntime, CancellationToken
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
@ -97,8 +97,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
)
|
||||
alice = await runtime.register_and_get_proxy(
|
||||
"Alice",
|
||||
lambda rt, id: ChatRoomAgent(
|
||||
name=id.name,
|
||||
lambda: ChatRoomAgent(
|
||||
name=AgentInstantiationContext.current_agent_id().name,
|
||||
description="Alice in the chat room.",
|
||||
background_story="Alice is a software engineer who loves to code.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
@ -107,8 +107,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
)
|
||||
bob = await runtime.register_and_get_proxy(
|
||||
"Bob",
|
||||
lambda rt, id: ChatRoomAgent(
|
||||
name=id.name,
|
||||
lambda: ChatRoomAgent(
|
||||
name=AgentInstantiationContext.current_agent_id().name,
|
||||
description="Bob in the chat room.",
|
||||
background_story="Bob is a data scientist who loves to analyze data.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
@ -117,8 +117,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
)
|
||||
charlie = await runtime.register_and_get_proxy(
|
||||
"Charlie",
|
||||
lambda rt, id: ChatRoomAgent(
|
||||
name=id.name,
|
||||
lambda: ChatRoomAgent(
|
||||
name=AgentInstantiationContext.current_agent_id().name,
|
||||
description="Charlie in the chat room.",
|
||||
background_story="Charlie is a designer who loves to create art.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
|
||||
@ -4,6 +4,7 @@ import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import warnings
|
||||
from asyncio import CancelledError, Future, Task
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
@ -15,11 +16,11 @@ from ..core import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentMetadata,
|
||||
AgentProxy,
|
||||
AgentRuntime,
|
||||
CancellationToken,
|
||||
agent_instantiation_context,
|
||||
)
|
||||
from ..core.exceptions import MessageDroppedException
|
||||
from ..core.intervention import DropMessage, InterventionHandler
|
||||
@ -461,11 +462,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
agent_id: AgentId,
|
||||
) -> T:
|
||||
with agent_instantiation_context((self, agent_id)):
|
||||
with AgentInstantiationContext.populate_context((self, agent_id)):
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
warnings.warn(
|
||||
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
|
||||
stacklevel=2,
|
||||
)
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
|
||||
@ -3,8 +3,9 @@ from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar, get_typ
|
||||
|
||||
from ..core._agent import Agent
|
||||
from ..core._agent_id import AgentId
|
||||
from ..core._agent_instantiation import AgentInstantiationContext
|
||||
from ..core._agent_metadata import AgentMetadata
|
||||
from ..core._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime
|
||||
from ..core._agent_runtime import AgentRuntime
|
||||
from ..core._cancellation_token import CancellationToken
|
||||
from ..core._serialization import MESSAGE_TYPE_REGISTRY
|
||||
from ..core.exceptions import CantHandleException
|
||||
@ -46,8 +47,9 @@ class ClosureAgent(Agent):
|
||||
self, description: str, closure: Callable[[AgentRuntime, AgentId, T, CancellationToken], Awaitable[Any]]
|
||||
) -> None:
|
||||
try:
|
||||
runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get()
|
||||
except LookupError as e:
|
||||
runtime = AgentInstantiationContext.current_runtime()
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
|
||||
) from e
|
||||
|
||||
@ -4,10 +4,11 @@ The :mod:`agnext.core` module provides the foundational generic interfaces upon
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_props import AgentChildren
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime, agent_instantiation_context
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._serialization import MESSAGE_TYPE_REGISTRY, TypeDeserializer, TypeSerializer
|
||||
@ -23,8 +24,7 @@ __all__ = [
|
||||
"BaseAgent",
|
||||
"CancellationToken",
|
||||
"AgentChildren",
|
||||
"agent_instantiation_context",
|
||||
"AGENT_INSTANTIATION_CONTEXT_VAR",
|
||||
"AgentInstantiationContext",
|
||||
"MESSAGE_TYPE_REGISTRY",
|
||||
"TypeSerializer",
|
||||
"TypeDeserializer",
|
||||
|
||||
44
python/src/agnext/core/_agent_instantiation.py
Normal file
44
python/src/agnext/core/_agent_instantiation.py
Normal file
@ -0,0 +1,44 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ClassVar, Generator
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
class AgentInstantiationContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
|
||||
)
|
||||
|
||||
AGENT_INSTANTIATION_CONTEXT_VAR: ClassVar[ContextVar[tuple[AgentRuntime, AgentId]]] = ContextVar(
|
||||
"AGENT_INSTANTIATION_CONTEXT_VAR"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]:
|
||||
token = AgentInstantiationContext.AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AgentInstantiationContext.AGENT_INSTANTIATION_CONTEXT_VAR.reset(token)
|
||||
|
||||
@classmethod
|
||||
def current_runtime(cls) -> AgentRuntime:
|
||||
try:
|
||||
return cls.AGENT_INSTANTIATION_CONTEXT_VAR.get()[0]
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def current_agent_id(cls) -> AgentId:
|
||||
try:
|
||||
return cls.AGENT_INSTANTIATION_CONTEXT_VAR.get()[1]
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
|
||||
from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, runtime_checkable
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
@ -14,19 +12,6 @@ from ._cancellation_token import CancellationToken
|
||||
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
AGENT_INSTANTIATION_CONTEXT_VAR: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar(
|
||||
"AGENT_INSTANTIATION_CONTEXT_VAR"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def agent_instantiation_context(ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]:
|
||||
token = AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AGENT_INSTANTIATION_CONTEXT_VAR.reset(token)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentRuntime(Protocol):
|
||||
@ -79,30 +64,16 @@ class AgentRuntime(Protocol):
|
||||
UndeliverableException: If the message cannot be delivered.
|
||||
"""
|
||||
|
||||
@overload
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
) -> None: ...
|
||||
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
) -> None:
|
||||
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
|
||||
|
||||
Args:
|
||||
name (str): The name of the type agent this factory creates.
|
||||
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent, where T is a concrete Agent type.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
|
||||
|
||||
Example:
|
||||
@ -166,36 +137,18 @@ class AgentRuntime(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
async def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentId: ...
|
||||
|
||||
@overload
|
||||
async def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentId: ...
|
||||
|
||||
async def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentId:
|
||||
"""Register an agent factory with the runtime associated with a specific name and get the agent id. The name must be unique.
|
||||
|
||||
Args:
|
||||
name (str): The name of the type agent this factory creates.
|
||||
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent, where T is a concrete Agent type.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
namespace (str, optional): The namespace of the agent. Defaults to "default".
|
||||
|
||||
Returns:
|
||||
@ -204,36 +157,18 @@ class AgentRuntime(Protocol):
|
||||
await self.register(name, agent_factory)
|
||||
return await self.get(name, namespace=namespace)
|
||||
|
||||
@overload
|
||||
async def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentProxy: ...
|
||||
|
||||
@overload
|
||||
async def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentProxy: ...
|
||||
|
||||
async def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentProxy:
|
||||
"""Register an agent factory with the runtime associated with a specific name and get the agent proxy. The name must be unique.
|
||||
|
||||
Args:
|
||||
name (str): The name of the type agent this factory creates.
|
||||
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent, where T is a concrete Agent type.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type.
|
||||
namespace (str, optional): The namespace of the agent. Defaults to "default".
|
||||
|
||||
Returns:
|
||||
|
||||
@ -4,8 +4,9 @@ from typing import Any, Mapping, Sequence
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
|
||||
@ -22,7 +23,8 @@ class BaseAgent(ABC, Agent):
|
||||
|
||||
def __init__(self, description: str, subscriptions: Sequence[str]) -> None:
|
||||
try:
|
||||
runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get()
|
||||
runtime = AgentInstantiationContext.current_runtime()
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
|
||||
|
||||
@ -3,6 +3,7 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import warnings
|
||||
from asyncio import Future, Task
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
@ -30,16 +31,9 @@ import grpc
|
||||
from grpc.aio import StreamStreamCall
|
||||
from typing_extensions import Self
|
||||
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, agent_instantiation_context
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY
|
||||
|
||||
from ..core import (
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentMetadata,
|
||||
AgentProxy,
|
||||
AgentRuntime,
|
||||
CancellationToken,
|
||||
)
|
||||
from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentProxy, AgentRuntime, CancellationToken
|
||||
from .protos import AgentId as AgentIdProto
|
||||
from .protos import (
|
||||
AgentRpcStub,
|
||||
@ -371,11 +365,15 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
agent_id: AgentId,
|
||||
) -> T:
|
||||
with agent_instantiation_context((self, agent_id)):
|
||||
with AgentInstantiationContext.populate_context((self, agent_id)):
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
warnings.warn(
|
||||
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
|
||||
stacklevel=2,
|
||||
)
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from agnext.core import AgentRuntime, AGENT_INSTANTIATION_CONTEXT_VAR, AgentId
|
||||
from agnext.core import AgentRuntime, AgentInstantiationContext, AgentId
|
||||
|
||||
from test_utils import NoopAgent
|
||||
|
||||
@ -11,9 +11,8 @@ async def test_base_agent_create(mocker: MockerFixture) -> None:
|
||||
runtime = mocker.Mock(spec=AgentRuntime)
|
||||
|
||||
# Shows how to set the context for the agent instantiation in a test context
|
||||
AGENT_INSTANTIATION_CONTEXT_VAR.set((runtime, AgentId("name", "namespace")))
|
||||
|
||||
agent = NoopAgent()
|
||||
assert agent.runtime == runtime
|
||||
assert agent.id == AgentId("name", "namespace")
|
||||
with AgentInstantiationContext.populate_context((runtime, AgentId("name", "namespace"))):
|
||||
agent = NoopAgent()
|
||||
assert agent.runtime == runtime
|
||||
assert agent.id == AgentId("name", "namespace")
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.core import AgentId, AgentRuntime
|
||||
from agnext.core import AgentId, AgentInstantiationContext
|
||||
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
|
||||
|
||||
|
||||
@ -8,7 +8,8 @@ from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, Mess
|
||||
async def test_agent_names_must_be_unique() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
def agent_factory(runtime: AgentRuntime, id: AgentId) -> NoopAgent:
|
||||
def agent_factory() -> NoopAgent:
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
assert id == AgentId("name1", "default")
|
||||
agent = NoopAgent()
|
||||
assert agent.id == id
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user