mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 05:59:19 +00:00
Agent name termination (#4123)
This commit is contained in:
parent
8f4d8c89c3
commit
0b5eaf1240
@ -7,6 +7,7 @@ from ._terminations import (
|
||||
TextMentionTermination,
|
||||
TimeoutTermination,
|
||||
TokenUsageTermination,
|
||||
SourceMatchTermination,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -17,5 +18,6 @@ __all__ = [
|
||||
"HandoffTermination",
|
||||
"TimeoutTermination",
|
||||
"ExternalTermination",
|
||||
"SourceMatchTermination",
|
||||
"Console",
|
||||
]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Sequence
|
||||
from typing import Sequence, List
|
||||
|
||||
from ..base import TerminatedException, TerminationCondition
|
||||
from ..messages import AgentMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
@ -251,3 +251,36 @@ class ExternalTermination(TerminationCondition):
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
self._setted = False
|
||||
|
||||
|
||||
class SourceMatchTermination(TerminationCondition):
|
||||
"""Terminate the conversation after a specific source responds.
|
||||
|
||||
Args:
|
||||
sources (List[str]): List of source names to terminate the conversation.
|
||||
|
||||
Raises:
|
||||
TerminatedException: If the termination condition has already been reached.
|
||||
"""
|
||||
|
||||
def __init__(self, sources: List[str]) -> None:
|
||||
self._sources = sources
|
||||
self._terminated = False
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
if not messages:
|
||||
return None
|
||||
for message in messages:
|
||||
if message.source in self._sources:
|
||||
self._terminated = True
|
||||
return StopMessage(content=f"'{message.source}' answered", source="SourceMatchTermination")
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.base import TerminatedException
|
||||
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
|
||||
from autogen_agentchat.task import (
|
||||
ExternalTermination,
|
||||
@ -10,6 +11,7 @@ from autogen_agentchat.task import (
|
||||
TextMentionTermination,
|
||||
TimeoutTermination,
|
||||
TokenUsageTermination,
|
||||
SourceMatchTermination,
|
||||
)
|
||||
from autogen_core.components.models import RequestUsage
|
||||
|
||||
@ -242,3 +244,26 @@ async def test_external_termination() -> None:
|
||||
|
||||
await termination.reset()
|
||||
assert await termination([]) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_source_match_termination() -> None:
|
||||
termination = SourceMatchTermination(sources=["Assistant"])
|
||||
assert await termination([]) is None
|
||||
|
||||
continue_messages = [TextMessage(content="Hello", source="agent"), TextMessage(content="Hello", source="user")]
|
||||
assert await termination(continue_messages) is None
|
||||
|
||||
terminate_messages = [
|
||||
TextMessage(content="Hello", source="agent"),
|
||||
TextMessage(content="Hello", source="Assistant"),
|
||||
TextMessage(content="Hello", source="user"),
|
||||
]
|
||||
result = await termination(terminate_messages)
|
||||
assert isinstance(result, StopMessage)
|
||||
assert termination.terminated
|
||||
|
||||
with pytest.raises(TerminatedException):
|
||||
await termination([])
|
||||
await termination.reset()
|
||||
assert not termination.terminated
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user