2024-05-20 17:30:45 -06:00
|
|
|
from dataclasses import dataclass
|
|
|
|
|
2024-06-05 15:48:14 -04:00
|
|
|
import pytest
|
2024-06-04 10:00:05 -04:00
|
|
|
from agnext.application import SingleThreadedAgentRuntime
|
2024-06-05 15:48:14 -04:00
|
|
|
from agnext.components import TypeRoutedAgent, message_handler
|
|
|
|
from agnext.core import Agent, AgentRuntime, CancellationToken
|
2024-05-20 17:30:45 -06:00
|
|
|
from agnext.core.exceptions import MessageDroppedException
|
|
|
|
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
|
|
|
|
2024-06-05 15:48:14 -04:00
|
|
|
|
2024-05-20 17:30:45 -06:00
|
|
|
@dataclass
|
|
|
|
class MessageType:
|
|
|
|
...
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
class LoopbackAgent(TypeRoutedAgent): # type: ignore
|
|
|
|
def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore
|
|
|
|
super().__init__(name, "A loop back agent.", router)
|
2024-05-20 17:30:45 -06:00
|
|
|
self.num_calls = 0
|
|
|
|
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
@message_handler() # type: ignore
|
|
|
|
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
|
2024-05-20 17:30:45 -06:00
|
|
|
self.num_calls += 1
|
|
|
|
return message
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_count_messages() -> None:
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
class DebugInterventionHandler(DefaultInterventionHandler): # type: ignore
|
2024-05-26 08:45:02 -04:00
|
|
|
def __init__(self) -> None:
|
2024-05-20 17:30:45 -06:00
|
|
|
self.num_messages = 0
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType: # type: ignore
|
2024-05-20 17:30:45 -06:00
|
|
|
self.num_messages += 1
|
|
|
|
return message
|
|
|
|
|
|
|
|
handler = DebugInterventionHandler()
|
2024-06-17 17:34:56 -07:00
|
|
|
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-06-17 17:34:56 -07:00
|
|
|
long_running = LoopbackAgent("name", runtime)
|
|
|
|
response = runtime.send_message(MessageType(), recipient=long_running)
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
while not response.done():
|
2024-06-17 17:34:56 -07:00
|
|
|
await runtime.process_next()
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
assert handler.num_messages == 1
|
|
|
|
assert long_running.num_calls == 1
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_drop_send() -> None:
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
class DropSendInterventionHandler(DefaultInterventionHandler): # type: ignore
|
|
|
|
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType | type[DropMessage]: # type: ignore
|
|
|
|
return DropMessage # type: ignore
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
handler = DropSendInterventionHandler()
|
2024-06-17 17:34:56 -07:00
|
|
|
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-06-17 17:34:56 -07:00
|
|
|
long_running = LoopbackAgent("name", runtime)
|
|
|
|
response = runtime.send_message(MessageType(), recipient=long_running)
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
while not response.done():
|
2024-06-17 17:34:56 -07:00
|
|
|
await runtime.process_next()
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
with pytest.raises(MessageDroppedException):
|
|
|
|
await response
|
|
|
|
|
|
|
|
assert long_running.num_calls == 0
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_drop_response() -> None:
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
class DropResponseInterventionHandler(DefaultInterventionHandler): # type: ignore
|
|
|
|
async def on_response(self, message: MessageType, *, sender: Agent, recipient: Agent | None) -> MessageType | type[DropMessage]: # type: ignore
|
|
|
|
return DropMessage # type: ignore
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
handler = DropResponseInterventionHandler()
|
2024-06-17 17:34:56 -07:00
|
|
|
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-06-17 17:34:56 -07:00
|
|
|
long_running = LoopbackAgent("name", runtime)
|
|
|
|
response = runtime.send_message(MessageType(), recipient=long_running)
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
while not response.done():
|
2024-06-17 17:34:56 -07:00
|
|
|
await runtime.process_next()
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
with pytest.raises(MessageDroppedException):
|
|
|
|
await response
|
|
|
|
|
|
|
|
assert long_running.num_calls == 1
|
2024-06-17 17:34:56 -07:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_raise_exception_on_send() -> None:
|
|
|
|
|
|
|
|
class InterventionException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
|
|
|
|
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType | type[DropMessage]: # type: ignore
|
|
|
|
raise InterventionException
|
|
|
|
|
|
|
|
handler = ExceptionInterventionHandler()
|
|
|
|
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
|
|
|
|
|
|
|
long_running = LoopbackAgent("name", runtime)
|
|
|
|
response = runtime.send_message(MessageType(), recipient=long_running)
|
|
|
|
|
|
|
|
while not response.done():
|
|
|
|
await runtime.process_next()
|
|
|
|
|
|
|
|
with pytest.raises(InterventionException):
|
|
|
|
await response
|
|
|
|
|
|
|
|
assert long_running.num_calls == 0
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_raise_exception_on_respond() -> None:
|
|
|
|
|
|
|
|
class InterventionException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
|
|
|
|
async def on_response(self, message: MessageType, *, sender: Agent, recipient: Agent | None) -> MessageType | type[DropMessage]: # type: ignore
|
|
|
|
raise InterventionException
|
|
|
|
|
|
|
|
handler = ExceptionInterventionHandler()
|
|
|
|
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
|
|
|
|
|
|
|
long_running = LoopbackAgent("name", runtime)
|
|
|
|
response = runtime.send_message(MessageType(), recipient=long_running)
|
|
|
|
|
|
|
|
while not response.done():
|
|
|
|
await runtime.process_next()
|
|
|
|
|
|
|
|
with pytest.raises(InterventionException):
|
|
|
|
await response
|
|
|
|
|
|
|
|
assert long_running.num_calls == 1
|