2024-06-05 15:48:14 -04:00
|
|
|
import pytest
|
2024-06-04 10:00:05 -04:00
|
|
|
from agnext.application import SingleThreadedAgentRuntime
|
2024-06-18 14:53:18 -04:00
|
|
|
from agnext.core import AgentId
|
2024-05-20 17:30:45 -06:00
|
|
|
from agnext.core.exceptions import MessageDroppedException
|
|
|
|
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
2024-06-18 14:53:18 -04:00
|
|
|
from test_utils import LoopbackAgent, MessageType
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-06-05 15:48:14 -04:00
|
|
|
|
2024-05-20 17:30:45 -06:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_count_messages() -> None:
|
|
|
|
|
2024-06-18 14:53:18 -04:00
|
|
|
class DebugInterventionHandler(DefaultInterventionHandler):
|
2024-05-26 08:45:02 -04:00
|
|
|
def __init__(self) -> None:
|
2024-05-20 17:30:45 -06:00
|
|
|
self.num_messages = 0
|
|
|
|
|
2024-06-18 14:53:18 -04:00
|
|
|
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType:
|
2024-05-20 17:30:45 -06:00
|
|
|
self.num_messages += 1
|
|
|
|
return message
|
|
|
|
|
|
|
|
handler = DebugInterventionHandler()
|
2024-06-24 16:52:09 -04:00
|
|
|
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
2024-07-23 11:49:38 -07:00
|
|
|
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
2024-07-01 11:53:45 -04:00
|
|
|
run_context = runtime.start()
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-07-01 11:53:45 -04:00
|
|
|
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-07-01 11:53:45 -04:00
|
|
|
await run_context.stop()
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
assert handler.num_messages == 1
|
2024-07-23 11:49:38 -07:00
|
|
|
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
2024-06-18 14:53:18 -04:00
|
|
|
assert loopback_agent.num_calls == 1
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_drop_send() -> None:
|
|
|
|
|
2024-06-18 14:53:18 -04:00
|
|
|
class DropSendInterventionHandler(DefaultInterventionHandler):
|
|
|
|
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]:
|
|
|
|
return DropMessage
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
handler = DropSendInterventionHandler()
|
2024-06-24 16:52:09 -04:00
|
|
|
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-07-23 11:49:38 -07:00
|
|
|
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
2024-07-01 11:53:45 -04:00
|
|
|
run_context = runtime.start()
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
with pytest.raises(MessageDroppedException):
|
2024-07-01 11:53:45 -04:00
|
|
|
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
|
|
|
|
|
|
|
await run_context.stop()
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-07-23 11:49:38 -07:00
|
|
|
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
2024-06-18 14:53:18 -04:00
|
|
|
assert loopback_agent.num_calls == 0
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_drop_response() -> None:
|
|
|
|
|
2024-06-18 14:53:18 -04:00
|
|
|
class DropResponseInterventionHandler(DefaultInterventionHandler):
|
|
|
|
async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]:
|
|
|
|
return DropMessage
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
handler = DropResponseInterventionHandler()
|
2024-06-24 16:52:09 -04:00
|
|
|
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-07-23 11:49:38 -07:00
|
|
|
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
2024-07-01 11:53:45 -04:00
|
|
|
run_context = runtime.start()
|
2024-05-20 17:30:45 -06:00
|
|
|
|
|
|
|
with pytest.raises(MessageDroppedException):
|
2024-07-01 11:53:45 -04:00
|
|
|
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
|
|
|
|
|
|
|
await run_context.stop()
|
2024-05-20 17:30:45 -06:00
|
|
|
|
2024-06-17 17:34:56 -07:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_raise_exception_on_send() -> None:
|
|
|
|
|
|
|
|
class InterventionException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
|
2024-06-18 14:53:18 -04:00
|
|
|
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]: # type: ignore
|
2024-06-17 17:34:56 -07:00
|
|
|
raise InterventionException
|
|
|
|
|
|
|
|
handler = ExceptionInterventionHandler()
|
2024-06-24 16:52:09 -04:00
|
|
|
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
2024-06-17 17:34:56 -07:00
|
|
|
|
2024-07-23 11:49:38 -07:00
|
|
|
long_running = await runtime.register_and_get("name", LoopbackAgent)
|
2024-07-01 11:53:45 -04:00
|
|
|
run_context = runtime.start()
|
2024-06-17 17:34:56 -07:00
|
|
|
|
|
|
|
with pytest.raises(InterventionException):
|
2024-07-01 11:53:45 -04:00
|
|
|
_response = await runtime.send_message(MessageType(), recipient=long_running)
|
|
|
|
|
|
|
|
await run_context.stop()
|
2024-06-17 17:34:56 -07:00
|
|
|
|
2024-07-23 11:49:38 -07:00
|
|
|
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
2024-06-18 14:53:18 -04:00
|
|
|
assert long_running_agent.num_calls == 0
|
2024-06-17 17:34:56 -07:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_intervention_raise_exception_on_respond() -> None:
|
|
|
|
|
|
|
|
class InterventionException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
|
2024-06-18 14:53:18 -04:00
|
|
|
async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]: # type: ignore
|
2024-06-17 17:34:56 -07:00
|
|
|
raise InterventionException
|
|
|
|
|
|
|
|
handler = ExceptionInterventionHandler()
|
2024-06-24 16:52:09 -04:00
|
|
|
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
2024-06-17 17:34:56 -07:00
|
|
|
|
2024-07-23 11:49:38 -07:00
|
|
|
long_running = await runtime.register_and_get("name", LoopbackAgent)
|
2024-07-01 11:53:45 -04:00
|
|
|
run_context = runtime.start()
|
2024-06-17 17:34:56 -07:00
|
|
|
with pytest.raises(InterventionException):
|
2024-07-01 11:53:45 -04:00
|
|
|
_response = await runtime.send_message(MessageType(), recipient=long_running)
|
|
|
|
|
|
|
|
await run_context.stop()
|
2024-06-17 17:34:56 -07:00
|
|
|
|
2024-07-23 11:49:38 -07:00
|
|
|
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
2024-06-18 14:53:18 -04:00
|
|
|
assert long_running_agent.num_calls == 1
|