mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 10:50:03 +00:00
feat!: Add message context to signature of intervention handler, add more to docs (#4882)
* Add message context to signature of intervention handler, add more to docs * example * Add to test * Fix pyright * mypy
This commit is contained in:
parent
f4382f01c8
commit
5b9be79fba
@ -23,7 +23,6 @@
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"from autogen_core import (\n",
|
||||
" AgentId,\n",
|
||||
" DefaultInterventionHandler,\n",
|
||||
" DefaultTopicId,\n",
|
||||
" MessageContext,\n",
|
||||
@ -100,7 +99,7 @@
|
||||
" def __init__(self) -> None:\n",
|
||||
" self._termination_value: Termination | None = None\n",
|
||||
"\n",
|
||||
" async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:\n",
|
||||
" async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:\n",
|
||||
" if isinstance(message, Termination):\n",
|
||||
" self._termination_value = message\n",
|
||||
" return message\n",
|
||||
@ -171,7 +170,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -131,7 +131,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
|
||||
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
|
||||
" async def on_send(\n",
|
||||
" self, message: Any, *, message_context: MessageContext, recipient: AgentId\n",
|
||||
" ) -> Any | type[DropMessage]:\n",
|
||||
" if isinstance(message, FunctionCall):\n",
|
||||
" # Request user prompt for tool execution.\n",
|
||||
" user_input = input(\n",
|
||||
|
||||
@ -31,7 +31,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
CancellationToken,
|
||||
DefaultInterventionHandler,
|
||||
DefaultTopicId,
|
||||
@ -211,7 +210,7 @@ class NeedsUserInputHandler(DefaultInterventionHandler):
|
||||
def __init__(self):
|
||||
self.question_for_user: GetSlowUserMessage | None = None
|
||||
|
||||
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
|
||||
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
|
||||
if isinstance(message, GetSlowUserMessage):
|
||||
self.question_for_user = message
|
||||
return message
|
||||
@ -231,7 +230,7 @@ class TerminationHandler(DefaultInterventionHandler):
|
||||
def __init__(self):
|
||||
self.terminateMessage: TerminateMessage | None = None
|
||||
|
||||
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
|
||||
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
|
||||
if isinstance(message, TerminateMessage):
|
||||
self.terminateMessage = message
|
||||
return message
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import Any, Protocol, final
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._message_context import MessageContext
|
||||
|
||||
__all__ = [
|
||||
"DropMessage",
|
||||
@ -10,20 +11,59 @@ __all__ = [
|
||||
|
||||
|
||||
@final
|
||||
class DropMessage: ...
|
||||
class DropMessage:
|
||||
"""Marker type for signalling that a message should be dropped by an intervention handler. The type itself should be returned from the handler."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class InterventionHandler(Protocol):
|
||||
"""An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`autogen_core.base.AgentRuntime`.
|
||||
|
||||
The handler is called when the message is submitted to the runtime.
|
||||
|
||||
Currently the only runtime which supports this is the :class:`autogen_core.base.SingleThreadedAgentRuntime`.
|
||||
|
||||
Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_core import DefaultInterventionHandler, MessageContext, AgentId, SingleThreadedAgentRuntime
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class MyInterventionHandler(DefaultInterventionHandler):
|
||||
async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> MyMessage:
|
||||
if isinstance(message, MyMessage):
|
||||
message.content = message.content.upper()
|
||||
return message
|
||||
|
||||
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handlers=[MyInterventionHandler()])
|
||||
|
||||
"""
|
||||
|
||||
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ...
|
||||
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ...
|
||||
async def on_response(
|
||||
self, message: Any, *, sender: AgentId, recipient: AgentId | None
|
||||
) -> Any | type[DropMessage]: ...
|
||||
async def on_send(
|
||||
self, message: Any, *, message_context: MessageContext, recipient: AgentId
|
||||
) -> Any | type[DropMessage]:
|
||||
"""Called when a message is submitted to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.send_message`."""
|
||||
...
|
||||
|
||||
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
|
||||
"""Called when a message is published to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.publish_message`."""
|
||||
...
|
||||
|
||||
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
|
||||
"""Called when a response is received by the AgentRuntime from an Agent's message handler returning a value."""
|
||||
...
|
||||
|
||||
|
||||
class DefaultInterventionHandler(InterventionHandler):
|
||||
@ -31,10 +71,12 @@ class DefaultInterventionHandler(InterventionHandler):
|
||||
handler methods, that simply returns the message unchanged. Allows for easy
|
||||
subclassing to override only the desired methods."""
|
||||
|
||||
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:
|
||||
async def on_send(
|
||||
self, message: Any, *, message_context: MessageContext, recipient: AgentId
|
||||
) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]:
|
||||
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
|
||||
|
||||
@ -474,7 +474,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
|
||||
):
|
||||
try:
|
||||
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
|
||||
message_context = MessageContext(
|
||||
sender=sender,
|
||||
topic_id=None,
|
||||
is_rpc=True,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
message_id=message_envelope.message_id,
|
||||
)
|
||||
temp_message = await handler.on_send(
|
||||
message, message_context=message_context, recipient=recipient
|
||||
)
|
||||
_warn_if_none(temp_message, "on_send")
|
||||
except BaseException as e:
|
||||
future.set_exception(e)
|
||||
@ -506,7 +515,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
|
||||
):
|
||||
try:
|
||||
temp_message = await handler.on_publish(message, sender=sender)
|
||||
message_context = MessageContext(
|
||||
sender=sender,
|
||||
topic_id=topic_id,
|
||||
is_rpc=False,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
message_id=message_envelope.message_id,
|
||||
)
|
||||
temp_message = await handler.on_publish(message, message_context=message_context)
|
||||
_warn_if_none(temp_message, "on_publish")
|
||||
except BaseException as e:
|
||||
# TODO: we should raise the intervention exception to the publisher.
|
||||
|
||||
@ -1,5 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from autogen_core import AgentId, DefaultInterventionHandler, DropMessage, SingleThreadedAgentRuntime
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
DefaultInterventionHandler,
|
||||
DefaultSubscription,
|
||||
DefaultTopicId,
|
||||
DropMessage,
|
||||
MessageContext,
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
from autogen_core.exceptions import MessageDroppedException
|
||||
from autogen_test_utils import LoopbackAgent, MessageType
|
||||
|
||||
@ -8,10 +18,20 @@ from autogen_test_utils import LoopbackAgent, MessageType
|
||||
async def test_intervention_count_messages() -> None:
|
||||
class DebugInterventionHandler(DefaultInterventionHandler):
|
||||
def __init__(self) -> None:
|
||||
self.num_messages = 0
|
||||
self.num_send_messages = 0
|
||||
self.num_publish_messages = 0
|
||||
self.num_response_messages = 0
|
||||
|
||||
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType:
|
||||
self.num_messages += 1
|
||||
async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> Any:
|
||||
self.num_send_messages += 1
|
||||
return message
|
||||
|
||||
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
|
||||
self.num_publish_messages += 1
|
||||
return message
|
||||
|
||||
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any:
|
||||
self.num_response_messages += 1
|
||||
return message
|
||||
|
||||
handler = DebugInterventionHandler()
|
||||
@ -22,18 +42,28 @@ async def test_intervention_count_messages() -> None:
|
||||
|
||||
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
await runtime.stop()
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
assert handler.num_messages == 1
|
||||
assert handler.num_send_messages == 1
|
||||
assert handler.num_response_messages == 1
|
||||
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
|
||||
assert loopback_agent.num_calls == 1
|
||||
|
||||
runtime.start()
|
||||
await runtime.add_subscription(DefaultSubscription(agent_type="name"))
|
||||
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
assert loopback_agent.num_calls == 2
|
||||
assert handler.num_publish_messages == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_drop_send() -> None:
|
||||
class DropSendInterventionHandler(DefaultInterventionHandler):
|
||||
async def on_send(
|
||||
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
|
||||
self, message: MessageType, *, message_context: MessageContext, recipient: AgentId
|
||||
) -> MessageType | type[DropMessage]:
|
||||
return DropMessage
|
||||
|
||||
@ -81,7 +111,7 @@ async def test_intervention_raise_exception_on_send() -> None:
|
||||
|
||||
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
|
||||
async def on_send(
|
||||
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
|
||||
self, message: MessageType, *, message_context: MessageContext, recipient: AgentId
|
||||
) -> MessageType | type[DropMessage]: # type: ignore
|
||||
raise InterventionException
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user