mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 15:09:41 +00:00
simplify namespace usage (#116)
* simplify namespace usage * format * pyright
This commit is contained in:
parent
606e43b325
commit
6189fdb05c
@ -1,7 +1,6 @@
|
||||
# Namespace
|
||||
|
||||
A namespace is a logical boundary between agents. By default, agents in one
|
||||
namespace cannot communicate with agents in another namespace.
|
||||
Namespace allow for defining logical boundaries between agents.
|
||||
|
||||
Namespaces are strings, and the default is `default`.
|
||||
|
||||
@ -15,4 +14,6 @@ Two possible use cases of agents are:
|
||||
|
||||
The {py:class}`agnext.core.AgentId` is used to address an agent, it is the combination of the agent's namespace and its name.
|
||||
|
||||
When getting an agent reference ({py:meth}`agnext.core.AgentRuntime.get`) or proxy ({py:meth}`agnext.core.AgentRuntime.get_proxy`) from the runtime the namespace can be specified. Agents have an ID property ({py:attr}`agnext.core.Agent.id`) that returns the agent's id. Additionally, the register method takes a factory that can optionally accept the ID as an argument ({py:meth}`agnext.core.AgentRuntime.register`).
|
||||
When getting an agent reference ({py:meth}`agnext.core.AgentRuntime.get`) or proxy ({py:meth}`agnext.core.AgentRuntime.get_proxy`) from the runtime the namespace can be specified. Agents have an ID property ({py:attr}`agnext.core.Agent.id`) that returns the agent's id. Additionally, the register method takes a factory that can optionally accept the ID as an argument ({py:meth}`agnext.core.AgentRuntime.register`).
|
||||
|
||||
By default, there are no restrictions and are left to the application to enforce. The runtime will however automatically create agents in a namespace if it does not exist.
|
||||
|
||||
@ -68,7 +68,7 @@ Finally, we add this handler to the runtime and use it to detect termination and
|
||||
async def main() -> None:
|
||||
termination_handler = TerminationHandler()
|
||||
runtime = SingleThreadedAgentRuntime(
|
||||
before_send=termination_handler
|
||||
intervention_handler=termination_handler
|
||||
)
|
||||
|
||||
# Add Agents and kick off task
|
||||
|
||||
@ -267,7 +267,7 @@ class DisplayAgent(TypeRoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
termination_handler = TerminationHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=termination_handler)
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=termination_handler)
|
||||
runtime.register(
|
||||
"ReviewerAgent",
|
||||
lambda: ReviewerAgent(
|
||||
|
||||
@ -133,7 +133,7 @@ class DisplayAgent(TypeRoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
termination_handler = TerminationHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=termination_handler)
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=termination_handler)
|
||||
# TODO: use different models for each agent.
|
||||
runtime.register(
|
||||
"ReferenceAgent1",
|
||||
|
||||
@ -6,7 +6,7 @@ from asyncio import Future
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, TypeVar, cast
|
||||
|
||||
from ..core import (
|
||||
Agent,
|
||||
@ -14,7 +14,6 @@ from ..core import (
|
||||
AgentMetadata,
|
||||
AgentProxy,
|
||||
AgentRuntime,
|
||||
AllNamespaces,
|
||||
CancellationToken,
|
||||
agent_instantiation_context,
|
||||
)
|
||||
@ -82,15 +81,13 @@ class Counter:
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
|
||||
def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None:
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, type], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
|
||||
# If empty, then all namespaces are valid for that agent type
|
||||
self._valid_namespaces: Dict[str, Sequence[str]] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._before_send = before_send
|
||||
self._intervention_handler = intervention_handler
|
||||
self._known_namespaces: set[str] = set()
|
||||
self._outstanding_tasks = Counter()
|
||||
|
||||
@ -322,9 +319,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
match message_envelope:
|
||||
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
if self._intervention_handler is not None:
|
||||
try:
|
||||
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
|
||||
temp_message = await self._intervention_handler.on_send(
|
||||
message, sender=sender, recipient=recipient
|
||||
)
|
||||
except BaseException as e:
|
||||
future.set_exception(e)
|
||||
return
|
||||
@ -339,9 +338,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
message=message,
|
||||
sender=sender,
|
||||
):
|
||||
if self._before_send is not None:
|
||||
if self._intervention_handler is not None:
|
||||
try:
|
||||
temp_message = await self._before_send.on_publish(message, sender=sender)
|
||||
temp_message = await self._intervention_handler.on_publish(message, sender=sender)
|
||||
except BaseException as e:
|
||||
# TODO: we should raise the intervention exception to the publisher.
|
||||
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
|
||||
@ -354,9 +353,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_publish(message_envelope))
|
||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
if self._intervention_handler is not None:
|
||||
try:
|
||||
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
|
||||
temp_message = await self._intervention_handler.on_response(
|
||||
message, sender=sender, recipient=recipient
|
||||
)
|
||||
except BaseException as e:
|
||||
# TODO: should we raise the exception to sender of the response instead?
|
||||
future.set_exception(e)
|
||||
@ -385,21 +386,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> None:
|
||||
if name in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {name} already exists.")
|
||||
self._agent_factories[name] = agent_factory
|
||||
if valid_namespaces is not AllNamespaces:
|
||||
self._valid_namespaces[name] = cast(Sequence[str], valid_namespaces)
|
||||
else:
|
||||
self._valid_namespaces[name] = []
|
||||
|
||||
# For all already prepared namespaces we need to prepare this agent
|
||||
for namespace in self._known_namespaces:
|
||||
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
||||
def _invoke_agent_factory(
|
||||
self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId
|
||||
@ -419,24 +413,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
return agent
|
||||
|
||||
def _type_valid_for_namespace(self, agent_id: AgentId) -> bool:
|
||||
if agent_id.name not in self._agent_factories:
|
||||
raise KeyError(f"Agent with name {agent_id.name} not found.")
|
||||
|
||||
valid_namespaces = self._valid_namespaces[agent_id.name]
|
||||
if len(valid_namespaces) == 0:
|
||||
return True
|
||||
|
||||
return agent_id.namespace in valid_namespaces
|
||||
|
||||
def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
self._process_seen_namespace(agent_id.namespace)
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
if not self._type_valid_for_namespace(agent_id):
|
||||
raise ValueError(f"Agent with name {agent_id.name} not valid for namespace {agent_id.namespace}.")
|
||||
|
||||
if agent_id.name not in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {agent_id.name} not found.")
|
||||
|
||||
@ -463,5 +444,4 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
self._known_namespaces.add(namespace)
|
||||
for name in self._known_agent_names:
|
||||
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
||||
@ -7,7 +7,7 @@ from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_props import AgentChildren
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._agent_runtime import AgentRuntime, AllNamespaces, agent_instantiation_context
|
||||
from ._agent_runtime import AgentRuntime, agent_instantiation_context
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
@ -17,7 +17,6 @@ __all__ = [
|
||||
"AgentProxy",
|
||||
"AgentMetadata",
|
||||
"AgentRuntime",
|
||||
"AllNamespaces",
|
||||
"BaseAgent",
|
||||
"CancellationToken",
|
||||
"AgentChildren",
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Mapping, Protocol, Sequence, Type, TypeVar, overload, runtime_checkable
|
||||
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
@ -17,10 +17,6 @@ T = TypeVar("T", bound=Agent)
|
||||
agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar("agent_instantiation_context")
|
||||
|
||||
|
||||
class AllNamespaces:
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentRuntime(Protocol):
|
||||
# Returns the response of the message
|
||||
@ -45,7 +41,9 @@ class AgentRuntime(Protocol):
|
||||
|
||||
@overload
|
||||
def register(
|
||||
self, name: str, agent_factory: Callable[[], T], *, valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T],
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
@ -53,23 +51,18 @@ class AgentRuntime(Protocol):
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> None: ...
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> None:
|
||||
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
|
||||
|
||||
Args:
|
||||
name (str): The name of the type agent this factory creates.
|
||||
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent.
|
||||
valid_namespaces (Sequence[str] | Type[AllNamespaces], optional): Valid namespaces for this type. Defaults to AllNamespaces.
|
||||
|
||||
|
||||
Example:
|
||||
@ -99,7 +92,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentId: ...
|
||||
|
||||
@overload
|
||||
@ -109,7 +101,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentId: ...
|
||||
|
||||
def register_and_get(
|
||||
@ -118,7 +109,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> AgentId:
|
||||
self.register(name, agent_factory)
|
||||
return self.get(name, namespace=namespace)
|
||||
@ -130,7 +120,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentProxy: ...
|
||||
|
||||
@overload
|
||||
@ -140,7 +129,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentProxy: ...
|
||||
|
||||
def register_and_get_proxy(
|
||||
@ -149,7 +137,6 @@ class AgentRuntime(Protocol):
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> AgentProxy:
|
||||
self.register(name, agent_factory)
|
||||
return self.get_proxy(name, namespace=namespace)
|
||||
|
||||
@ -18,7 +18,7 @@ async def test_intervention_count_messages() -> None:
|
||||
return message
|
||||
|
||||
handler = DebugInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
|
||||
response = runtime.send_message(MessageType(), recipient=loopback)
|
||||
@ -38,7 +38,7 @@ async def test_intervention_drop_send() -> None:
|
||||
return DropMessage
|
||||
|
||||
handler = DropSendInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=loopback)
|
||||
@ -61,7 +61,7 @@ async def test_intervention_drop_response() -> None:
|
||||
return DropMessage
|
||||
|
||||
handler = DropResponseInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=loopback)
|
||||
@ -84,7 +84,7 @@ async def test_intervention_raise_exception_on_send() -> None:
|
||||
raise InterventionException
|
||||
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=long_running)
|
||||
@ -109,7 +109,7 @@ async def test_intervention_raise_exception_on_respond() -> None:
|
||||
raise InterventionException
|
||||
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
@ -41,34 +41,3 @@ async def test_register_receives_publish() -> None:
|
||||
other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_try_instantiate_agent_invalid_namespace() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
runtime.register("name", LoopbackAgent, valid_namespaces=["default"])
|
||||
await runtime.publish_message(MessageType(), namespace="non_default")
|
||||
|
||||
while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0:
|
||||
await runtime.process_next()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_agent = runtime.get("name", namespace="non_default")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_crosses_namepace() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
runtime.register("name", LoopbackAgent)
|
||||
|
||||
default_ns_agent = runtime.get("name")
|
||||
non_default_ns_agent = runtime.get("name", namespace="non_default")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await runtime.send_message(MessageType(), default_ns_agent, sender=non_default_ns_agent)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user