mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 06:59:03 +00:00
TextMessageTerminationCondition for agentchat (#5742)
Closes #5732 --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
906b09e451
commit
9d4236b1ce
@ -25,9 +25,9 @@ class SocietyOfMindAgentConfig(BaseModel):
|
||||
name: str
|
||||
team: ComponentModel
|
||||
model_client: ComponentModel
|
||||
description: str
|
||||
instruction: str
|
||||
response_prompt: str
|
||||
description: str | None = None
|
||||
instruction: str | None = None
|
||||
response_prompt: str | None = None
|
||||
|
||||
|
||||
class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
||||
@ -103,13 +103,16 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
||||
"""str: The default response prompt to use when generating a response using
|
||||
the inner team's messages. It assumes the role of 'system'."""
|
||||
|
||||
DEFAULT_DESCRIPTION = "An agent that uses an inner team of agents to generate responses."
|
||||
"""str: The default description for a SocietyOfMindAgent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
team: Team,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
description: str = "An agent that uses an inner team of agents to generate responses.",
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
instruction: str = DEFAULT_INSTRUCTION,
|
||||
response_prompt: str = DEFAULT_RESPONSE_PROMPT,
|
||||
) -> None:
|
||||
@ -212,7 +215,7 @@ class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
||||
name=config.name,
|
||||
team=team,
|
||||
model_client=model_client,
|
||||
description=config.description,
|
||||
instruction=config.instruction,
|
||||
response_prompt=config.response_prompt,
|
||||
description=config.description or cls.DEFAULT_DESCRIPTION,
|
||||
instruction=config.instruction or cls.DEFAULT_INSTRUCTION,
|
||||
response_prompt=config.response_prompt or cls.DEFAULT_RESPONSE_PROMPT,
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@ from ._terminations import (
|
||||
SourceMatchTermination,
|
||||
StopMessageTermination,
|
||||
TextMentionTermination,
|
||||
TextMessageTermination,
|
||||
TimeoutTermination,
|
||||
TokenUsageTermination,
|
||||
)
|
||||
@ -23,4 +24,5 @@ __all__ = [
|
||||
"TimeoutTermination",
|
||||
"ExternalTermination",
|
||||
"SourceMatchTermination",
|
||||
"TextMessageTermination",
|
||||
]
|
||||
|
||||
@ -6,7 +6,15 @@ from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..base import TerminatedException, TerminationCondition
|
||||
from ..messages import AgentEvent, BaseChatMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
)
|
||||
|
||||
|
||||
class StopMessageTerminationConfig(BaseModel):
|
||||
@ -428,3 +436,56 @@ class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminat
|
||||
@classmethod
|
||||
def _from_config(cls, config: SourceMatchTerminationConfig) -> Self:
|
||||
return cls(sources=config.sources)
|
||||
|
||||
|
||||
class TextMessageTerminationConfig(BaseModel):
|
||||
"""Configuration for the TextMessageTermination termination condition."""
|
||||
|
||||
source: str | None = None
|
||||
"""The source of the text message to terminate the conversation."""
|
||||
|
||||
|
||||
class TextMessageTermination(TerminationCondition, Component[TextMessageTerminationConfig]):
|
||||
"""Terminate the conversation if a :class:`~autogen_agentchat.messages.TextMessage` is received.
|
||||
|
||||
This termination condition checks for TextMessage instances in the message sequence. When a TextMessage is found,
|
||||
it terminates the conversation if either:
|
||||
- No source was specified (terminates on any TextMessage)
|
||||
- The message source matches the specified source
|
||||
|
||||
Args:
|
||||
source (str | None, optional): The source name to match against incoming messages. If None, matches any source.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
component_config_schema = TextMessageTerminationConfig
|
||||
component_provider_override = "autogen_agentchat.conditions.TextMessageTermination"
|
||||
|
||||
def __init__(self, source: str | None = None) -> None:
|
||||
self._terminated = False
|
||||
self._source = source
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
if isinstance(message, TextMessage) and (self._source is None or message.source == self._source):
|
||||
self._terminated = True
|
||||
return StopMessage(
|
||||
content=f"Text message received from '{message.source}'", source="TextMessageTermination"
|
||||
)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
def _to_config(self) -> TextMessageTerminationConfig:
|
||||
return TextMessageTerminationConfig(source=self._source)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TextMessageTerminationConfig) -> Self:
|
||||
return cls(source=config.source)
|
||||
|
||||
@ -9,6 +9,7 @@ from autogen_agentchat.conditions import (
|
||||
SourceMatchTermination,
|
||||
StopMessageTermination,
|
||||
TextMentionTermination,
|
||||
TextMessageTermination,
|
||||
TimeoutTermination,
|
||||
TokenUsageTermination,
|
||||
)
|
||||
@ -62,6 +63,40 @@ async def test_stop_message_termination() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_message_termination() -> None:
|
||||
termination = TextMessageTermination()
|
||||
assert await termination([]) is None
|
||||
await termination.reset()
|
||||
assert await termination([StopMessage(content="Hello", source="user")]) is None
|
||||
await termination.reset()
|
||||
assert await termination([TextMessage(content="Hello", source="user")]) is not None
|
||||
assert termination.terminated
|
||||
await termination.reset()
|
||||
assert (
|
||||
await termination([StopMessage(content="Hello", source="user"), TextMessage(content="World", source="agent")])
|
||||
is not None
|
||||
)
|
||||
assert termination.terminated
|
||||
with pytest.raises(TerminatedException):
|
||||
await termination([TextMessage(content="Hello", source="user")])
|
||||
|
||||
termination = TextMessageTermination(source="user")
|
||||
assert await termination([]) is None
|
||||
await termination.reset()
|
||||
assert await termination([TextMessage(content="Hello", source="user")]) is not None
|
||||
assert termination.terminated
|
||||
await termination.reset()
|
||||
|
||||
termination = TextMessageTermination(source="agent")
|
||||
assert await termination([]) is None
|
||||
await termination.reset()
|
||||
assert await termination([TextMessage(content="Hello", source="user")]) is None
|
||||
await termination.reset()
|
||||
assert await termination([TextMessage(content="Hello", source="agent")]) is not None
|
||||
assert termination.terminated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_message_termination() -> None:
|
||||
termination = MaxMessageTermination(2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user