Move intervention objects to root module (#4859)

* Move intervention to root

* usage
This commit is contained in:
Jack Gerrits 2024-12-30 16:09:37 -05:00 committed by GitHub
parent 0569689e6b
commit c58eb9d120
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 241 additions and 215 deletions

View File

@ -1,179 +1,179 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Termination using Intervention Handler\n",
"\n",
"```{note}\n",
"This method is valid when using {py:class}`~autogen_core.SingleThreadedAgentRuntime`.\n",
"```\n",
"\n",
"There are many different ways to handle termination in `autogen_core`. Ultimately, the goal is to detect that the runtime no longer needs to be executed and you can proceed to finalization tasks. One way to do this is to use an {py:class}`autogen_core.base.intervention.InterventionHandler` to detect a termination message and then act on it."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import Any\n",
"\n",
"from autogen_core import (\n",
" AgentId,\n",
" DefaultTopicId,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" default_subscription,\n",
" message_handler,\n",
")\n",
"from autogen_core.base.intervention import DefaultInterventionHandler"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we define a dataclass for regular message and message that will be used to signal termination."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Message:\n",
" content: Any\n",
"\n",
"\n",
"@dataclass\n",
"class Termination:\n",
" reason: str"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We code our agent to publish a termination message when it decides it is time to terminate."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"@default_subscription\n",
"class AnAgent(RoutedAgent):\n",
" def __init__(self) -> None:\n",
" super().__init__(\"MyAgent\")\n",
" self.received = 0\n",
"\n",
" @message_handler\n",
" async def on_new_message(self, message: Message, ctx: MessageContext) -> None:\n",
" self.received += 1\n",
" if self.received > 3:\n",
" await self.publish_message(Termination(reason=\"Reached maximum number of messages\"), DefaultTopicId())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we create an InterventionHandler that will detect the termination message and act on it. This one hooks into publishes and when it encounters `Termination` it alters its internal state to indicate that termination has been requested."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class TerminationHandler(DefaultInterventionHandler):\n",
" def __init__(self) -> None:\n",
" self._termination_value: Termination | None = None\n",
"\n",
" async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:\n",
" if isinstance(message, Termination):\n",
" self._termination_value = message\n",
" return message\n",
"\n",
" @property\n",
" def termination_value(self) -> Termination | None:\n",
" return self._termination_value\n",
"\n",
" @property\n",
" def has_terminated(self) -> bool:\n",
" return self._termination_value is not None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we add this handler to the runtime and use it to detect termination and stop the runtime when the termination message is received."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Termination(reason='Reached maximum number of messages')\n"
]
}
],
"source": [
"termination_handler = TerminationHandler()\n",
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[termination_handler])\n",
"\n",
"await AnAgent.register(runtime, \"my_agent\", AnAgent)\n",
"\n",
"runtime.start()\n",
"\n",
"# Publish more than 3 messages to trigger termination.\n",
"await runtime.publish_message(Message(\"hello\"), DefaultTopicId())\n",
"await runtime.publish_message(Message(\"hello\"), DefaultTopicId())\n",
"await runtime.publish_message(Message(\"hello\"), DefaultTopicId())\n",
"await runtime.publish_message(Message(\"hello\"), DefaultTopicId())\n",
"\n",
"# Wait for termination.\n",
"await runtime.stop_when(lambda: termination_handler.has_terminated)\n",
"\n",
"print(termination_handler.termination_value)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Termination using Intervention Handler\n",
"\n",
"```{note}\n",
"This method is valid when using {py:class}`~autogen_core.SingleThreadedAgentRuntime`.\n",
"```\n",
"\n",
"There are many different ways to handle termination in `autogen_core`. Ultimately, the goal is to detect that the runtime no longer needs to be executed and you can proceed to finalization tasks. One way to do this is to use an {py:class}`autogen_core.base.intervention.InterventionHandler` to detect a termination message and then act on it."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import Any\n",
"\n",
"from autogen_core import (\n",
" AgentId,\n",
" DefaultInterventionHandler,\n",
" DefaultTopicId,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" default_subscription,\n",
" message_handler,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we define a dataclass for regular message and message that will be used to signal termination."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Message:\n",
" content: Any\n",
"\n",
"\n",
"@dataclass\n",
"class Termination:\n",
" reason: str"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We code our agent to publish a termination message when it decides it is time to terminate."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"@default_subscription\n",
"class AnAgent(RoutedAgent):\n",
" def __init__(self) -> None:\n",
" super().__init__(\"MyAgent\")\n",
" self.received = 0\n",
"\n",
" @message_handler\n",
" async def on_new_message(self, message: Message, ctx: MessageContext) -> None:\n",
" self.received += 1\n",
" if self.received > 3:\n",
" await self.publish_message(Termination(reason=\"Reached maximum number of messages\"), DefaultTopicId())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we create an InterventionHandler that will detect the termination message and act on it. This one hooks into publishes and when it encounters `Termination` it alters its internal state to indicate that termination has been requested."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class TerminationHandler(DefaultInterventionHandler):\n",
" def __init__(self) -> None:\n",
" self._termination_value: Termination | None = None\n",
"\n",
" async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:\n",
" if isinstance(message, Termination):\n",
" self._termination_value = message\n",
" return message\n",
"\n",
" @property\n",
" def termination_value(self) -> Termination | None:\n",
" return self._termination_value\n",
"\n",
" @property\n",
" def has_terminated(self) -> bool:\n",
" return self._termination_value is not None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we add this handler to the runtime and use it to detect termination and stop the runtime when the termination message is received."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Termination(reason='Reached maximum number of messages')\n"
]
}
],
"source": [
"termination_handler = TerminationHandler()\n",
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[termination_handler])\n",
"\n",
"await AnAgent.register(runtime, \"my_agent\", AnAgent)\n",
"\n",
"runtime.start()\n",
"\n",
"# Publish more than 3 messages to trigger termination.\n",
"await runtime.publish_message(Message(\"hello\"), DefaultTopicId())\n",
"await runtime.publish_message(Message(\"hello\"), DefaultTopicId())\n",
"await runtime.publish_message(Message(\"hello\"), DefaultTopicId())\n",
"await runtime.publish_message(Message(\"hello\"), DefaultTopicId())\n",
"\n",
"# Wait for termination.\n",
"await runtime.stop_when(lambda: termination_handler.has_terminated)\n",
"\n",
"print(termination_handler.termination_value)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -22,13 +22,14 @@
"from autogen_core import (\n",
" AgentId,\n",
" AgentType,\n",
" DefaultInterventionHandler,\n",
" DropMessage,\n",
" FunctionCall,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" message_handler,\n",
")\n",
"from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage\n",
"from autogen_core.models import (\n",
" ChatCompletionClient,\n",
" LLMMessage,\n",

View File

@ -33,6 +33,7 @@ from typing import Any, Mapping, Optional
from autogen_core import (
AgentId,
CancellationToken,
DefaultInterventionHandler,
DefaultTopicId,
FunctionCall,
MessageContext,
@ -41,7 +42,6 @@ from autogen_core import (
message_handler,
type_subscription,
)
from autogen_core.base.intervention import DefaultInterventionHandler
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import (
AssistantMessage,

View File

@ -31,6 +31,11 @@ from ._constants import (
from ._default_subscription import DefaultSubscription, default_subscription, type_subscription
from ._default_topic import DefaultTopicId
from ._image import Image
from ._intervention import (
DefaultInterventionHandler,
DropMessage,
InterventionHandler,
)
from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext
from ._routed_agent import RoutedAgent, event, message_handler, rpc
@ -111,4 +116,7 @@ __all__ = [
"ComponentConfigImpl",
"ComponentModel",
"ComponentType",
"DropMessage",
"InterventionHandler",
"DefaultInterventionHandler",
]

View File

@ -0,0 +1,41 @@
from typing import Any, Protocol, final
from ._agent_id import AgentId
__all__ = [
"DropMessage",
"InterventionHandler",
"DefaultInterventionHandler",
]
@final
class DropMessage: ...
class InterventionHandler(Protocol):
"""An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`autogen_core.base.AgentRuntime`.
Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly.
"""
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ...
async def on_response(
self, message: Any, *, sender: AgentId, recipient: AgentId | None
) -> Any | type[DropMessage]: ...
class DefaultInterventionHandler(InterventionHandler):
"""Simple class that provides a default implementation for all intervention
handler methods, that simply returns the message unchanged. Allows for easy
subclassing to override only the desired methods."""
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:
return message
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]:
return message
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
return message

View File

@ -28,6 +28,7 @@ from ._agent_metadata import AgentMetadata
from ._agent_runtime import AgentRuntime
from ._agent_type import AgentType
from ._cancellation_token import CancellationToken
from ._intervention import DropMessage, InterventionHandler
from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext
from ._runtime_impl_helpers import SubscriptionManager, get_impl
@ -35,7 +36,6 @@ from ._serialization import MessageSerializer, SerializationRegistry
from ._subscription import Subscription
from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata
from ._topic import TopicId
from .base.intervention import DropMessage, InterventionHandler
from .exceptions import MessageDroppedException
logger = logging.getLogger("autogen_core")

View File

@ -1,45 +1,22 @@
from typing import Any, Awaitable, Callable, Protocol, final
from typing_extensions import deprecated
from .._agent_id import AgentId
from .._intervention import DefaultInterventionHandler as DefaultInterventionHandlerAlias
from .._intervention import DropMessage as DropMessageAlias
from .._intervention import InterventionHandler as InterventionHandlerAliass
__all__ = [
"DropMessage",
"InterventionFunction",
"InterventionHandler",
"DefaultInterventionHandler",
]
@final
class DropMessage: ...
# Final so can't inherit and deprecate
DropMessage = DropMessageAlias
InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]]
@deprecated("Moved to autogen_core.InterventionHandler. Will remove this in 0.4.0.", stacklevel=2)
class InterventionHandler(InterventionHandlerAliass): ...
class InterventionHandler(Protocol):
"""An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`autogen_core.base.AgentRuntime`.
Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly.
"""
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ...
async def on_response(
self, message: Any, *, sender: AgentId, recipient: AgentId | None
) -> Any | type[DropMessage]: ...
class DefaultInterventionHandler(InterventionHandler):
"""Simple class that provides a default implementation for all intervention
handler methods, that simply returns the message unchanged. Allows for easy
subclassing to override only the desired methods."""
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:
return message
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]:
return message
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
return message
@deprecated("Moved to autogen_core.DefaultInterventionHandler. Will remove this in 0.4.0.", stacklevel=2)
class DefaultInterventionHandler(DefaultInterventionHandlerAlias): ...

View File

@ -1,6 +1,5 @@
import pytest
from autogen_core import AgentId, SingleThreadedAgentRuntime
from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage
from autogen_core import AgentId, DefaultInterventionHandler, DropMessage, SingleThreadedAgentRuntime
from autogen_core.exceptions import MessageDroppedException
from autogen_test_utils import LoopbackAgent, MessageType