diff --git a/docs/src/conf.py b/docs/src/conf.py index de2225338..0560ff78f 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -25,6 +25,7 @@ apidoc_output_dir = 'reference' apidoc_template_dir = '_apidoc_templates' apidoc_separate_modules = True apidoc_extra_args = ["--no-toc"] +napoleon_custom_sections = [('Returns', 'params_style')] templates_path = [] exclude_patterns = ["reference/agnext.rst"] diff --git a/src/agnext/agent_components/type_routed_agent.py b/src/agnext/agent_components/type_routed_agent.py index d6fa10309..3cf4df40d 100644 --- a/src/agnext/agent_components/type_routed_agent.py +++ b/src/agnext/agent_components/type_routed_agent.py @@ -31,13 +31,9 @@ def message_handler( class TypeRoutedAgent(BaseAgent): def __init__(self, name: str, router: AgentRuntime) -> None: - super().__init__(name, router) - # Self is already bound to the handlers self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {} - router.add_agent(self) - for attr in dir(self): if callable(getattr(self, attr, None)): handler = getattr(self, attr) @@ -45,6 +41,8 @@ class TypeRoutedAgent(BaseAgent): for target_type in handler._target_types: self._handlers[target_type] = handler + super().__init__(name, router) + @property def subscriptions(self) -> Sequence[Type[Any]]: return list(self._handlers.keys()) diff --git a/src/agnext/application_components/single_threaded_agent_runtime.py b/src/agnext/application_components/single_threaded_agent_runtime.py index 9d9320c1b..906b1dd03 100644 --- a/src/agnext/application_components/single_threaded_agent_runtime.py +++ b/src/agnext/application_components/single_threaded_agent_runtime.py @@ -51,6 +51,10 @@ class SingleThreadedAgentRuntime(AgentRuntime): self._before_send = before_send def add_agent(self, agent: Agent) -> None: + agent_names = {agent.name for agent in self._agents} + if agent.name in agent_names: + raise ValueError(f"Agent with name {agent.name} already exists. Agent names must be unique.") + for message_type in agent.subscriptions: if message_type not in self._per_type_subscribers: self._per_type_subscribers[message_type] = [] diff --git a/src/agnext/core/agent.py b/src/agnext/core/agent.py index 2fd60ec63..01d8c599f 100644 --- a/src/agnext/core/agent.py +++ b/src/agnext/core/agent.py @@ -6,9 +6,30 @@ from agnext.core.cancellation_token import CancellationToken @runtime_checkable class Agent(Protocol): @property - def name(self) -> str: ... + def name(self) -> str: + """Name of the agent. + + Note: + This name should be unique within the runtime. + """ + ... @property - def subscriptions(self) -> Sequence[type]: ... + def subscriptions(self) -> Sequence[type]: + """Types of messages that this agent can receive.""" + ... - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ... + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: + """Message handler for the agent. This should only be called by the runtime, not by other agents. + + Args: + message (Any): Received message. Type is one of the types in `subscriptions`. + cancellation_token (CancellationToken): Cancellation token for the message. + + Returns: + Any: Response to the message. Can be None. + + Notes: + If there was a cancellation, this function should raise a `CancelledError`. + """ + ... diff --git a/src/agnext/core/agent_runtime.py b/src/agnext/core/agent_runtime.py index 7bd4875fe..e886a11a3 100644 --- a/src/agnext/core/agent_runtime.py +++ b/src/agnext/core/agent_runtime.py @@ -8,7 +8,16 @@ from agnext.core.cancellation_token import CancellationToken class AgentRuntime(Protocol): - def add_agent(self, agent: Agent) -> None: ... + def add_agent(self, agent: Agent) -> None: + """Add an agent to the runtime. + + Args: + agent (Agent): Agent to add to the runtime. + + Note: + The name of the agent should be unique within the runtime. + """ + ... # Returns the response of the message def send_message( diff --git a/src/agnext/core/base_agent.py b/src/agnext/core/base_agent.py index 38c13ab87..51ef4f536 100644 --- a/src/agnext/core/base_agent.py +++ b/src/agnext/core/base_agent.py @@ -18,6 +18,7 @@ class BaseAgent(ABC, Agent): def __init__(self, name: str, router: AgentRuntime) -> None: self._name = name self._router = router + router.add_agent(self) @property def name(self) -> str: @@ -29,7 +30,7 @@ class BaseAgent(ABC, Agent): return [] @abstractmethod - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ... + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ... # Returns the response of the message def _send_message( diff --git a/tests/test_runtime.py b/tests/test_runtime.py new file mode 100644 index 000000000..1935f8ab2 --- /dev/null +++ b/tests/test_runtime.py @@ -0,0 +1,32 @@ +from typing import Any, Sequence +import pytest + +from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime +from agnext.core.agent_runtime import AgentRuntime +from agnext.core.base_agent import BaseAgent +from agnext.core.cancellation_token import CancellationToken + +class NoopAgent(BaseAgent): + def __init__(self, name: str, router: AgentRuntime) -> None: + super().__init__(name, router) + + @property + def subscriptions(self) -> Sequence[type]: + return [] + + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: + raise NotImplementedError + + +@pytest.mark.asyncio +async def test_agent_names_must_be_unique() -> None: + router = SingleThreadedAgentRuntime() + + _agent1 = NoopAgent("name1", router) + + with pytest.raises(ValueError): + _agent1_again = NoopAgent("name1", router) + + _agent3 = NoopAgent("name3", router) + +