AssistantAgent no longer sends out StopMessage. We use TextMentionTermination("TERMINATE") on the team instead for default setting. (#4030)

* AssistantAgent no longer sends out StopMessage. We use TextMentionTermination("TERMINATE") on the team instead for default setting.

* Fix test
This commit is contained in:
Eric Zhu 2024-11-01 12:35:26 -07:00 committed by GitHub
parent 173acc6638
commit c3b2597e12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 44 additions and 48 deletions

View File

@ -24,7 +24,6 @@ from ..messages import (
HandoffMessage,
InnerMessage,
ResetMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
@ -232,8 +231,8 @@ class AssistantAgent(BaseChatAgent):
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
if self._handoffs:
return [TextMessage, HandoffMessage, StopMessage]
return [TextMessage, StopMessage]
return [TextMessage, HandoffMessage]
return [TextMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
@ -303,16 +302,9 @@ class AssistantAgent(BaseChatAgent):
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
assert isinstance(result.content, str)
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
yield Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
else:
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken

View File

@ -22,7 +22,7 @@ from autogen_agentchat.messages import (
ToolCallMessage,
ToolCallResultMessage,
)
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination
from autogen_agentchat.teams import (
RoundRobinGroupChat,
SelectorGroupChat,
@ -151,7 +151,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
expected_messages = [
"Write a program that prints 'Hello, world!'",
@ -172,7 +172,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
@ -247,7 +247,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
assert len(result.messages) == 6
@ -256,7 +256,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], StopMessage) # tool use agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response
context = tool_use_agent._model_context # pyright: ignore
assert context[0].content == "Write a program that prints 'Hello, world!'"
@ -275,7 +275,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
@ -351,7 +351,7 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
)
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
assert len(result.messages) == 6
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
@ -366,7 +366,7 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
@ -401,7 +401,7 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
)
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
assert len(result.messages) == 5
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
@ -417,7 +417,7 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
@ -472,7 +472,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
allow_repeated_speaker=True,
)
result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
)
assert len(result.messages) == 4
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
@ -484,7 +484,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
@ -649,7 +649,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
)
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
team = Swarm([agent1, agent2])
result = await team.run("task", termination_condition=StopMessageTermination())
result = await team.run("task", termination_condition=TextMentionTermination("TERMINATE"))
assert len(result.messages) == 7
assert result.messages[0].content == "task"
assert isinstance(result.messages[1], ToolCallMessage)
@ -663,7 +663,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
stream = team.run_stream("task", termination_condition=StopMessageTermination())
stream = team.run_stream("task", termination_condition=TextMentionTermination("TERMINATE"))
async for message in stream:
if isinstance(message, TaskResult):
assert message == result

View File

@ -18,12 +18,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.task import TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.tools import FunctionTool\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
@ -265,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -400,7 +400,9 @@
}
],
"source": [
"result = await team.run(\"Write a financial report on American airlines\", termination_condition=StopMessageTermination())\n",
"result = await team.run(\n",
" \"Write a financial report on American airlines\", termination_condition=TextMentionTermination(\"TERMINATE\")\n",
")\n",
"print(result)"
]
}

View File

@ -18,12 +18,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.task import TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.tools import FunctionTool\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
@ -161,7 +161,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -332,7 +332,7 @@
"\n",
"result = await team.run(\n",
" task=\"Write a literature review on no code tools for building multi agent ai systems\",\n",
" termination_condition=StopMessageTermination(),\n",
" termination_condition=TextMentionTermination(\"TERMINATE\"),\n",
")"
]
}

View File

@ -13,12 +13,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.agents import CodingAssistantAgent\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.task import TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
]
@ -69,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -195,7 +195,9 @@
],
"source": [
"group_chat = RoundRobinGroupChat([planner_agent, local_agent, language_agent, travel_summary_agent])\n",
"result = await group_chat.run(task=\"Plan a 3 day trip to Nepal.\", termination_condition=StopMessageTermination())\n",
"result = await group_chat.run(\n",
" task=\"Plan a 3 day trip to Nepal.\", termination_condition=TextMentionTermination(\"TERMINATE\")\n",
")\n",
"print(result)"
]
}

View File

@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -47,7 +47,7 @@
")\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.task import TextMentionTermination\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
"from autogen_core.base import CancellationToken\n",
"from autogen_core.components.tools import FunctionTool\n",
@ -114,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -254,7 +254,7 @@
"team = SelectorGroupChat(\n",
" [user_proxy, flight_broker, travel_assistant], model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\")\n",
")\n",
"await team.run(\"Help user plan a trip and book a flight.\", termination_condition=StopMessageTermination())"
"await team.run(\"Help user plan a trip and book a flight.\", termination_condition=TextMentionTermination(\"TERMINATE\"))"
]
}
],
@ -274,7 +274,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.6"
}
},
"nbformat": 4,

View File

@ -28,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -37,7 +37,7 @@
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
"from autogen_agentchat.agents import CodingAssistantAgent\n",
"from autogen_agentchat.logging import ConsoleLogHandler\n",
"from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination\n",
"from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
"\n",
@ -140,7 +140,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -178,7 +178,7 @@
"round_robin_team = RoundRobinGroupChat([writing_assistant_agent])\n",
"\n",
"round_robin_team_result = await round_robin_team.run(\n",
" \"Write a unique, Haiku about the weather in Paris\", termination_condition=StopMessageTermination()\n",
" \"Write a unique, Haiku about the weather in Paris\", termination_condition=TextMentionTermination(\"TERMINATE\")\n",
")"
]
}