| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  | import asyncio | 
					
						
							|  |  |  | from dataclasses import dataclass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 15:48:14 -04:00
										 |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2024-06-04 10:00:05 -04:00
										 |  |  | from agnext.application import SingleThreadedAgentRuntime | 
					
						
							| 
									
										
										
										
											2024-06-05 15:48:14 -04:00
										 |  |  | from agnext.components import TypeRoutedAgent, message_handler | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  | from agnext.core import AgentId, CancellationToken | 
					
						
							| 
									
										
										
										
											2024-06-05 15:48:14 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  | @dataclass | 
					
						
							|  |  |  | class MessageType: | 
					
						
							|  |  |  |     ... | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Note for future reader: | 
					
						
							|  |  |  | # To do cancellation, only the token should be interacted with as a user | 
					
						
							|  |  |  | # If you cancel a future, it may not work as you expect. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  | class LongRunningAgent(TypeRoutedAgent): | 
					
						
							|  |  |  |     def __init__(self) -> None: | 
					
						
							|  |  |  |         super().__init__("A long running agent") | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |         self.called = False | 
					
						
							|  |  |  |         self.cancelled = False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-13 19:44:51 -04:00
										 |  |  |     @message_handler | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |         self.called = True | 
					
						
							|  |  |  |         sleep = asyncio.ensure_future(asyncio.sleep(100)) | 
					
						
							|  |  |  |         cancellation_token.link_future(sleep) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             await sleep | 
					
						
							|  |  |  |             return MessageType() | 
					
						
							|  |  |  |         except asyncio.CancelledError: | 
					
						
							|  |  |  |             self.cancelled = True | 
					
						
							|  |  |  |             raise | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  | class NestingLongRunningAgent(TypeRoutedAgent): | 
					
						
							|  |  |  |     def __init__(self, nested_agent: AgentId) -> None: | 
					
						
							|  |  |  |         super().__init__("A nesting long running agent") | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |         self.called = False | 
					
						
							|  |  |  |         self.cancelled = False | 
					
						
							|  |  |  |         self._nested_agent = nested_agent | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-13 19:44:51 -04:00
										 |  |  |     @message_handler | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |         self.called = True | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  |         response = self.send_message(message, self._nested_agent, cancellation_token=cancellation_token) | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-05-23 16:00:05 -04:00
										 |  |  |             val = await response | 
					
						
							|  |  |  |             assert isinstance(val, MessageType) | 
					
						
							|  |  |  |             return val | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |         except asyncio.CancelledError: | 
					
						
							|  |  |  |             self.cancelled = True | 
					
						
							|  |  |  |             raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_cancellation_with_token() -> None: | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-23 11:49:38 -07:00
										 |  |  |     long_running = await runtime.register_and_get("long_running", LongRunningAgent) | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |     token = CancellationToken() | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  |     response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token)) | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |     assert not response.done() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  |     while len(runtime.unprocessed_messages) == 0: | 
					
						
							|  |  |  |         await asyncio.sleep(0.01) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     await runtime.process_next() | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |     token.cancel() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(asyncio.CancelledError): | 
					
						
							|  |  |  |         await response | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert response.done() | 
					
						
							| 
									
										
										
										
											2024-07-23 16:38:37 -07:00
										 |  |  |     long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     assert long_running_agent.called | 
					
						
							|  |  |  |     assert long_running_agent.cancelled | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_nested_cancellation_only_outer_called() -> None: | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-23 11:49:38 -07:00
										 |  |  |     long_running = await runtime.register_and_get("long_running", LongRunningAgent) | 
					
						
							|  |  |  |     nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  |     token = CancellationToken() | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  |     response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token)) | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |     assert not response.done() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  |     while len(runtime.unprocessed_messages) == 0: | 
					
						
							|  |  |  |         await asyncio.sleep(0.01) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     await runtime.process_next() | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |     token.cancel() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(asyncio.CancelledError): | 
					
						
							|  |  |  |         await response | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert response.done() | 
					
						
							| 
									
										
										
										
											2024-07-23 16:38:37 -07:00
										 |  |  |     nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     assert nested_agent.called | 
					
						
							|  |  |  |     assert nested_agent.cancelled | 
					
						
							| 
									
										
										
										
											2024-07-23 16:38:37 -07:00
										 |  |  |     long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     assert long_running_agent.called is False | 
					
						
							|  |  |  |     assert long_running_agent.cancelled is False | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_nested_cancellation_inner_called() -> None: | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-23 11:49:38 -07:00
										 |  |  |     long_running = await runtime.register_and_get("long_running", LongRunningAgent ) | 
					
						
							|  |  |  |     nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  |     token = CancellationToken() | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  |     response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token)) | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |     assert not response.done() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  |     while len(runtime.unprocessed_messages) == 0: | 
					
						
							|  |  |  |         await asyncio.sleep(0.01) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     await runtime.process_next() | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |     # allow the inner agent to process | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     await runtime.process_next() | 
					
						
							| 
									
										
										
										
											2024-05-20 13:32:08 -06:00
										 |  |  |     token.cancel() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(asyncio.CancelledError): | 
					
						
							|  |  |  |         await response | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert response.done() | 
					
						
							| 
									
										
										
										
											2024-07-23 16:38:37 -07:00
										 |  |  |     nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     assert nested_agent.called | 
					
						
							|  |  |  |     assert nested_agent.cancelled | 
					
						
							| 
									
										
										
										
											2024-07-23 16:38:37 -07:00
										 |  |  |     long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |     assert long_running_agent.called | 
					
						
							|  |  |  |     assert long_running_agent.cancelled |