autogen/python/tests/test_runtime.py

76 lines
2.7 KiB
Python
Raw Normal View History

from typing import Any
import pytest
2024-06-04 10:00:05 -04:00
from agnext.application import SingleThreadedAgentRuntime
from agnext.core import BaseAgent, CancellationToken
from test_utils import LoopbackAgent, MessageType
class NoopAgent(BaseAgent): # type: ignore
def __init__(self) -> None: # type: ignore
super().__init__("A no op agent", [])
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore
raise NotImplementedError
@pytest.mark.asyncio
async def test_agent_names_must_be_unique() -> None:
runtime = SingleThreadedAgentRuntime()
_agent1 = runtime.register_and_get("name1", NoopAgent)
with pytest.raises(ValueError):
_agent1 = runtime.register_and_get("name1", NoopAgent)
_agent1 = runtime.register_and_get("name3", NoopAgent)
@pytest.mark.asyncio
async def test_register_receives_publish() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.register("name", LoopbackAgent)
await runtime.publish_message(MessageType(), namespace="default")
while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0:
await runtime.process_next()
# Agent in default namespace should have received the message
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_try_instantiate_agent_invalid_namespace() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.register("name", LoopbackAgent, valid_namespaces=["default"])
await runtime.publish_message(MessageType(), namespace="non_default")
while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0:
await runtime.process_next()
# Agent in default namespace should have received the message
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
assert long_running_agent.num_calls == 0
with pytest.raises(ValueError):
_agent = runtime.get("name", namespace="non_default")
@pytest.mark.asyncio
async def test_send_crosses_namepace() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.register("name", LoopbackAgent)
default_ns_agent = runtime.get("name")
non_default_ns_agent = runtime.get("name", namespace="non_default")
with pytest.raises(ValueError):
await runtime.send_message(MessageType(), default_ns_agent, sender=non_default_ns_agent)