2024-05-20 13:32:08 -06:00
|
|
|
import asyncio
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
2024-06-05 15:48:14 -04:00
|
|
|
import pytest
|
2024-06-04 10:00:05 -04:00
|
|
|
from agnext.application import SingleThreadedAgentRuntime
|
2024-06-05 15:48:14 -04:00
|
|
|
from agnext.components import TypeRoutedAgent, message_handler
|
|
|
|
from agnext.core import Agent, AgentRuntime, CancellationToken
|
|
|
|
|
2024-05-20 13:32:08 -06:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class MessageType:
|
|
|
|
...
|
|
|
|
|
|
|
|
# Note for future reader:
|
|
|
|
# To do cancellation, only the token should be interacted with as a user
|
|
|
|
# If you cancel a future, it may not work as you expect.
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
class LongRunningAgent(TypeRoutedAgent): # type: ignore
|
|
|
|
def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore
|
|
|
|
super().__init__(name, "A long running agent", router)
|
2024-05-20 13:32:08 -06:00
|
|
|
self.called = False
|
|
|
|
self.cancelled = False
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
@message_handler() # type: ignore
|
|
|
|
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
|
2024-05-20 13:32:08 -06:00
|
|
|
self.called = True
|
|
|
|
sleep = asyncio.ensure_future(asyncio.sleep(100))
|
|
|
|
cancellation_token.link_future(sleep)
|
|
|
|
try:
|
|
|
|
await sleep
|
|
|
|
return MessageType()
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
self.cancelled = True
|
|
|
|
raise
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
class NestingLongRunningAgent(TypeRoutedAgent): # type: ignore
|
|
|
|
def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None: # type: ignore
|
|
|
|
super().__init__(name, "A nesting long running agent", router)
|
2024-05-20 13:32:08 -06:00
|
|
|
self.called = False
|
|
|
|
self.cancelled = False
|
|
|
|
self._nested_agent = nested_agent
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
@message_handler() # type: ignore
|
|
|
|
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
|
2024-05-20 13:32:08 -06:00
|
|
|
self.called = True
|
2024-05-26 08:45:02 -04:00
|
|
|
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)
|
2024-05-20 13:32:08 -06:00
|
|
|
try:
|
2024-05-23 16:00:05 -04:00
|
|
|
val = await response
|
|
|
|
assert isinstance(val, MessageType)
|
|
|
|
return val
|
2024-05-20 13:32:08 -06:00
|
|
|
except asyncio.CancelledError:
|
|
|
|
self.cancelled = True
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_cancellation_with_token() -> None:
|
2024-05-23 16:00:05 -04:00
|
|
|
router = SingleThreadedAgentRuntime()
|
2024-05-20 13:32:08 -06:00
|
|
|
|
|
|
|
long_running = LongRunningAgent("name", router)
|
|
|
|
token = CancellationToken()
|
2024-05-20 17:30:45 -06:00
|
|
|
response = router.send_message(MessageType(), recipient=long_running, cancellation_token=token)
|
2024-05-20 13:32:08 -06:00
|
|
|
assert not response.done()
|
|
|
|
|
|
|
|
await router.process_next()
|
|
|
|
token.cancel()
|
|
|
|
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
|
|
await response
|
|
|
|
|
|
|
|
assert response.done()
|
|
|
|
assert long_running.called
|
|
|
|
assert long_running.cancelled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_nested_cancellation_only_outer_called() -> None:
|
2024-05-23 16:00:05 -04:00
|
|
|
router = SingleThreadedAgentRuntime()
|
2024-05-20 13:32:08 -06:00
|
|
|
|
|
|
|
long_running = LongRunningAgent("name", router)
|
|
|
|
nested = NestingLongRunningAgent("nested", router, long_running)
|
|
|
|
|
|
|
|
token = CancellationToken()
|
2024-05-20 17:30:45 -06:00
|
|
|
response = router.send_message(MessageType(), nested, cancellation_token=token)
|
2024-05-20 13:32:08 -06:00
|
|
|
assert not response.done()
|
|
|
|
|
|
|
|
await router.process_next()
|
|
|
|
token.cancel()
|
|
|
|
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
|
|
await response
|
|
|
|
|
|
|
|
assert response.done()
|
|
|
|
assert nested.called
|
|
|
|
assert nested.cancelled
|
2024-06-05 15:48:14 -04:00
|
|
|
assert long_running.called is False
|
|
|
|
assert long_running.cancelled is False
|
2024-05-20 13:32:08 -06:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_nested_cancellation_inner_called() -> None:
|
2024-05-23 16:00:05 -04:00
|
|
|
router = SingleThreadedAgentRuntime()
|
2024-05-20 13:32:08 -06:00
|
|
|
|
|
|
|
long_running = LongRunningAgent("name", router)
|
|
|
|
nested = NestingLongRunningAgent("nested", router, long_running)
|
|
|
|
|
|
|
|
token = CancellationToken()
|
2024-05-20 17:30:45 -06:00
|
|
|
response = router.send_message(MessageType(), nested, cancellation_token=token)
|
2024-05-20 13:32:08 -06:00
|
|
|
assert not response.done()
|
|
|
|
|
|
|
|
await router.process_next()
|
|
|
|
# allow the inner agent to process
|
|
|
|
await router.process_next()
|
|
|
|
token.cancel()
|
|
|
|
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
|
|
await response
|
|
|
|
|
|
|
|
assert response.done()
|
|
|
|
assert nested.called
|
|
|
|
assert nested.cancelled
|
|
|
|
assert long_running.called
|
2024-06-05 15:48:14 -04:00
|
|
|
assert long_running.cancelled
|