autogen/python/tests/test_runtime.py

225 lines
9.3 KiB
Python
Raw Normal View History

import asyncio
import pytest
2024-06-22 14:50:32 -04:00
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeSubscription, DefaultTopicId, DefaultSubscription
from agnext.base import AgentId, AgentInstantiationContext
from agnext.base import TopicId
from agnext.base import Subscription
from agnext.base import SubscriptionInstantiationContext
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
@pytest.mark.asyncio
async def test_agent_names_must_be_unique() -> None:
runtime = SingleThreadedAgentRuntime()
def agent_factory() -> NoopAgent:
id = AgentInstantiationContext.current_agent_id()
2024-06-22 14:50:32 -04:00
assert id == AgentId("name1", "default")
agent = NoopAgent()
assert agent.id == id
return agent
await runtime.register("name1", agent_factory)
with pytest.raises(ValueError):
await runtime.register("name1", NoopAgent)
await runtime.register("name3", NoopAgent)
@pytest.mark.asyncio
async def test_register_receives_publish() -> None:
runtime = SingleThreadedAgentRuntime()
2024-07-23 11:49:38 -07:00
await runtime.register("name", LoopbackAgent)
runtime.start()
await runtime.add_subscription(TypeSubscription("default", "name"))
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_receives_publish_cascade() -> None:
runtime = SingleThreadedAgentRuntime()
num_agents = 5
num_initial_messages = 5
max_rounds = 5
total_num_calls_expected = 0
for i in range(0, max_rounds):
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
# Register agents
for i in range(num_agents):
2024-07-23 11:49:38 -07:00
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
await runtime.add_subscription(TypeSubscription("default", f"name{i}"))
runtime.start()
# Publish messages
topic_id = TopicId("default", "default")
for _ in range(num_initial_messages):
await runtime.publish_message(CascadingMessageType(round=1), topic_id)
# Process until idle.
await runtime.stop_when_idle()
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected
@pytest.mark.asyncio
async def test_register_factory_explicit_name() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", "name")])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_factory_context_var_name() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_factory_async() -> None:
runtime = SingleThreadedAgentRuntime()
async def sub_factory() -> list[Subscription]:
await asyncio.sleep(0.1)
return [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)]
await runtime.register("name", LoopbackAgent, sub_factory)
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_register_factory_direct_list() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, [TypeSubscription("default", "name")])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_non_default_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_non_publish_to_other_source() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 0
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 1