simplify namespace usage (#116)

* simplify namespace usage

* format

* pyright
This commit is contained in:
Jack Gerrits 2024-06-24 16:52:09 -04:00 committed by GitHub
parent 606e43b325
commit 6189fdb05c
9 changed files with 32 additions and 96 deletions

View File

@ -1,7 +1,6 @@
# Namespace
A namespace is a logical boundary between agents. By default, agents in one
namespace cannot communicate with agents in another namespace.
Namespace allow for defining logical boundaries between agents.
Namespaces are strings, and the default is `default`.
@ -15,4 +14,6 @@ Two possible use cases of agents are:
The {py:class}`agnext.core.AgentId` is used to address an agent, it is the combination of the agent's namespace and its name.
When getting an agent reference ({py:meth}`agnext.core.AgentRuntime.get`) or proxy ({py:meth}`agnext.core.AgentRuntime.get_proxy`) from the runtime the namespace can be specified. Agents have an ID property ({py:attr}`agnext.core.Agent.id`) that returns the agent's id. Additionally, the register method takes a factory that can optionally accept the ID as an argument ({py:meth}`agnext.core.AgentRuntime.register`).
When getting an agent reference ({py:meth}`agnext.core.AgentRuntime.get`) or proxy ({py:meth}`agnext.core.AgentRuntime.get_proxy`) from the runtime the namespace can be specified. Agents have an ID property ({py:attr}`agnext.core.Agent.id`) that returns the agent's id. Additionally, the register method takes a factory that can optionally accept the ID as an argument ({py:meth}`agnext.core.AgentRuntime.register`).
By default, there are no restrictions and are left to the application to enforce. The runtime will however automatically create agents in a namespace if it does not exist.

View File

@ -68,7 +68,7 @@ Finally, we add this handler to the runtime and use it to detect termination and
async def main() -> None:
termination_handler = TerminationHandler()
runtime = SingleThreadedAgentRuntime(
before_send=termination_handler
intervention_handler=termination_handler
)
# Add Agents and kick off task

View File

@ -267,7 +267,7 @@ class DisplayAgent(TypeRoutedAgent):
async def main() -> None:
termination_handler = TerminationHandler()
runtime = SingleThreadedAgentRuntime(before_send=termination_handler)
runtime = SingleThreadedAgentRuntime(intervention_handler=termination_handler)
runtime.register(
"ReviewerAgent",
lambda: ReviewerAgent(

View File

@ -133,7 +133,7 @@ class DisplayAgent(TypeRoutedAgent):
async def main() -> None:
termination_handler = TerminationHandler()
runtime = SingleThreadedAgentRuntime(before_send=termination_handler)
runtime = SingleThreadedAgentRuntime(intervention_handler=termination_handler)
# TODO: use different models for each agent.
runtime.register(
"ReferenceAgent1",

View File

@ -6,7 +6,7 @@ from asyncio import Future
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
from ..core import (
Agent,
@ -14,7 +14,6 @@ from ..core import (
AgentMetadata,
AgentProxy,
AgentRuntime,
AllNamespaces,
CancellationToken,
agent_instantiation_context,
)
@ -82,15 +81,13 @@ class Counter:
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._per_type_subscribers: DefaultDict[tuple[str, type], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
# If empty, then all namespaces are valid for that agent type
self._valid_namespaces: Dict[str, Sequence[str]] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._before_send = before_send
self._intervention_handler = intervention_handler
self._known_namespaces: set[str] = set()
self._outstanding_tasks = Counter()
@ -322,9 +319,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
match message_envelope:
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
if self._intervention_handler is not None:
try:
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
temp_message = await self._intervention_handler.on_send(
message, sender=sender, recipient=recipient
)
except BaseException as e:
future.set_exception(e)
return
@ -339,9 +338,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message=message,
sender=sender,
):
if self._before_send is not None:
if self._intervention_handler is not None:
try:
temp_message = await self._before_send.on_publish(message, sender=sender)
temp_message = await self._intervention_handler.on_publish(message, sender=sender)
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
@ -354,9 +353,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._outstanding_tasks.increment()
asyncio.create_task(self._process_publish(message_envelope))
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
if self._intervention_handler is not None:
try:
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
temp_message = await self._intervention_handler.on_response(
message, sender=sender, recipient=recipient
)
except BaseException as e:
# TODO: should we raise the exception to sender of the response instead?
future.set_exception(e)
@ -385,21 +386,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> None:
if name in self._agent_factories:
raise ValueError(f"Agent with name {name} already exists.")
self._agent_factories[name] = agent_factory
if valid_namespaces is not AllNamespaces:
self._valid_namespaces[name] = cast(Sequence[str], valid_namespaces)
else:
self._valid_namespaces[name] = []
# For all already prepared namespaces we need to prepare this agent
for namespace in self._known_namespaces:
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
self._get_agent(AgentId(name=name, namespace=namespace))
self._get_agent(AgentId(name=name, namespace=namespace))
def _invoke_agent_factory(
self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId
@ -419,24 +413,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return agent
def _type_valid_for_namespace(self, agent_id: AgentId) -> bool:
if agent_id.name not in self._agent_factories:
raise KeyError(f"Agent with name {agent_id.name} not found.")
valid_namespaces = self._valid_namespaces[agent_id.name]
if len(valid_namespaces) == 0:
return True
return agent_id.namespace in valid_namespaces
def _get_agent(self, agent_id: AgentId) -> Agent:
self._process_seen_namespace(agent_id.namespace)
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
if not self._type_valid_for_namespace(agent_id):
raise ValueError(f"Agent with name {agent_id.name} not valid for namespace {agent_id.namespace}.")
if agent_id.name not in self._agent_factories:
raise ValueError(f"Agent with name {agent_id.name} not found.")
@ -463,5 +444,4 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._known_namespaces.add(namespace)
for name in self._known_agent_names:
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
self._get_agent(AgentId(name=name, namespace=namespace))
self._get_agent(AgentId(name=name, namespace=namespace))

View File

@ -7,7 +7,7 @@ from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_props import AgentChildren
from ._agent_proxy import AgentProxy
from ._agent_runtime import AgentRuntime, AllNamespaces, agent_instantiation_context
from ._agent_runtime import AgentRuntime, agent_instantiation_context
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
@ -17,7 +17,6 @@ __all__ = [
"AgentProxy",
"AgentMetadata",
"AgentRuntime",
"AllNamespaces",
"BaseAgent",
"CancellationToken",
"AgentChildren",

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from asyncio import Future
from contextvars import ContextVar
from typing import Any, Callable, Mapping, Protocol, Sequence, Type, TypeVar, overload, runtime_checkable
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
from ._agent import Agent
from ._agent_id import AgentId
@ -17,10 +17,6 @@ T = TypeVar("T", bound=Agent)
agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar("agent_instantiation_context")
class AllNamespaces:
pass
@runtime_checkable
class AgentRuntime(Protocol):
# Returns the response of the message
@ -45,7 +41,9 @@ class AgentRuntime(Protocol):
@overload
def register(
self, name: str, agent_factory: Callable[[], T], *, valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...
self,
name: str,
agent_factory: Callable[[], T],
) -> None: ...
@overload
@ -53,23 +51,18 @@ class AgentRuntime(Protocol):
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T],
*,
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> None: ...
def register(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> None:
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent.
valid_namespaces (Sequence[str] | Type[AllNamespaces], optional): Valid namespaces for this type. Defaults to AllNamespaces.
Example:
@ -99,7 +92,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentId: ...
@overload
@ -109,7 +101,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentId: ...
def register_and_get(
@ -118,7 +109,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> AgentId:
self.register(name, agent_factory)
return self.get(name, namespace=namespace)
@ -130,7 +120,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentProxy: ...
@overload
@ -140,7 +129,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentProxy: ...
def register_and_get_proxy(
@ -149,7 +137,6 @@ class AgentRuntime(Protocol):
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> AgentProxy:
self.register(name, agent_factory)
return self.get_proxy(name, namespace=namespace)

View File

@ -18,7 +18,7 @@ async def test_intervention_count_messages() -> None:
return message
handler = DebugInterventionHandler()
runtime = SingleThreadedAgentRuntime(before_send=handler)
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
response = runtime.send_message(MessageType(), recipient=loopback)
@ -38,7 +38,7 @@ async def test_intervention_drop_send() -> None:
return DropMessage
handler = DropSendInterventionHandler()
runtime = SingleThreadedAgentRuntime(before_send=handler)
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
response = runtime.send_message(MessageType(), recipient=loopback)
@ -61,7 +61,7 @@ async def test_intervention_drop_response() -> None:
return DropMessage
handler = DropResponseInterventionHandler()
runtime = SingleThreadedAgentRuntime(before_send=handler)
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
response = runtime.send_message(MessageType(), recipient=loopback)
@ -84,7 +84,7 @@ async def test_intervention_raise_exception_on_send() -> None:
raise InterventionException
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(before_send=handler)
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = runtime.register_and_get("name", LoopbackAgent)
response = runtime.send_message(MessageType(), recipient=long_running)
@ -109,7 +109,7 @@ async def test_intervention_raise_exception_on_respond() -> None:
raise InterventionException
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(before_send=handler)
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = runtime.register_and_get("name", LoopbackAgent)
response = runtime.send_message(MessageType(), recipient=long_running)

View File

@ -41,34 +41,3 @@ async def test_register_receives_publish() -> None:
other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_try_instantiate_agent_invalid_namespace() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.register("name", LoopbackAgent, valid_namespaces=["default"])
await runtime.publish_message(MessageType(), namespace="non_default")
while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0:
await runtime.process_next()
# Agent in default namespace should have received the message
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
assert long_running_agent.num_calls == 0
with pytest.raises(ValueError):
_agent = runtime.get("name", namespace="non_default")
@pytest.mark.asyncio
async def test_send_crosses_namepace() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.register("name", LoopbackAgent)
default_ns_agent = runtime.get("name")
non_default_ns_agent = runtime.get("name", namespace="non_default")
with pytest.raises(ValueError):
await runtime.send_message(MessageType(), default_ns_agent, sender=non_default_ns_agent)