mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 22:18:53 +00:00
format (#593)
This commit is contained in:
parent
207330577f
commit
46ca778423
@ -564,6 +564,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
expected_class: type[T],
|
||||
) -> AgentType:
|
||||
if type.type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
|
||||
async def factory_wrapper() -> T:
|
||||
maybe_agent_instance = agent_factory()
|
||||
if inspect.isawaitable(maybe_agent_instance):
|
||||
|
||||
@ -5,11 +5,19 @@ from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import (
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentType,
|
||||
Subscription,
|
||||
SubscriptionInstantiationContext,
|
||||
TopicId,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.components import (
|
||||
DefaultSubscription,
|
||||
DefaultTopicId,
|
||||
TypeSubscription,
|
||||
default_subscription,
|
||||
type_subscription,
|
||||
)
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, TypeSubscription
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
|
||||
from test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider
|
||||
@ -24,7 +32,7 @@ def tracer_provider() -> TracerProvider:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_names_must_be_unique() -> None:
|
||||
async def test_agent_type_must_be_unique() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
def agent_factory() -> NoopAgent:
|
||||
@ -34,29 +42,30 @@ async def test_agent_names_must_be_unique() -> None:
|
||||
assert agent.id == id
|
||||
return agent
|
||||
|
||||
await runtime.register("name1", agent_factory)
|
||||
await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await runtime.register("name1", NoopAgent)
|
||||
await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent)
|
||||
|
||||
await runtime.register("name3", NoopAgent)
|
||||
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish(tracer_provider: TracerProvider) -> None:
|
||||
runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
|
||||
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
runtime.start()
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
await runtime.register_factory(
|
||||
type=AgentType("name"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "name"))
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await runtime.publish_message(MessageType(), topic_id=topic_id)
|
||||
|
||||
runtime.start()
|
||||
await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(AgentId("name", "default"), type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
@ -77,7 +86,6 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_cascade() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
num_agents = 5
|
||||
num_initial_messages = 5
|
||||
max_rounds = 5
|
||||
@ -85,16 +93,17 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
for i in range(0, max_rounds):
|
||||
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
|
||||
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
|
||||
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
|
||||
runtime.start()
|
||||
|
||||
# Publish messages
|
||||
topic_id = TopicId("default", "default")
|
||||
for _ in range(num_initial_messages):
|
||||
await runtime.publish_message(CascadingMessageType(round=1), topic_id)
|
||||
await runtime.publish_message(CascadingMessageType(round=1), DefaultTopicId())
|
||||
|
||||
# Process until idle.
|
||||
await runtime.stop_when_idle()
|
||||
@ -206,64 +215,72 @@ async def test_register_factory_direct_list() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_subscription() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
runtime.start()
|
||||
|
||||
@default_subscription
|
||||
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...
|
||||
|
||||
await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription)
|
||||
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(
|
||||
agent_id, type=LoopbackAgentWithDefaultSubscription
|
||||
)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_default_default_subscription() -> None:
|
||||
async def test_type_subscription() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))
|
||||
|
||||
@type_subscription(topic_type="Other")
|
||||
class LoopbackAgentWithSubscription(LoopbackAgent): ...
|
||||
|
||||
await LoopbackAgentWithSubscription.register(runtime, "name", LoopbackAgentWithSubscription)
|
||||
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=TopicId("Other", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgentWithSubscription)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgentWithSubscription
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_publish_to_other_source() -> None:
|
||||
async def test_default_subscription_publish_to_other_source() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
runtime.start()
|
||||
|
||||
@default_subscription
|
||||
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...
|
||||
|
||||
await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription)
|
||||
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(
|
||||
agent_id, type=LoopbackAgentWithDefaultSubscription
|
||||
)
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 1
|
||||
|
||||
@ -9,7 +9,13 @@ from autogen_core.base import (
|
||||
TopicId,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, TypeSubscription
|
||||
from autogen_core.components import (
|
||||
DefaultSubscription,
|
||||
DefaultTopicId,
|
||||
TypeSubscription,
|
||||
default_subscription,
|
||||
type_subscription,
|
||||
)
|
||||
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
|
||||
|
||||
|
||||
@ -190,83 +196,107 @@ async def test_default_subscription() -> None:
|
||||
host_address = "localhost:50054"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime.start()
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
publisher = WorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
publisher.start()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
@default_subscription
|
||||
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...
|
||||
|
||||
await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription())
|
||||
|
||||
await publisher.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
# Agent in default topic source should have received the message.
|
||||
long_running_agent = await worker.try_get_underlying_agent_instance(
|
||||
AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription
|
||||
)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
# Agent in other topic source should not have received the message.
|
||||
other_long_running_agent = await worker.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
await runtime.stop()
|
||||
await worker.stop()
|
||||
await publisher.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_default_default_subscription() -> None:
|
||||
host_address = "localhost:50055"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
await runtime.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_publish_to_other_source() -> None:
|
||||
async def test_default_subscription_other_source() -> None:
|
||||
host_address = "localhost:50056"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime.start()
|
||||
publisher = WorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
publisher.start()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
|
||||
@default_subscription
|
||||
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...
|
||||
|
||||
await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription())
|
||||
|
||||
await publisher.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription
|
||||
)
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 1
|
||||
|
||||
await runtime.stop()
|
||||
await publisher.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_type_subscription() -> None:
|
||||
host_address = "localhost:50055"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
publisher = WorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
publisher.start()
|
||||
|
||||
@type_subscription("Other")
|
||||
class LoopbackAgentWithSubscription(LoopbackAgent): ...
|
||||
|
||||
await LoopbackAgentWithSubscription.register(worker, "name", lambda: LoopbackAgentWithSubscription())
|
||||
|
||||
await publisher.publish_message(MessageType(), topic_id=TopicId(type="Other", source="default"))
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent in default topic source should have received the message.
|
||||
long_running_agent = await worker.try_get_underlying_agent_instance(
|
||||
AgentId("name", "default"), type=LoopbackAgentWithSubscription
|
||||
)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other topic source should not have received the message.
|
||||
other_long_running_agent = await worker.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgentWithSubscription
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
await worker.stop()
|
||||
await publisher.stop()
|
||||
await host.stop()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user