| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | import asyncio | 
					
						
							| 
									
										
										
										
											2024-09-18 19:08:35 -07:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  | from typing import Any, List | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2024-12-03 17:00:44 -08:00
										 |  |  | from autogen_core import ( | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  |     PROTOBUF_DATA_CONTENT_TYPE, | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     AgentId, | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     AgentType, | 
					
						
							| 
									
										
										
										
											2025-02-11 17:42:09 -05:00
										 |  |  |     DefaultSubscription, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:00:44 -08:00
										 |  |  |     DefaultTopicId, | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  |     MessageContext, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:00:44 -08:00
										 |  |  |     RoutedAgent, | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  |     Subscription, | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     TopicId, | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     TypeSubscription, | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  |     default_subscription, | 
					
						
							|  |  |  |     event, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:00:44 -08:00
										 |  |  |     try_get_known_serializers_for_type, | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     type_subscription, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  | from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime, GrpcWorkerAgentRuntimeHost | 
					
						
							|  |  |  | from autogen_test_utils import ( | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |     CascadingAgent, | 
					
						
							|  |  |  |     CascadingMessageType, | 
					
						
							|  |  |  |     ContentMessage, | 
					
						
							|  |  |  |     LoopbackAgent, | 
					
						
							|  |  |  |     LoopbackAgentWithDefaultSubscription, | 
					
						
							|  |  |  |     MessageType, | 
					
						
							|  |  |  |     NoopAgent, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2025-02-10 15:27:27 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | from .protos.serialization_test_pb2 import ProtoMessage | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  | async def test_agent_types_must_be_unique_single_worker() -> None: | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     host_address = "localhost:50051" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     host.start() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     worker = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await worker.start() | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ValueError): | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |         await worker.register_factory( | 
					
						
							|  |  |  |             type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     await worker.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent) | 
					
						
							| 
									
										
										
										
											2025-02-14 18:19:32 -08:00
										 |  |  |     await worker.register_factory(type=AgentType("name5"), agent_factory=lambda: NoopAgent()) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await worker.stop() | 
					
						
							|  |  |  |     await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_agent_types_must_be_unique_multiple_workers() -> None: | 
					
						
							| 
									
										
										
										
											2024-10-08 15:01:13 -07:00
										 |  |  |     host_address = "localhost:50052" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     host.start() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     worker1 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await worker1.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     worker2 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await worker2.start() | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-28 11:15:57 -05:00
										 |  |  |     with pytest.raises(Exception, match="Agent type name1 already registered"): | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |         await worker2.register_factory( | 
					
						
							|  |  |  |             type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await worker2.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await worker1.stop() | 
					
						
							|  |  |  |     await worker2.stop() | 
					
						
							|  |  |  |     await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_register_receives_publish() -> None: | 
					
						
							| 
									
										
										
										
											2024-10-08 15:01:13 -07:00
										 |  |  |     host_address = "localhost:50053" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     host.start() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     worker1 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await worker1.start() | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType)) | 
					
						
							|  |  |  |     await worker1.register_factory( | 
					
						
							|  |  |  |         type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     await worker1.add_subscription(TypeSubscription("default", "name1")) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     worker2 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await worker2.start() | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType)) | 
					
						
							|  |  |  |     await worker2.register_factory( | 
					
						
							|  |  |  |         type=AgentType("name2"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     await worker2.add_subscription(TypeSubscription("default", "name2")) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     # Publish message from worker1 | 
					
						
							|  |  |  |     await worker1.publish_message(MessageType(), topic_id=TopicId("default", "default")) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Let the agent run for a bit. | 
					
						
							|  |  |  |     await asyncio.sleep(2) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     # Agents in default topic source should have received the message. | 
					
						
							|  |  |  |     worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent) | 
					
						
							|  |  |  |     assert worker1_agent.num_calls == 1 | 
					
						
							|  |  |  |     worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "default"), LoopbackAgent) | 
					
						
							|  |  |  |     assert worker2_agent.num_calls == 1 | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     # Agents in other topic source should not have received the message. | 
					
						
							|  |  |  |     worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "other"), LoopbackAgent) | 
					
						
							|  |  |  |     assert worker1_agent.num_calls == 0 | 
					
						
							|  |  |  |     worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "other"), LoopbackAgent) | 
					
						
							|  |  |  |     assert worker2_agent.num_calls == 0 | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |     await worker1.stop() | 
					
						
							|  |  |  |     await worker2.stop() | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2025-02-11 17:42:09 -05:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_register_doesnt_receive_after_removing_subscription() -> None: | 
					
						
							|  |  |  |     host_address = "localhost:50053" | 
					
						
							|  |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							|  |  |  |     host.start() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     worker1 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await worker1.start() | 
					
						
							| 
									
										
										
										
											2025-02-11 17:42:09 -05:00
										 |  |  |     worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType)) | 
					
						
							|  |  |  |     await worker1.register_factory( | 
					
						
							|  |  |  |         type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     sub = DefaultSubscription(agent_type="name1") | 
					
						
							|  |  |  |     await worker1.add_subscription(sub) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     agent_1_instance = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent) | 
					
						
							|  |  |  |     # Publish message from worker1 | 
					
						
							|  |  |  |     await worker1.publish_message(MessageType(), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Let the agent run for a bit. | 
					
						
							|  |  |  |     await agent_1_instance.event.wait() | 
					
						
							|  |  |  |     agent_1_instance.event.clear() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Agents in default topic source should have received the message. | 
					
						
							|  |  |  |     assert agent_1_instance.num_calls == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await worker1.remove_subscription(sub.id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Publish message from worker1 | 
					
						
							|  |  |  |     await worker1.publish_message(MessageType(), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Let the agent run for a bit. | 
					
						
							|  |  |  |     await asyncio.sleep(2) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Agent should not have received the message. | 
					
						
							|  |  |  |     assert agent_1_instance.num_calls == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await worker1.stop() | 
					
						
							|  |  |  |     await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  | async def test_register_receives_publish_cascade_single_worker() -> None: | 
					
						
							| 
									
										
										
										
											2024-10-08 15:01:13 -07:00
										 |  |  |     host_address = "localhost:50054" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     host.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     runtime = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await runtime.start() | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     num_agents = 5 | 
					
						
							|  |  |  |     num_initial_messages = 5 | 
					
						
							|  |  |  |     max_rounds = 5 | 
					
						
							|  |  |  |     total_num_calls_expected = 0 | 
					
						
							|  |  |  |     for i in range(0, max_rounds): | 
					
						
							|  |  |  |         total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Register agents | 
					
						
							|  |  |  |     for i in range(num_agents): | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |         await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds)) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Publish messages | 
					
						
							|  |  |  |     for _ in range(num_initial_messages): | 
					
						
							|  |  |  |         await runtime.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-18 19:08:35 -07:00
										 |  |  |     # Wait for all agents to finish. | 
					
						
							|  |  |  |     await asyncio.sleep(10) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Check that each agent received the correct number of messages. | 
					
						
							|  |  |  |     for i in range(num_agents): | 
					
						
							|  |  |  |         agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent) | 
					
						
							|  |  |  |         assert agent.num_calls == total_num_calls_expected | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await runtime.stop() | 
					
						
							|  |  |  |     await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-09-18 19:08:35 -07:00
										 |  |  | @pytest.mark.skip(reason="Fix flakiness") | 
					
						
							|  |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_register_receives_publish_cascade_multiple_workers() -> None: | 
					
						
							|  |  |  |     logging.basicConfig(level=logging.DEBUG) | 
					
						
							| 
									
										
										
										
											2024-10-08 15:01:13 -07:00
										 |  |  |     host_address = "localhost:50055" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-18 19:08:35 -07:00
										 |  |  |     host.start() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # TODO: Increasing num_initial_messages or max_round to 2 causes the test to fail. | 
					
						
							|  |  |  |     num_agents = 2 | 
					
						
							|  |  |  |     num_initial_messages = 1 | 
					
						
							|  |  |  |     max_rounds = 1 | 
					
						
							|  |  |  |     total_num_calls_expected = 0 | 
					
						
							|  |  |  |     for i in range(0, max_rounds): | 
					
						
							|  |  |  |         total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Run multiple workers one for each agent. | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     workers: List[GrpcWorkerAgentRuntime] = [] | 
					
						
							| 
									
										
										
										
											2024-09-18 19:08:35 -07:00
										 |  |  |     # Register agents | 
					
						
							|  |  |  |     for i in range(num_agents): | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |         runtime = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |         await runtime.start() | 
					
						
							| 
									
										
										
										
											2024-09-19 10:50:17 -07:00
										 |  |  |         await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds)) | 
					
						
							| 
									
										
										
										
											2024-09-18 19:08:35 -07:00
										 |  |  |         workers.append(runtime) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Publish messages | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     publisher = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-18 19:08:35 -07:00
										 |  |  |     publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType)) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await publisher.start() | 
					
						
							| 
									
										
										
										
											2024-09-18 19:08:35 -07:00
										 |  |  |     for _ in range(num_initial_messages): | 
					
						
							|  |  |  |         await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await asyncio.sleep(20) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check that each agent received the correct number of messages. | 
					
						
							|  |  |  |     for i in range(num_agents): | 
					
						
							|  |  |  |         agent = await workers[i].try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent) | 
					
						
							|  |  |  |         assert agent.num_calls == total_num_calls_expected | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for worker in workers: | 
					
						
							|  |  |  |         await worker.stop() | 
					
						
							|  |  |  |     await publisher.stop() | 
					
						
							|  |  |  |     await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_default_subscription() -> None: | 
					
						
							| 
									
										
										
										
											2024-10-08 15:01:13 -07:00
										 |  |  |     host_address = "localhost:50056" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     host.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     worker = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await worker.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     publisher = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await publisher.start() | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription()) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     await publisher.publish_message(MessageType(), topic_id=DefaultTopicId()) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await asyncio.sleep(2) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     # Agent in default topic source should have received the message. | 
					
						
							|  |  |  |     long_running_agent = await worker.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     assert long_running_agent.num_calls == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     # Agent in other topic source should not have received the message. | 
					
						
							|  |  |  |     other_long_running_agent = await worker.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     ) | 
					
						
							|  |  |  |     assert other_long_running_agent.num_calls == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     await worker.stop() | 
					
						
							|  |  |  |     await publisher.stop() | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | async def test_default_subscription_other_source() -> None: | 
					
						
							| 
									
										
										
										
											2024-10-08 15:01:13 -07:00
										 |  |  |     host_address = "localhost:50057" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     host.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     runtime = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await runtime.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     publisher = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await publisher.start() | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription()) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     await publisher.publish_message(MessageType(), topic_id=DefaultTopicId(source="other")) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await asyncio.sleep(2) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Agent in default namespace should have received the message | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     long_running_agent = await runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert long_running_agent.num_calls == 0 | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Agent in other namespace should not have received the message | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     other_long_running_agent = await runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     assert other_long_running_agent.num_calls == 1 | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await runtime.stop() | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     await publisher.stop() | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | async def test_type_subscription() -> None: | 
					
						
							| 
									
										
										
										
											2024-10-08 15:01:13 -07:00
										 |  |  |     host_address = "localhost:50058" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     host.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     worker = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await worker.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     publisher = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await publisher.start() | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @type_subscription("Other") | 
					
						
							|  |  |  |     class LoopbackAgentWithSubscription(LoopbackAgent): ... | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     await LoopbackAgentWithSubscription.register(worker, "name", lambda: LoopbackAgentWithSubscription()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await publisher.publish_message(MessageType(), topic_id=TopicId(type="Other", source="default")) | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await asyncio.sleep(2) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     # Agent in default topic source should have received the message. | 
					
						
							|  |  |  |     long_running_agent = await worker.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", "default"), type=LoopbackAgentWithSubscription | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert long_running_agent.num_calls == 1 | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     # Agent in other topic source should not have received the message. | 
					
						
							|  |  |  |     other_long_running_agent = await worker.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", key="other"), type=LoopbackAgentWithSubscription | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     assert other_long_running_agent.num_calls == 0 | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     await worker.stop() | 
					
						
							|  |  |  |     await publisher.stop() | 
					
						
							| 
									
										
										
										
											2024-09-13 08:17:53 -07:00
										 |  |  |     await host.stop() | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_duplicate_subscription() -> None: | 
					
						
							|  |  |  |     host_address = "localhost:50059" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							|  |  |  |     worker1 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							|  |  |  |     worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  |     host.start() | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |         await worker1.start() | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |         await NoopAgent.register(worker1, "worker1", lambda: NoopAgent()) | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |         await worker1_2.start() | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Note: This passes because worker1 is still running | 
					
						
							| 
									
										
										
										
											2025-01-28 11:15:57 -05:00
										 |  |  |         with pytest.raises(Exception, match="Agent type worker1 already registered"): | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |             await NoopAgent.register(worker1_2, "worker1", lambda: NoopAgent()) | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # This is somehow covered in test_disconnected_agent as well as a stop will also disconnect the agent. | 
					
						
							|  |  |  |         #  Will keep them both for now as we might replace the way we simulate a disconnect | 
					
						
							|  |  |  |         await worker1.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with pytest.raises(ValueError): | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |             await NoopAgent.register(worker1_2, "worker1", lambda: NoopAgent()) | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     except Exception as ex: | 
					
						
							|  |  |  |         raise ex | 
					
						
							|  |  |  |     finally: | 
					
						
							|  |  |  |         await worker1_2.stop() | 
					
						
							|  |  |  |         await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_disconnected_agent() -> None: | 
					
						
							| 
									
										
										
										
											2024-10-08 15:01:13 -07:00
										 |  |  |     host_address = "localhost:50060" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  |     host.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     worker1 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							|  |  |  |     worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # TODO: Implementing `get_current_subscriptions` and `get_subscribed_recipients` requires access | 
					
						
							|  |  |  |     # to some private properties. This needs to be updated once they are available publicly | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_current_subscriptions() -> List[Subscription]: | 
					
						
							|  |  |  |         return host._servicer._subscription_manager._subscriptions  # type: ignore[reportPrivateUsage] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def get_subscribed_recipients() -> List[AgentId]: | 
					
						
							|  |  |  |         return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId())  # type: ignore[reportPrivateUsage] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |         await worker1.start() | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |         await LoopbackAgentWithDefaultSubscription.register( | 
					
						
							|  |  |  |             worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         subscriptions1 = get_current_subscriptions() | 
					
						
							| 
									
										
										
										
											2024-11-26 17:01:25 -05:00
										 |  |  |         assert len(subscriptions1) == 2 | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  |         recipients1 = await get_subscribed_recipients() | 
					
						
							|  |  |  |         assert AgentId(type="worker1", key="default") in recipients1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         first_subscription_id = subscriptions1[0].id | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |         await worker1.publish_message(ContentMessage(content="Hello!"), DefaultTopicId()) | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  |         # This is a simple simulation of worker disconnct | 
					
						
							|  |  |  |         if worker1._host_connection is not None:  # type: ignore[reportPrivateUsage] | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 await worker1._host_connection.close()  # type: ignore[reportPrivateUsage] | 
					
						
							|  |  |  |             except asyncio.CancelledError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         await asyncio.sleep(1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         subscriptions2 = get_current_subscriptions() | 
					
						
							|  |  |  |         assert len(subscriptions2) == 0 | 
					
						
							|  |  |  |         recipients2 = await get_subscribed_recipients() | 
					
						
							|  |  |  |         assert len(recipients2) == 0 | 
					
						
							|  |  |  |         await asyncio.sleep(1) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |         await worker1_2.start() | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |         await LoopbackAgentWithDefaultSubscription.register( | 
					
						
							|  |  |  |             worker1_2, "worker1", lambda: LoopbackAgentWithDefaultSubscription() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         subscriptions3 = get_current_subscriptions() | 
					
						
							| 
									
										
										
										
											2024-11-26 17:01:25 -05:00
										 |  |  |         assert len(subscriptions3) == 2 | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  |         assert first_subscription_id not in [x.id for x in subscriptions3] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         recipients3 = await get_subscribed_recipients() | 
					
						
							|  |  |  |         assert len(set(recipients2)) == len(recipients2)  # Make sure there are no duplicates | 
					
						
							|  |  |  |         assert AgentId(type="worker1", key="default") in recipients3 | 
					
						
							|  |  |  |     except Exception as ex: | 
					
						
							|  |  |  |         raise ex | 
					
						
							|  |  |  |     finally: | 
					
						
							|  |  |  |         await worker1.stop() | 
					
						
							|  |  |  |         await worker1_2.stop() | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  | @default_subscription | 
					
						
							|  |  |  | class ProtoReceivingAgent(RoutedAgent): | 
					
						
							|  |  |  |     def __init__(self) -> None: | 
					
						
							|  |  |  |         super().__init__("A loop back agent.") | 
					
						
							|  |  |  |         self.num_calls = 0 | 
					
						
							|  |  |  |         self.received_messages: list[Any] = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @event | 
					
						
							| 
									
										
										
										
											2025-02-10 15:27:27 -05:00
										 |  |  |     async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None:  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  |         self.num_calls += 1 | 
					
						
							|  |  |  |         self.received_messages.append(message) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_proto_payloads() -> None: | 
					
						
							|  |  |  |     host_address = "localhost:50057" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address) | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  |     host.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     receiver_runtime = GrpcWorkerAgentRuntime( | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  |         host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await receiver_runtime.start() | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     publisher_runtime = GrpcWorkerAgentRuntime( | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  |         host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage)) | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |     await publisher_runtime.start() | 
					
						
							| 
									
										
										
										
											2024-11-27 11:32:03 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await ProtoReceivingAgent.register(receiver_runtime, "name", ProtoReceivingAgent) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await publisher_runtime.publish_message(ProtoMessage(message="Hello!"), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await asyncio.sleep(2) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Agent in default namespace should have received the message | 
					
						
							|  |  |  |     long_running_agent = await receiver_runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", "default"), type=ProtoReceivingAgent | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert long_running_agent.num_calls == 1 | 
					
						
							|  |  |  |     assert long_running_agent.received_messages[0].message == "Hello!" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Agent in other namespace should not have received the message | 
					
						
							|  |  |  |     other_long_running_agent = await receiver_runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", key="other"), type=ProtoReceivingAgent | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert other_long_running_agent.num_calls == 0 | 
					
						
							|  |  |  |     assert len(other_long_running_agent.received_messages) == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await receiver_runtime.stop() | 
					
						
							|  |  |  |     await publisher_runtime.stop() | 
					
						
							|  |  |  |     await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # TODO add tests for failure to deserialize | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 11:57:30 -05:00
										 |  |  | @pytest.mark.grpc | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  | @pytest.mark.skip(reason="Fix flakiness") | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  | async def test_grpc_max_message_size() -> None: | 
					
						
							|  |  |  |     default_max_size = 2**22 | 
					
						
							|  |  |  |     new_max_size = default_max_size * 2 | 
					
						
							|  |  |  |     small_message = ContentMessage(content="small message") | 
					
						
							|  |  |  |     big_message = ContentMessage(content="." * (default_max_size + 1)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     extra_grpc_config = [ | 
					
						
							|  |  |  |         ("grpc.max_send_message_length", new_max_size), | 
					
						
							|  |  |  |         ("grpc.max_receive_message_length", new_max_size), | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2024-10-08 15:01:13 -07:00
										 |  |  |     host_address = "localhost:50061" | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     host = GrpcWorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config) | 
					
						
							|  |  |  |     worker1 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config) | 
					
						
							|  |  |  |     worker2 = GrpcWorkerAgentRuntime(host_address=host_address) | 
					
						
							|  |  |  |     worker3 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config) | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         host.start() | 
					
						
							| 
									
										
										
										
											2025-02-12 16:40:52 -05:00
										 |  |  |         await worker1.start() | 
					
						
							|  |  |  |         await worker2.start() | 
					
						
							|  |  |  |         await worker3.start() | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |         await LoopbackAgentWithDefaultSubscription.register( | 
					
						
							|  |  |  |             worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription() | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         await LoopbackAgentWithDefaultSubscription.register( | 
					
						
							|  |  |  |             worker2, "worker2", lambda: LoopbackAgentWithDefaultSubscription() | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         await LoopbackAgentWithDefaultSubscription.register( | 
					
						
							|  |  |  |             worker3, "worker3", lambda: LoopbackAgentWithDefaultSubscription() | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # with pytest.raises(Exception): | 
					
						
							|  |  |  |         await worker1.publish_message(small_message, DefaultTopicId()) | 
					
						
							|  |  |  |         # This is a simple simulation of worker disconnct | 
					
						
							|  |  |  |         await asyncio.sleep(1) | 
					
						
							|  |  |  |         agent_instance_2 = await worker2.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |             AgentId("worker2", key="default"), type=LoopbackAgent | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         agent_instance_3 = await worker3.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |             AgentId("worker3", key="default"), type=LoopbackAgent | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         assert agent_instance_2.num_calls == 1 | 
					
						
							|  |  |  |         assert agent_instance_3.num_calls == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         await worker1.publish_message(big_message, DefaultTopicId()) | 
					
						
							|  |  |  |         await asyncio.sleep(2) | 
					
						
							|  |  |  |         assert agent_instance_2.num_calls == 1  # Worker 2 won't receive the big message | 
					
						
							|  |  |  |         assert agent_instance_3.num_calls == 2  # Worker 3 will receive the big message as has increased message length | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  |     finally: | 
					
						
							|  |  |  |         await worker1.stop() | 
					
						
							|  |  |  |         # await worker2.stop() # Worker 2 somehow breaks can can not be stopped. | 
					
						
							|  |  |  |         await worker3.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  |         await host.stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     os.environ["GRPC_VERBOSITY"] = "DEBUG" | 
					
						
							|  |  |  |     os.environ["GRPC_TRACE"] = "all" | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-05 15:15:01 +00:00
										 |  |  |     asyncio.run(test_disconnected_agent()) | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |     asyncio.run(test_grpc_max_message_size()) |