autogen/python/packages/autogen-ext/tests/test_worker_runtime.py
Jacob Alber 676b611064
fix: Make race condition between channel open and RPC less likely to occur (#5514)
Right now we rely on opening the channel to associate a ClientId with an
entry on the gateway side. This causes a race when the channel is being
opened in the background while an RPC (e.g. MyAgent.register()) is
invoked.

If the RPC is processed first, the gateway rejects it due to "invalid"
clientId.

This fix makes this condition less likely to trigger, but there is still
a piece of the puzzle that needs to be solved on the Gateway side.
2025-02-12 16:40:52 -05:00

585 lines
21 KiB
Python

import asyncio
import logging
import os
from typing import Any, List
import pytest
from autogen_core import (
PROTOBUF_DATA_CONTENT_TYPE,
AgentId,
AgentType,
DefaultSubscription,
DefaultTopicId,
MessageContext,
RoutedAgent,
Subscription,
TopicId,
TypeSubscription,
default_subscription,
event,
try_get_known_serializers_for_type,
type_subscription,
)
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime, GrpcWorkerAgentRuntimeHost
from autogen_test_utils import (
CascadingAgent,
CascadingMessageType,
ContentMessage,
LoopbackAgent,
LoopbackAgentWithDefaultSubscription,
MessageType,
NoopAgent,
)
from .protos.serialization_test_pb2 import ProtoMessage
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_agent_types_must_be_unique_single_worker() -> None:
host_address = "localhost:50051"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
await worker.start()
await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
with pytest.raises(ValueError):
await worker.register_factory(
type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent
)
await worker.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
await worker.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_agent_types_must_be_unique_multiple_workers() -> None:
host_address = "localhost:50052"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
await worker1.start()
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
await worker2.start()
await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
with pytest.raises(Exception, match="Agent type name1 already registered"):
await worker2.register_factory(
type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent
)
await worker2.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
await worker1.stop()
await worker2.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_register_receives_publish() -> None:
host_address = "localhost:50053"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
await worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
await worker1.add_subscription(TypeSubscription("default", "name1"))
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
await worker2.start()
worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker2.register_factory(
type=AgentType("name2"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
await worker2.add_subscription(TypeSubscription("default", "name2"))
# Publish message from worker1
await worker1.publish_message(MessageType(), topic_id=TopicId("default", "default"))
# Let the agent run for a bit.
await asyncio.sleep(2)
# Agents in default topic source should have received the message.
worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent)
assert worker1_agent.num_calls == 1
worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "default"), LoopbackAgent)
assert worker2_agent.num_calls == 1
# Agents in other topic source should not have received the message.
worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "other"), LoopbackAgent)
assert worker1_agent.num_calls == 0
worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "other"), LoopbackAgent)
assert worker2_agent.num_calls == 0
await worker1.stop()
await worker2.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_register_doesnt_receive_after_removing_subscription() -> None:
host_address = "localhost:50053"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
await worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
sub = DefaultSubscription(agent_type="name1")
await worker1.add_subscription(sub)
agent_1_instance = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent)
# Publish message from worker1
await worker1.publish_message(MessageType(), topic_id=DefaultTopicId())
# Let the agent run for a bit.
await agent_1_instance.event.wait()
agent_1_instance.event.clear()
# Agents in default topic source should have received the message.
assert agent_1_instance.num_calls == 1
await worker1.remove_subscription(sub.id)
# Publish message from worker1
await worker1.publish_message(MessageType(), topic_id=DefaultTopicId())
# Let the agent run for a bit.
await asyncio.sleep(2)
# Agent should not have received the message.
assert agent_1_instance.num_calls == 1
await worker1.stop()
await host.stop()
@pytest.mark.asyncio
async def test_register_receives_publish_cascade_single_worker() -> None:
host_address = "localhost:50054"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
await runtime.start()
num_agents = 5
num_initial_messages = 5
max_rounds = 5
total_num_calls_expected = 0
for i in range(0, max_rounds):
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
# Register agents
for i in range(num_agents):
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
# Publish messages
for _ in range(num_initial_messages):
await runtime.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
# Wait for all agents to finish.
await asyncio.sleep(10)
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected
await runtime.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.skip(reason="Fix flakiness")
@pytest.mark.asyncio
async def test_register_receives_publish_cascade_multiple_workers() -> None:
logging.basicConfig(level=logging.DEBUG)
host_address = "localhost:50055"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
# TODO: Increasing num_initial_messages or max_round to 2 causes the test to fail.
num_agents = 2
num_initial_messages = 1
max_rounds = 1
total_num_calls_expected = 0
for i in range(0, max_rounds):
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
# Run multiple workers one for each agent.
workers: List[GrpcWorkerAgentRuntime] = []
# Register agents
for i in range(num_agents):
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
await runtime.start()
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
workers.append(runtime)
# Publish messages
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
await publisher.start()
for _ in range(num_initial_messages):
await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
await asyncio.sleep(20)
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent = await workers[i].try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected
for worker in workers:
await worker.stop()
await publisher.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_default_subscription() -> None:
host_address = "localhost:50056"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
await worker.start()
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await publisher.start()
await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription())
await publisher.publish_message(MessageType(), topic_id=DefaultTopicId())
await asyncio.sleep(2)
# Agent in default topic source should have received the message.
long_running_agent = await worker.try_get_underlying_agent_instance(
AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription
)
assert long_running_agent.num_calls == 1
# Agent in other topic source should not have received the message.
other_long_running_agent = await worker.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
)
assert other_long_running_agent.num_calls == 0
await worker.stop()
await publisher.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_default_subscription_other_source() -> None:
host_address = "localhost:50057"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
await runtime.start()
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await publisher.start()
await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription())
await publisher.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
await asyncio.sleep(2)
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription
)
assert long_running_agent.num_calls == 0
# Agent in other namespace should not have received the message
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
)
assert other_long_running_agent.num_calls == 1
await runtime.stop()
await publisher.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_type_subscription() -> None:
host_address = "localhost:50058"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
await worker.start()
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await publisher.start()
@type_subscription("Other")
class LoopbackAgentWithSubscription(LoopbackAgent): ...
await LoopbackAgentWithSubscription.register(worker, "name", lambda: LoopbackAgentWithSubscription())
await publisher.publish_message(MessageType(), topic_id=TopicId(type="Other", source="default"))
await asyncio.sleep(2)
# Agent in default topic source should have received the message.
long_running_agent = await worker.try_get_underlying_agent_instance(
AgentId("name", "default"), type=LoopbackAgentWithSubscription
)
assert long_running_agent.num_calls == 1
# Agent in other topic source should not have received the message.
other_long_running_agent = await worker.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithSubscription
)
assert other_long_running_agent.num_calls == 0
await worker.stop()
await publisher.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_duplicate_subscription() -> None:
host_address = "localhost:50059"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
host.start()
try:
await worker1.start()
await NoopAgent.register(worker1, "worker1", lambda: NoopAgent())
await worker1_2.start()
# Note: This passes because worker1 is still running
with pytest.raises(Exception, match="Agent type worker1 already registered"):
await NoopAgent.register(worker1_2, "worker1", lambda: NoopAgent())
# This is somehow covered in test_disconnected_agent as well as a stop will also disconnect the agent.
# Will keep them both for now as we might replace the way we simulate a disconnect
await worker1.stop()
with pytest.raises(ValueError):
await NoopAgent.register(worker1_2, "worker1", lambda: NoopAgent())
except Exception as ex:
raise ex
finally:
await worker1_2.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_disconnected_agent() -> None:
host_address = "localhost:50060"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
# TODO: Implementing `get_current_subscriptions` and `get_subscribed_recipients` requires access
# to some private properties. This needs to be updated once they are available publicly
def get_current_subscriptions() -> List[Subscription]:
return host._servicer._subscription_manager._subscriptions # type: ignore[reportPrivateUsage]
async def get_subscribed_recipients() -> List[AgentId]:
return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage]
try:
await worker1.start()
await LoopbackAgentWithDefaultSubscription.register(
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
)
subscriptions1 = get_current_subscriptions()
assert len(subscriptions1) == 2
recipients1 = await get_subscribed_recipients()
assert AgentId(type="worker1", key="default") in recipients1
first_subscription_id = subscriptions1[0].id
await worker1.publish_message(ContentMessage(content="Hello!"), DefaultTopicId())
# This is a simple simulation of worker disconnct
if worker1._host_connection is not None: # type: ignore[reportPrivateUsage]
try:
await worker1._host_connection.close() # type: ignore[reportPrivateUsage]
except asyncio.CancelledError:
pass
await asyncio.sleep(1)
subscriptions2 = get_current_subscriptions()
assert len(subscriptions2) == 0
recipients2 = await get_subscribed_recipients()
assert len(recipients2) == 0
await asyncio.sleep(1)
await worker1_2.start()
await LoopbackAgentWithDefaultSubscription.register(
worker1_2, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
)
subscriptions3 = get_current_subscriptions()
assert len(subscriptions3) == 2
assert first_subscription_id not in [x.id for x in subscriptions3]
recipients3 = await get_subscribed_recipients()
assert len(set(recipients2)) == len(recipients2) # Make sure there are no duplicates
assert AgentId(type="worker1", key="default") in recipients3
except Exception as ex:
raise ex
finally:
await worker1.stop()
await worker1_2.stop()
@default_subscription
class ProtoReceivingAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("A loop back agent.")
self.num_calls = 0
self.received_messages: list[Any] = []
@event
async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None: # type: ignore
self.num_calls += 1
self.received_messages.append(message)
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_proto_payloads() -> None:
host_address = "localhost:50057"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
receiver_runtime = GrpcWorkerAgentRuntime(
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
)
await receiver_runtime.start()
publisher_runtime = GrpcWorkerAgentRuntime(
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
)
publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage))
await publisher_runtime.start()
await ProtoReceivingAgent.register(receiver_runtime, "name", ProtoReceivingAgent)
await publisher_runtime.publish_message(ProtoMessage(message="Hello!"), topic_id=DefaultTopicId())
await asyncio.sleep(2)
# Agent in default namespace should have received the message
long_running_agent = await receiver_runtime.try_get_underlying_agent_instance(
AgentId("name", "default"), type=ProtoReceivingAgent
)
assert long_running_agent.num_calls == 1
assert long_running_agent.received_messages[0].message == "Hello!"
# Agent in other namespace should not have received the message
other_long_running_agent = await receiver_runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=ProtoReceivingAgent
)
assert other_long_running_agent.num_calls == 0
assert len(other_long_running_agent.received_messages) == 0
await receiver_runtime.stop()
await publisher_runtime.stop()
await host.stop()
# TODO add tests for failure to deserialize
@pytest.mark.grpc
@pytest.mark.asyncio
@pytest.mark.skip(reason="Fix flakiness")
async def test_grpc_max_message_size() -> None:
default_max_size = 2**22
new_max_size = default_max_size * 2
small_message = ContentMessage(content="small message")
big_message = ContentMessage(content="." * (default_max_size + 1))
extra_grpc_config = [
("grpc.max_send_message_length", new_max_size),
("grpc.max_receive_message_length", new_max_size),
]
host_address = "localhost:50061"
host = GrpcWorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
worker1 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
worker3 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
try:
host.start()
await worker1.start()
await worker2.start()
await worker3.start()
await LoopbackAgentWithDefaultSubscription.register(
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
)
await LoopbackAgentWithDefaultSubscription.register(
worker2, "worker2", lambda: LoopbackAgentWithDefaultSubscription()
)
await LoopbackAgentWithDefaultSubscription.register(
worker3, "worker3", lambda: LoopbackAgentWithDefaultSubscription()
)
# with pytest.raises(Exception):
await worker1.publish_message(small_message, DefaultTopicId())
# This is a simple simulation of worker disconnct
await asyncio.sleep(1)
agent_instance_2 = await worker2.try_get_underlying_agent_instance(
AgentId("worker2", key="default"), type=LoopbackAgent
)
agent_instance_3 = await worker3.try_get_underlying_agent_instance(
AgentId("worker3", key="default"), type=LoopbackAgent
)
assert agent_instance_2.num_calls == 1
assert agent_instance_3.num_calls == 1
await worker1.publish_message(big_message, DefaultTopicId())
await asyncio.sleep(2)
assert agent_instance_2.num_calls == 1 # Worker 2 won't receive the big message
assert agent_instance_3.num_calls == 2 # Worker 3 will receive the big message as has increased message length
except Exception as e:
raise e
finally:
await worker1.stop()
# await worker2.stop() # Worker 2 somehow breaks can can not be stopped.
await worker3.stop()
await host.stop()
if __name__ == "__main__":
os.environ["GRPC_VERBOSITY"] = "DEBUG"
os.environ["GRPC_TRACE"] = "all"
asyncio.run(test_disconnected_agent())
asyncio.run(test_grpc_max_message_size())