autogen/python/tests/test_intervention.py

120 lines
4.5 KiB
Python
Raw Normal View History

import pytest
2024-06-04 10:00:05 -04:00
from agnext.application import SingleThreadedAgentRuntime
from agnext.core import AgentId
2024-05-20 17:30:45 -06:00
from agnext.core.exceptions import MessageDroppedException
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
from test_utils import LoopbackAgent, MessageType
2024-05-20 17:30:45 -06:00
2024-05-20 17:30:45 -06:00
@pytest.mark.asyncio
async def test_intervention_count_messages() -> None:
class DebugInterventionHandler(DefaultInterventionHandler):
def __init__(self) -> None:
2024-05-20 17:30:45 -06:00
self.num_messages = 0
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType:
2024-05-20 17:30:45 -06:00
self.num_messages += 1
return message
handler = DebugInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
2024-07-23 11:49:38 -07:00
loopback = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
2024-05-20 17:30:45 -06:00
_response = await runtime.send_message(MessageType(), recipient=loopback)
2024-05-20 17:30:45 -06:00
await run_context.stop()
2024-05-20 17:30:45 -06:00
assert handler.num_messages == 1
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert loopback_agent.num_calls == 1
2024-05-20 17:30:45 -06:00
@pytest.mark.asyncio
async def test_intervention_drop_send() -> None:
class DropSendInterventionHandler(DefaultInterventionHandler):
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]:
return DropMessage
2024-05-20 17:30:45 -06:00
handler = DropSendInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
2024-05-20 17:30:45 -06:00
2024-07-23 11:49:38 -07:00
loopback = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
2024-05-20 17:30:45 -06:00
with pytest.raises(MessageDroppedException):
_response = await runtime.send_message(MessageType(), recipient=loopback)
await run_context.stop()
2024-05-20 17:30:45 -06:00
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert loopback_agent.num_calls == 0
2024-05-20 17:30:45 -06:00
@pytest.mark.asyncio
async def test_intervention_drop_response() -> None:
class DropResponseInterventionHandler(DefaultInterventionHandler):
async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]:
return DropMessage
2024-05-20 17:30:45 -06:00
handler = DropResponseInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
2024-05-20 17:30:45 -06:00
2024-07-23 11:49:38 -07:00
loopback = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
2024-05-20 17:30:45 -06:00
with pytest.raises(MessageDroppedException):
_response = await runtime.send_message(MessageType(), recipient=loopback)
await run_context.stop()
2024-05-20 17:30:45 -06:00
@pytest.mark.asyncio
async def test_intervention_raise_exception_on_send() -> None:
class InterventionException(Exception):
pass
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]: # type: ignore
raise InterventionException
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
2024-07-23 11:49:38 -07:00
long_running = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(InterventionException):
_response = await runtime.send_message(MessageType(), recipient=long_running)
await run_context.stop()
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
assert long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_intervention_raise_exception_on_respond() -> None:
class InterventionException(Exception):
pass
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]: # type: ignore
raise InterventionException
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
2024-07-23 11:49:38 -07:00
long_running = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(InterventionException):
_response = await runtime.send_message(MessageType(), recipient=long_running)
await run_context.stop()
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent)
assert long_running_agent.num_calls == 1