External termination condition (#4294)

This commit is contained in:
Eric Zhu 2024-11-21 03:25:53 -05:00 committed by GitHub
parent 0d79b4b2e8
commit 6e4609a76e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 0 deletions

View File

@ -1,5 +1,6 @@
from ._console import Console from ._console import Console
from ._terminations import ( from ._terminations import (
ExternalTermination,
HandoffTermination, HandoffTermination,
MaxMessageTermination, MaxMessageTermination,
StopMessageTermination, StopMessageTermination,
@ -15,5 +16,6 @@ __all__ = [
"TokenUsageTermination", "TokenUsageTermination",
"HandoffTermination", "HandoffTermination",
"TimeoutTermination", "TimeoutTermination",
"ExternalTermination",
"Console", "Console",
] ]

View File

@ -208,3 +208,45 @@ class TimeoutTermination(TerminationCondition):
async def reset(self) -> None: async def reset(self) -> None:
self._start_time = time.monotonic() self._start_time = time.monotonic()
self._terminated = False self._terminated = False
class ExternalTermination(TerminationCondition):
"""A termination condition that is externally controlled
by calling the :meth:`set` method.
Example:
.. code-block:: python
termination = ExternalTermination()
# Run the team in an asyncio task.
...
# Set the termination condition externally
termination.set()
"""
def __init__(self) -> None:
self._terminated = False
self._setted = False
@property
def terminated(self) -> bool:
return self._terminated
def set(self) -> None:
self._setted = True
async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
if self._setted:
self._terminated = True
return StopMessage(content="External termination requested", source="ExternalTermination")
return None
async def reset(self) -> None:
self._terminated = False
self._setted = False

View File

@ -3,6 +3,7 @@ import asyncio
import pytest import pytest
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
from autogen_agentchat.task import ( from autogen_agentchat.task import (
ExternalTermination,
HandoffTermination, HandoffTermination,
MaxMessageTermination, MaxMessageTermination,
StopMessageTermination, StopMessageTermination,
@ -226,3 +227,18 @@ async def test_timeout_termination() -> None:
assert await termination([TextMessage(content="Hello", source="user")]) is None assert await termination([TextMessage(content="Hello", source="user")]) is None
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
assert await termination([TextMessage(content="World", source="user")]) is not None assert await termination([TextMessage(content="World", source="user")]) is not None
@pytest.mark.asyncio
async def test_external_termination() -> None:
termination = ExternalTermination()
assert await termination([]) is None
assert not termination.terminated
termination.set()
assert await termination([]) is not None
assert termination.terminated
await termination.reset()
assert await termination([]) is None