from typing import Any import pytest from agnext.application import SingleThreadedAgentRuntime from agnext.core import BaseAgent, CancellationToken from test_utils import LoopbackAgent, MessageType class NoopAgent(BaseAgent): # type: ignore def __init__(self) -> None: # type: ignore super().__init__("A no op agent", []) async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore raise NotImplementedError @pytest.mark.asyncio async def test_agent_names_must_be_unique() -> None: runtime = SingleThreadedAgentRuntime() _agent1 = runtime.register_and_get("name1", NoopAgent) with pytest.raises(ValueError): _agent1 = runtime.register_and_get("name1", NoopAgent) _agent1 = runtime.register_and_get("name3", NoopAgent) @pytest.mark.asyncio async def test_register_receives_publish() -> None: runtime = SingleThreadedAgentRuntime() runtime.register("name", LoopbackAgent) await runtime.publish_message(MessageType(), namespace="default") while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0: await runtime.process_next() # Agent in default namespace should have received the message long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore assert long_running_agent.num_calls == 1 # Agent in other namespace should not have received the message other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore assert other_long_running_agent.num_calls == 0 @pytest.mark.asyncio async def test_try_instantiate_agent_invalid_namespace() -> None: runtime = SingleThreadedAgentRuntime() runtime.register("name", LoopbackAgent, valid_namespaces=["default"]) await runtime.publish_message(MessageType(), namespace="non_default") while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0: await runtime.process_next() # Agent in default namespace should have received the message long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore assert long_running_agent.num_calls == 0 with pytest.raises(ValueError): _agent = runtime.get("name", namespace="non_default") @pytest.mark.asyncio async def test_send_crosses_namepace() -> None: runtime = SingleThreadedAgentRuntime() runtime.register("name", LoopbackAgent) default_ns_agent = runtime.get("name") non_default_ns_agent = runtime.get("name", namespace="non_default") with pytest.raises(ValueError): await runtime.send_message(MessageType(), default_ns_agent, sender=non_default_ns_agent)