| 
									
										
										
										
											2024-10-13 12:12:24 -07:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-09-11 16:47:55 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 15:48:14 -04:00
										 |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2024-12-03 17:00:44 -08:00
										 |  |  | from autogen_core import ( | 
					
						
							| 
									
										
										
										
											2024-09-11 16:47:55 -07:00
										 |  |  |     AgentId, | 
					
						
							|  |  |  |     AgentInstantiationContext, | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     AgentType, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:00:44 -08:00
										 |  |  |     DefaultTopicId, | 
					
						
							| 
									
										
										
										
											2025-02-26 13:34:53 -05:00
										 |  |  |     MessageContext, | 
					
						
							|  |  |  |     RoutedAgent, | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  |     SingleThreadedAgentRuntime, | 
					
						
							| 
									
										
										
										
											2024-09-11 16:47:55 -07:00
										 |  |  |     TopicId, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:00:44 -08:00
										 |  |  |     TypeSubscription, | 
					
						
							| 
									
										
										
										
											2025-02-26 13:34:53 -05:00
										 |  |  |     event, | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     try_get_known_serializers_for_type, | 
					
						
							| 
									
										
										
										
											2024-12-03 17:00:44 -08:00
										 |  |  |     type_subscription, | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | ) | 
					
						
							| 
									
										
										
										
											2025-02-26 13:34:53 -05:00
										 |  |  | from autogen_core._default_subscription import default_subscription | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  | from autogen_test_utils import ( | 
					
						
							| 
									
										
										
										
											2024-10-08 18:46:12 +00:00
										 |  |  |     CascadingAgent, | 
					
						
							|  |  |  |     CascadingMessageType, | 
					
						
							|  |  |  |     LoopbackAgent, | 
					
						
							|  |  |  |     LoopbackAgentWithDefaultSubscription, | 
					
						
							|  |  |  |     MessageType, | 
					
						
							|  |  |  |     NoopAgent, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-12-31 15:11:48 -05:00
										 |  |  | from autogen_test_utils.telemetry_test_utils import MyTestExporter, get_test_tracer_provider | 
					
						
							| 
									
										
										
										
											2024-12-04 16:23:20 -08:00
										 |  |  | from opentelemetry.sdk.trace import TracerProvider | 
					
						
							| 
									
										
										
										
											2024-09-11 16:47:55 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 15:11:48 -05:00
										 |  |  | test_exporter = MyTestExporter() | 
					
						
							| 
									
										
										
										
											2024-09-11 16:47:55 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-13 10:41:15 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-11 16:47:55 -07:00
										 |  |  | @pytest.fixture | 
					
						
							|  |  |  | def tracer_provider() -> TracerProvider: | 
					
						
							|  |  |  |     test_exporter.clear() | 
					
						
							|  |  |  |     return get_test_tracer_provider(test_exporter) | 
					
						
							| 
									
										
										
										
											2024-05-27 16:33:28 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 18:19:32 -08:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_agent_type_register_factory() -> None: | 
					
						
							|  |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def agent_factory() -> NoopAgent: | 
					
						
							|  |  |  |         id = AgentInstantiationContext.current_agent_id() | 
					
						
							|  |  |  |         assert id == AgentId("name1", "default") | 
					
						
							|  |  |  |         agent = NoopAgent() | 
					
						
							|  |  |  |         assert agent.id == id | 
					
						
							|  |  |  |         return agent | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         # This should fail because the expected class does not match the actual class. | 
					
						
							|  |  |  |         await runtime.register_factory( | 
					
						
							|  |  |  |             type=AgentType("name1"), | 
					
						
							|  |  |  |             agent_factory=agent_factory,  # type: ignore | 
					
						
							|  |  |  |             expected_class=CascadingAgent, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Without expected_class, no error. | 
					
						
							|  |  |  |     await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-27 16:33:28 -04:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | async def test_agent_type_must_be_unique() -> None: | 
					
						
							| 
									
										
										
										
											2024-06-17 10:44:46 -04:00
										 |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							| 
									
										
										
										
											2024-05-27 16:33:28 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-02 11:02:45 -04:00
										 |  |  |     def agent_factory() -> NoopAgent: | 
					
						
							|  |  |  |         id = AgentInstantiationContext.current_agent_id() | 
					
						
							| 
									
										
										
										
											2024-06-22 14:50:32 -04:00
										 |  |  |         assert id == AgentId("name1", "default") | 
					
						
							|  |  |  |         agent = NoopAgent() | 
					
						
							|  |  |  |         assert agent.id == id | 
					
						
							|  |  |  |         return agent | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-25 16:15:17 -07:00
										 |  |  |     await NoopAgent.register(runtime, "name1", agent_factory) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent) | 
					
						
							| 
									
										
										
										
											2024-05-27 16:33:28 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ValueError): | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |         await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-27 11:46:06 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2024-09-11 16:47:55 -07:00
										 |  |  | async def test_register_receives_publish(tracer_provider: TracerProvider) -> None: | 
					
						
							|  |  |  |     runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType)) | 
					
						
							|  |  |  |     await runtime.register_factory( | 
					
						
							|  |  |  |         type=AgentType("name"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-08-20 14:41:24 -04:00
										 |  |  |     await runtime.add_subscription(TypeSubscription("default", "name")) | 
					
						
							| 
									
										
										
										
											2024-05-27 16:33:28 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     runtime.start() | 
					
						
							|  |  |  |     await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default")) | 
					
						
							| 
									
										
										
										
											2024-08-21 13:59:59 -07:00
										 |  |  |     await runtime.stop_when_idle() | 
					
						
							| 
									
										
										
										
											2024-05-27 16:33:28 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     # Agent in default namespace should have received the message | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     long_running_agent = await runtime.try_get_underlying_agent_instance(AgentId("name", "default"), type=LoopbackAgent) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     assert long_running_agent.num_calls == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Agent in other namespace should not have received the message | 
					
						
							| 
									
										
										
										
											2024-09-13 10:41:15 -04:00
										 |  |  |     other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", key="other"), type=LoopbackAgent | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     assert other_long_running_agent.num_calls == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-11 16:47:55 -07:00
										 |  |  |     exported_spans = test_exporter.get_exported_spans() | 
					
						
							|  |  |  |     assert len(exported_spans) == 3 | 
					
						
							|  |  |  |     span_names = [span.name for span in exported_spans] | 
					
						
							| 
									
										
										
										
											2024-09-13 10:41:15 -04:00
										 |  |  |     assert span_names == [ | 
					
						
							|  |  |  |         "autogen create default.(default)-T", | 
					
						
							|  |  |  |         "autogen process name.(default)-A", | 
					
						
							|  |  |  |         "autogen publish default.(default)-T", | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2024-09-11 16:47:55 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-07 16:37:02 -05:00
										 |  |  |     await runtime.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-27 11:46:06 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-13 12:12:24 -07:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2024-12-31 15:11:48 -05:00
										 |  |  | async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None: | 
					
						
							| 
									
										
										
										
											2024-10-13 12:12:24 -07:00
										 |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def agent_factory() -> LoopbackAgent: | 
					
						
							|  |  |  |         raise ValueError("test") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await runtime.register_factory(type=AgentType("name"), agent_factory=agent_factory, expected_class=LoopbackAgent) | 
					
						
							|  |  |  |     await runtime.add_subscription(TypeSubscription("default", "name")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with caplog.at_level(logging.ERROR): | 
					
						
							|  |  |  |         runtime.start() | 
					
						
							|  |  |  |         await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default")) | 
					
						
							|  |  |  |         await runtime.stop_when_idle() | 
					
						
							| 
									
										
										
										
											2024-12-31 15:11:48 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Check if logger has the exception. | 
					
						
							|  |  |  |     assert any("Error constructing agent" in e.message for e in caplog.records) | 
					
						
							| 
									
										
										
										
											2024-10-13 12:12:24 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-07 16:37:02 -05:00
										 |  |  |     await runtime.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-13 12:12:24 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-27 11:46:06 -07:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_register_receives_publish_cascade() -> None: | 
					
						
							|  |  |  |     num_agents = 5 | 
					
						
							|  |  |  |     num_initial_messages = 5 | 
					
						
							|  |  |  |     max_rounds = 5 | 
					
						
							|  |  |  |     total_num_calls_expected = 0 | 
					
						
							|  |  |  |     for i in range(0, max_rounds): | 
					
						
							|  |  |  |         total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-27 11:46:06 -07:00
										 |  |  |     # Register agents | 
					
						
							|  |  |  |     for i in range(num_agents): | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |         await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds)) | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-21 13:59:59 -07:00
										 |  |  |     runtime.start() | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-27 11:46:06 -07:00
										 |  |  |     # Publish messages | 
					
						
							|  |  |  |     for _ in range(num_initial_messages): | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |         await runtime.publish_message(CascadingMessageType(round=1), DefaultTopicId()) | 
					
						
							| 
									
										
										
										
											2024-06-27 11:46:06 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Process until idle. | 
					
						
							| 
									
										
										
										
											2024-08-21 13:59:59 -07:00
										 |  |  |     await runtime.stop_when_idle() | 
					
						
							| 
									
										
										
										
											2024-06-27 11:46:06 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Check that each agent received the correct number of messages. | 
					
						
							|  |  |  |     for i in range(num_agents): | 
					
						
							| 
									
										
										
										
											2024-08-20 14:41:24 -04:00
										 |  |  |         agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent) | 
					
						
							| 
									
										
										
										
											2024-06-27 11:46:06 -07:00
										 |  |  |         assert agent.num_calls == total_num_calls_expected | 
					
						
							| 
									
										
										
										
											2024-08-22 16:53:35 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-07 16:37:02 -05:00
										 |  |  |     await runtime.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-13 10:41:15 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-22 16:53:35 -04:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_register_factory_explicit_name() -> None: | 
					
						
							|  |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 19:31:23 -05:00
										 |  |  |     await LoopbackAgent.register(runtime, "name", LoopbackAgent) | 
					
						
							|  |  |  |     await runtime.add_subscription(TypeSubscription("default", "name")) | 
					
						
							| 
									
										
										
										
											2024-08-22 16:53:35 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     runtime.start() | 
					
						
							|  |  |  |     agent_id = AgentId("name", key="default") | 
					
						
							|  |  |  |     topic_id = TopicId("default", "default") | 
					
						
							|  |  |  |     await runtime.publish_message(MessageType(), topic_id=topic_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await runtime.stop_when_idle() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Agent in default namespace should have received the message | 
					
						
							|  |  |  |     long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent) | 
					
						
							|  |  |  |     assert long_running_agent.num_calls == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Agent in other namespace should not have received the message | 
					
						
							| 
									
										
										
										
											2024-09-13 10:41:15 -04:00
										 |  |  |     other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", key="other"), type=LoopbackAgent | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-08-22 16:53:35 -04:00
										 |  |  |     assert other_long_running_agent.num_calls == 0 | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-07 16:37:02 -05:00
										 |  |  |     await runtime.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_default_subscription() -> None: | 
					
						
							|  |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							|  |  |  |     runtime.start() | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     agent_id = AgentId("name", key="default") | 
					
						
							|  |  |  |     await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await runtime.stop_when_idle() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     long_running_agent = await runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         agent_id, type=LoopbackAgentWithDefaultSubscription | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     assert long_running_agent.num_calls == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     other_long_running_agent = await runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription | 
					
						
							| 
									
										
										
										
											2024-09-13 10:41:15 -04:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     assert other_long_running_agent.num_calls == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-07 16:37:02 -05:00
										 |  |  |     await runtime.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-13 10:41:15 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | async def test_type_subscription() -> None: | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							|  |  |  |     runtime.start() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     @type_subscription(topic_type="Other") | 
					
						
							|  |  |  |     class LoopbackAgentWithSubscription(LoopbackAgent): ... | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await LoopbackAgentWithSubscription.register(runtime, "name", LoopbackAgentWithSubscription) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     agent_id = AgentId("name", key="default") | 
					
						
							|  |  |  |     await runtime.publish_message(MessageType(), topic_id=TopicId("Other", "default")) | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     await runtime.stop_when_idle() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgentWithSubscription) | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     assert long_running_agent.num_calls == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     other_long_running_agent = await runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", key="other"), type=LoopbackAgentWithSubscription | 
					
						
							| 
									
										
										
										
											2024-09-13 10:41:15 -04:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     assert other_long_running_agent.num_calls == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-07 16:37:02 -05:00
										 |  |  |     await runtime.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.asyncio | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | async def test_default_subscription_publish_to_other_source() -> None: | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							|  |  |  |     runtime.start() | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     agent_id = AgentId("name", key="default") | 
					
						
							|  |  |  |     await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other")) | 
					
						
							|  |  |  |     await runtime.stop_when_idle() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     long_running_agent = await runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         agent_id, type=LoopbackAgentWithDefaultSubscription | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     assert long_running_agent.num_calls == 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 13:59:39 -07:00
										 |  |  |     other_long_running_agent = await runtime.try_get_underlying_agent_instance( | 
					
						
							|  |  |  |         AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription | 
					
						
							| 
									
										
										
										
											2024-09-13 10:41:15 -04:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-08-23 16:01:57 -04:00
										 |  |  |     assert other_long_running_agent.num_calls == 1 | 
					
						
							| 
									
										
										
										
											2025-01-07 16:37:02 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     await runtime.close() | 
					
						
							| 
									
										
										
										
											2025-02-26 13:34:53 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @default_subscription | 
					
						
							|  |  |  | class FailingAgent(RoutedAgent): | 
					
						
							|  |  |  |     def __init__(self) -> None: | 
					
						
							|  |  |  |         super().__init__("A failing agent.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @event | 
					
						
							|  |  |  |     async def on_new_message_event(self, message: MessageType, ctx: MessageContext) -> None: | 
					
						
							|  |  |  |         raise ValueError("Test exception") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_event_handler_exception_propogates() -> None: | 
					
						
							|  |  |  |     runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) | 
					
						
							|  |  |  |     await FailingAgent.register(runtime, "name", FailingAgent) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ValueError, match="Test exception"): | 
					
						
							|  |  |  |         runtime.start() | 
					
						
							|  |  |  |         await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  |         await runtime.stop_when_idle() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await runtime.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_event_handler_exception_multi_message() -> None: | 
					
						
							|  |  |  |     runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) | 
					
						
							|  |  |  |     await FailingAgent.register(runtime, "name", FailingAgent) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ValueError, match="Test exception"): | 
					
						
							|  |  |  |         runtime.start() | 
					
						
							|  |  |  |         await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  |         await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  |         await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) | 
					
						
							|  |  |  |         await runtime.stop_when_idle() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     await runtime.close() |