autogen/tests/test_cancellation.py

122 lines
3.9 KiB
Python
Raw Normal View History

import asyncio
import pytest
from dataclasses import dataclass
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.core.agent import Agent
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.cancellation_token import CancellationToken
@dataclass
class MessageType:
...
# Note for future reader:
# To do cancellation, only the token should be interacted with as a user
# If you cancel a future, it may not work as you expect.
class LongRunningAgent(TypeRoutedAgent[MessageType]):
def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None:
super().__init__(name, router)
self.called = False
self.cancelled = False
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
cancellation_token.link_future(sleep)
try:
await sleep
return MessageType()
except asyncio.CancelledError:
self.cancelled = True
raise
class NestingLongRunningAgent(TypeRoutedAgent[MessageType]):
def __init__(self, name: str, router: AgentRuntime[MessageType], nested_agent: Agent[MessageType]) -> None:
super().__init__(name, router)
self.called = False
self.cancelled = False
self._nested_agent = nested_agent
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
response = self._send_message(message, self._nested_agent, cancellation_token)
try:
return await response
except asyncio.CancelledError:
self.cancelled = True
raise
@pytest.mark.asyncio
async def test_cancellation_with_token() -> None:
router = SingleThreadedAgentRuntime[MessageType]()
long_running = LongRunningAgent("name", router)
token = CancellationToken()
2024-05-20 17:30:45 -06:00
response = router.send_message(MessageType(), recipient=long_running, cancellation_token=token)
assert not response.done()
await router.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
assert long_running.called
assert long_running.cancelled
@pytest.mark.asyncio
async def test_nested_cancellation_only_outer_called() -> None:
router = SingleThreadedAgentRuntime[MessageType]()
long_running = LongRunningAgent("name", router)
nested = NestingLongRunningAgent("nested", router, long_running)
token = CancellationToken()
2024-05-20 17:30:45 -06:00
response = router.send_message(MessageType(), nested, cancellation_token=token)
assert not response.done()
await router.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
assert nested.called
assert nested.cancelled
assert long_running.called == False
assert long_running.cancelled == False
@pytest.mark.asyncio
async def test_nested_cancellation_inner_called() -> None:
router = SingleThreadedAgentRuntime[MessageType]()
long_running = LongRunningAgent("name", router)
nested = NestingLongRunningAgent("nested", router, long_running)
token = CancellationToken()
2024-05-20 17:30:45 -06:00
response = router.send_message(MessageType(), nested, cancellation_token=token)
assert not response.done()
await router.process_next()
# allow the inner agent to process
await router.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
assert nested.called
assert nested.cancelled
assert long_running.called
assert long_running.cancelled