feat!: Add message context to signature of intervention handler, add more to docs (#4882)

* Add message context to signature of intervention handler, add more to docs

* example

* Add to test

* Fix pyright

* mypy
This commit is contained in:
Jack Gerrits 2025-01-07 12:51:35 -05:00 committed by GitHub
parent f4382f01c8
commit 5b9be79fba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 113 additions and 25 deletions

View File

@ -23,7 +23,6 @@
"from typing import Any\n",
"\n",
"from autogen_core import (\n",
" AgentId,\n",
" DefaultInterventionHandler,\n",
" DefaultTopicId,\n",
" MessageContext,\n",
@ -100,7 +99,7 @@
" def __init__(self) -> None:\n",
" self._termination_value: Termination | None = None\n",
"\n",
" async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:\n",
" async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:\n",
" if isinstance(message, Termination):\n",
" self._termination_value = message\n",
" return message\n",
@ -171,7 +170,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.5"
}
},
"nbformat": 4,

View File

@ -131,7 +131,9 @@
"outputs": [],
"source": [
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
" async def on_send(\n",
" self, message: Any, *, message_context: MessageContext, recipient: AgentId\n",
" ) -> Any | type[DropMessage]:\n",
" if isinstance(message, FunctionCall):\n",
" # Request user prompt for tool execution.\n",
" user_input = input(\n",

View File

@ -31,7 +31,6 @@ from dataclasses import dataclass
from typing import Any, Mapping, Optional
from autogen_core import (
AgentId,
CancellationToken,
DefaultInterventionHandler,
DefaultTopicId,
@ -211,7 +210,7 @@ class NeedsUserInputHandler(DefaultInterventionHandler):
def __init__(self):
self.question_for_user: GetSlowUserMessage | None = None
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
if isinstance(message, GetSlowUserMessage):
self.question_for_user = message
return message
@ -231,7 +230,7 @@ class TerminationHandler(DefaultInterventionHandler):
def __init__(self):
self.terminateMessage: TerminateMessage | None = None
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
if isinstance(message, TerminateMessage):
self.terminateMessage = message
return message

View File

@ -1,6 +1,7 @@
from typing import Any, Protocol, final
from ._agent_id import AgentId
from ._message_context import MessageContext
__all__ = [
"DropMessage",
@ -10,20 +11,59 @@ __all__ = [
@final
class DropMessage: ...
class DropMessage:
"""Marker type for signalling that a message should be dropped by an intervention handler. The type itself should be returned from the handler."""
...
class InterventionHandler(Protocol):
"""An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`autogen_core.base.AgentRuntime`.
The handler is called when the message is submitted to the runtime.
Currently the only runtime which supports this is the :class:`autogen_core.base.SingleThreadedAgentRuntime`.
Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly.
Example:
.. code-block:: python
from autogen_core import DefaultInterventionHandler, MessageContext, AgentId, SingleThreadedAgentRuntime
from dataclasses import dataclass
from typing import Any
@dataclass
class MyMessage:
content: str
class MyInterventionHandler(DefaultInterventionHandler):
async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> MyMessage:
if isinstance(message, MyMessage):
message.content = message.content.upper()
return message
runtime = SingleThreadedAgentRuntime(intervention_handlers=[MyInterventionHandler()])
"""
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ...
async def on_response(
self, message: Any, *, sender: AgentId, recipient: AgentId | None
) -> Any | type[DropMessage]: ...
async def on_send(
self, message: Any, *, message_context: MessageContext, recipient: AgentId
) -> Any | type[DropMessage]:
"""Called when a message is submitted to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.send_message`."""
...
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
"""Called when a message is published to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.publish_message`."""
...
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
"""Called when a response is received by the AgentRuntime from an Agent's message handler returning a value."""
...
class DefaultInterventionHandler(InterventionHandler):
@ -31,10 +71,12 @@ class DefaultInterventionHandler(InterventionHandler):
handler methods, that simply returns the message unchanged. Allows for easy
subclassing to override only the desired methods."""
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:
async def on_send(
self, message: Any, *, message_context: MessageContext, recipient: AgentId
) -> Any | type[DropMessage]:
return message
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]:
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
return message
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:

View File

@ -474,7 +474,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
message_context = MessageContext(
sender=sender,
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
temp_message = await handler.on_send(
message, message_context=message_context, recipient=recipient
)
_warn_if_none(temp_message, "on_send")
except BaseException as e:
future.set_exception(e)
@ -506,7 +515,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_publish(message, sender=sender)
message_context = MessageContext(
sender=sender,
topic_id=topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
temp_message = await handler.on_publish(message, message_context=message_context)
_warn_if_none(temp_message, "on_publish")
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.

View File

@ -1,5 +1,15 @@
from typing import Any
import pytest
from autogen_core import AgentId, DefaultInterventionHandler, DropMessage, SingleThreadedAgentRuntime
from autogen_core import (
AgentId,
DefaultInterventionHandler,
DefaultSubscription,
DefaultTopicId,
DropMessage,
MessageContext,
SingleThreadedAgentRuntime,
)
from autogen_core.exceptions import MessageDroppedException
from autogen_test_utils import LoopbackAgent, MessageType
@ -8,10 +18,20 @@ from autogen_test_utils import LoopbackAgent, MessageType
async def test_intervention_count_messages() -> None:
class DebugInterventionHandler(DefaultInterventionHandler):
def __init__(self) -> None:
self.num_messages = 0
self.num_send_messages = 0
self.num_publish_messages = 0
self.num_response_messages = 0
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType:
self.num_messages += 1
async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> Any:
self.num_send_messages += 1
return message
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
self.num_publish_messages += 1
return message
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any:
self.num_response_messages += 1
return message
handler = DebugInterventionHandler()
@ -22,18 +42,28 @@ async def test_intervention_count_messages() -> None:
_response = await runtime.send_message(MessageType(), recipient=loopback)
await runtime.stop()
await runtime.stop_when_idle()
assert handler.num_messages == 1
assert handler.num_send_messages == 1
assert handler.num_response_messages == 1
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert loopback_agent.num_calls == 1
runtime.start()
await runtime.add_subscription(DefaultSubscription(agent_type="name"))
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()
assert loopback_agent.num_calls == 2
assert handler.num_publish_messages == 1
@pytest.mark.asyncio
async def test_intervention_drop_send() -> None:
class DropSendInterventionHandler(DefaultInterventionHandler):
async def on_send(
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
self, message: MessageType, *, message_context: MessageContext, recipient: AgentId
) -> MessageType | type[DropMessage]:
return DropMessage
@ -81,7 +111,7 @@ async def test_intervention_raise_exception_on_send() -> None:
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
async def on_send(
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
self, message: MessageType, *, message_context: MessageContext, recipient: AgentId
) -> MessageType | type[DropMessage]: # type: ignore
raise InterventionException