autogen/python/packages/autogen-agentchat/tests/test_termination_condition.py

434 lines
14 KiB
Python
Raw Permalink Normal View History

import asyncio
from typing import Sequence
import pytest
2024-11-23 22:07:21 +07:00
from autogen_agentchat.base import TerminatedException
from autogen_agentchat.conditions import (
2024-11-21 03:25:53 -05:00
ExternalTermination,
FunctionalTermination,
FunctionCallTermination,
HandoffTermination,
MaxMessageTermination,
SourceMatchTermination,
StopMessageTermination,
TextMentionTermination,
TextMessageTermination,
TimeoutTermination,
TokenUsageTermination,
)
from autogen_agentchat.messages import (
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
StopMessage,
StructuredMessage,
TextMessage,
ToolCallExecutionEvent,
UserInputRequestedEvent,
)
from autogen_core.models import FunctionExecutionResult, RequestUsage
from pydantic import BaseModel
@pytest.mark.asyncio
async def test_handoff_termination() -> None:
termination = HandoffTermination("target")
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert await termination([HandoffMessage(target="target", source="user", content="Hello")]) is not None
assert termination.terminated
await termination.reset()
assert await termination([HandoffMessage(target="another", source="user", content="Hello")]) is None
assert not termination.terminated
await termination.reset()
assert (
await termination(
[
TextMessage(content="Hello", source="user"),
HandoffMessage(target="target", source="user", content="Hello"),
]
)
is not None
)
assert termination.terminated
await termination.reset()
@pytest.mark.asyncio
async def test_stop_message_termination() -> None:
termination = StopMessageTermination()
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert await termination([StopMessage(content="Stop", source="user")]) is not None
await termination.reset()
assert (
await termination([TextMessage(content="Hello", source="user"), TextMessage(content="World", source="agent")])
is None
)
await termination.reset()
assert (
await termination([TextMessage(content="Hello", source="user"), StopMessage(content="Stop", source="user")])
is not None
)
@pytest.mark.asyncio
async def test_text_message_termination() -> None:
termination = TextMessageTermination()
assert await termination([]) is None
await termination.reset()
assert await termination([StopMessage(content="Hello", source="user")]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is not None
assert termination.terminated
await termination.reset()
assert (
await termination([StopMessage(content="Hello", source="user"), TextMessage(content="World", source="agent")])
is not None
)
assert termination.terminated
with pytest.raises(TerminatedException):
await termination([TextMessage(content="Hello", source="user")])
termination = TextMessageTermination(source="user")
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is not None
assert termination.terminated
await termination.reset()
termination = TextMessageTermination(source="agent")
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="agent")]) is not None
assert termination.terminated
@pytest.mark.asyncio
async def test_max_message_termination() -> None:
termination = MaxMessageTermination(2)
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert (
await termination([TextMessage(content="Hello", source="user"), TextMessage(content="World", source="agent")])
is not None
)
termination = MaxMessageTermination(2, include_agent_event=True)
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert (
await termination(
[TextMessage(content="Hello", source="user"), UserInputRequestedEvent(request_id="1", source="agent")]
)
is not None
)
@pytest.mark.asyncio
async def test_mention_termination() -> None:
termination = TextMentionTermination("stop")
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert await termination([TextMessage(content="stop", source="user")]) is not None
await termination.reset()
assert (
await termination([TextMessage(content="Hello", source="user"), TextMessage(content="stop", source="user")])
is not None
)
termination = TextMentionTermination("stop", sources=["agent"])
assert await termination([TextMessage(content="stop", source="user")]) is None
await termination.reset()
assert (
await termination([TextMessage(content="stop", source="user"), TextMessage(content="stop", source="agent")])
is not None
)
@pytest.mark.asyncio
async def test_token_usage_termination() -> None:
termination = TokenUsageTermination(max_total_token=10)
assert await termination([]) is None
await termination.reset()
assert (
await termination(
[
TextMessage(
content="Hello", source="user", models_usage=RequestUsage(prompt_tokens=10, completion_tokens=10)
)
]
)
is not None
)
await termination.reset()
assert (
await termination(
[
TextMessage(
content="Hello", source="user", models_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
),
TextMessage(
content="World", source="agent", models_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
),
]
)
is None
)
await termination.reset()
assert (
await termination(
[
TextMessage(
content="Hello", source="user", models_usage=RequestUsage(prompt_tokens=5, completion_tokens=0)
),
TextMessage(
content="stop", source="user", models_usage=RequestUsage(prompt_tokens=0, completion_tokens=5)
),
]
)
is not None
)
@pytest.mark.asyncio
async def test_and_termination() -> None:
termination = MaxMessageTermination(2) & TextMentionTermination("stop")
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert (
await termination([TextMessage(content="Hello", source="user"), TextMessage(content="World", source="agent")])
is None
)
await termination.reset()
assert (
await termination([TextMessage(content="Hello", source="user"), TextMessage(content="stop", source="user")])
is not None
)
@pytest.mark.asyncio
async def test_or_termination() -> None:
termination = MaxMessageTermination(3) | TextMentionTermination("stop")
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert (
await termination([TextMessage(content="Hello", source="user"), TextMessage(content="World", source="agent")])
is None
)
await termination.reset()
assert (
await termination([TextMessage(content="Hello", source="user"), TextMessage(content="stop", source="user")])
is not None
)
await termination.reset()
assert (
await termination([TextMessage(content="Hello", source="user"), TextMessage(content="Hello", source="user")])
is None
)
await termination.reset()
assert (
await termination(
[
TextMessage(content="Hello", source="user"),
TextMessage(content="Hello", source="user"),
TextMessage(content="Hello", source="user"),
]
)
is not None
)
await termination.reset()
assert (
await termination(
[
TextMessage(content="Hello", source="user"),
TextMessage(content="Hello", source="user"),
TextMessage(content="stop", source="user"),
]
)
is not None
)
await termination.reset()
assert (
await termination(
[
TextMessage(content="Hello", source="user"),
TextMessage(content="Hello", source="user"),
TextMessage(content="Hello", source="user"),
TextMessage(content="stop", source="user"),
]
)
is not None
)
@pytest.mark.asyncio
async def test_timeout_termination() -> None:
termination = TimeoutTermination(0.1) # 100ms timeout
assert await termination([]) is None
assert not termination.terminated
await asyncio.sleep(0.2)
assert await termination([]) is not None
assert termination.terminated
await termination.reset()
assert not termination.terminated
assert await termination([]) is None
assert await termination([TextMessage(content="Hello", source="user")]) is None
await asyncio.sleep(0.2)
assert await termination([TextMessage(content="World", source="user")]) is not None
2024-11-21 03:25:53 -05:00
@pytest.mark.asyncio
async def test_external_termination() -> None:
termination = ExternalTermination()
assert await termination([]) is None
assert not termination.terminated
termination.set()
assert await termination([]) is not None
assert termination.terminated
await termination.reset()
assert await termination([]) is None
2024-11-23 22:07:21 +07:00
@pytest.mark.asyncio
async def test_source_match_termination() -> None:
termination = SourceMatchTermination(sources=["Assistant"])
assert await termination([]) is None
continue_messages = [TextMessage(content="Hello", source="agent"), TextMessage(content="Hello", source="user")]
assert await termination(continue_messages) is None
terminate_messages = [
TextMessage(content="Hello", source="agent"),
TextMessage(content="Hello", source="Assistant"),
TextMessage(content="Hello", source="user"),
]
result = await termination(terminate_messages)
assert isinstance(result, StopMessage)
assert termination.terminated
with pytest.raises(TerminatedException):
await termination([])
await termination.reset()
assert not termination.terminated
@pytest.mark.asyncio
async def test_function_call_termination() -> None:
termination = FunctionCallTermination(function_name="test_function")
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert (
await termination(
[TextMessage(content="Hello", source="user"), ToolCallExecutionEvent(content=[], source="assistant")]
)
is None
)
await termination.reset()
assert (
await termination(
[
TextMessage(content="Hello", source="user"),
ToolCallExecutionEvent(
content=[FunctionExecutionResult(content="", name="test_function", call_id="")], source="assistant"
),
]
)
is not None
)
assert termination.terminated
await termination.reset()
assert (
await termination(
[
TextMessage(content="Hello", source="user"),
ToolCallExecutionEvent(
content=[FunctionExecutionResult(content="", name="another_function", call_id="")],
source="assistant",
),
]
)
is None
)
assert not termination.terminated
await termination.reset()
@pytest.mark.asyncio
async def test_functional_termination() -> None:
async def async_termination_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> bool:
if len(messages) < 1:
return False
if isinstance(messages[-1], TextMessage):
return messages[-1].content == "stop"
return False
termination = FunctionalTermination(async_termination_func)
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert await termination([TextMessage(content="stop", source="user")]) is not None
assert termination.terminated
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
class TestContentType(BaseModel):
content: str
data: str
def sync_termination_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> bool:
if len(messages) < 1:
return False
last_message = messages[-1]
if isinstance(last_message, StructuredMessage) and isinstance(last_message.content, TestContentType): # type: ignore[reportUnknownMemberType]
return last_message.content.data == "stop"
return False
termination = FunctionalTermination(sync_termination_func)
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert (
await termination(
[StructuredMessage[TestContentType](content=TestContentType(content="1", data="stop"), source="user")]
)
is not None
)
assert termination.terminated
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None