mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-06 00:20:25 +00:00

Resolves #4075 1. Introduce custom runtime parameter for all AgentChat teams (RoundRobinGroupChat, SelectorGroupChat, etc.). This is done by making sure each team's topics are isolated from other teams, and decoupling state from agent identities. Also, I removed the closure agent from the BaseGroupChat and use the group chat manager agent to relay messages to the output message queue. 2. Added unit tests to test scenarios with custom runtimes by using pytest fixture 3. Refactored existing unit tests to use ReplayChatCompletionClient with a few improvements to the client. 4. Fix a one-liner bug in AssistantAgent that caused deserialized agent to have handoffs. How to use it? ```python import asyncio from autogen_core import SingleThreadedAgentRuntime from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.conditions import TextMentionTermination from autogen_ext.models.replay import ReplayChatCompletionClient async def main() -> None: # Create a runtime runtime = SingleThreadedAgentRuntime() runtime.start() # Create a model client. model_client = ReplayChatCompletionClient( ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], ) # Create agents agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") # Create a termination condition termination_condition = TextMentionTermination("10", sources=["assistant1", "assistant2"]) # Create a team team = RoundRobinGroupChat([agent1, agent2], runtime=runtime, termination_condition=termination_condition) # Run the team stream = team.run_stream(task="Count to 10.") async for message in stream: print(message) # Save the state. state = await team.save_state() # Load the state to an existing team. await team.load_state(state) # Run the team again model_client.reset() stream = team.run_stream(task="Count to 10.") async for message in stream: print(message) # Create a new team, with the same agent names. agent3 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") agent4 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") new_team = RoundRobinGroupChat([agent3, agent4], runtime=runtime, termination_condition=termination_condition) # Load the state to the new team. await new_team.load_state(state) # Run the new team model_client.reset() new_stream = new_team.run_stream(task="Count to 10.") async for message in new_stream: print(message) # Stop the runtime await runtime.stop() asyncio.run(main()) ``` TODOs as future PRs: 1. Documentation. 2. How to handle errors in custom runtime when the agent has exception? --------- Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
222 lines
9.3 KiB
Python
222 lines
9.3 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import AsyncGenerator, Sequence
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
|
from autogen_agentchat.agents import (
|
|
BaseChatAgent,
|
|
)
|
|
from autogen_agentchat.base import Response
|
|
from autogen_agentchat.messages import (
|
|
ChatMessage,
|
|
TextMessage,
|
|
)
|
|
from autogen_agentchat.teams import (
|
|
MagenticOneGroupChat,
|
|
)
|
|
from autogen_agentchat.teams._group_chat._magentic_one._magentic_one_orchestrator import MagenticOneOrchestrator
|
|
from autogen_core import AgentId, AgentRuntime, CancellationToken, SingleThreadedAgentRuntime
|
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
|
from utils import FileLogHandler
|
|
|
|
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
|
logger.setLevel(logging.DEBUG)
|
|
logger.addHandler(FileLogHandler("test_magentic_one_group_chat.log"))
|
|
|
|
|
|
class _EchoAgent(BaseChatAgent):
|
|
def __init__(self, name: str, description: str) -> None:
|
|
super().__init__(name, description)
|
|
self._last_message: str | None = None
|
|
self._total_messages = 0
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
|
return (TextMessage,)
|
|
|
|
@property
|
|
def total_messages(self) -> int:
|
|
return self._total_messages
|
|
|
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
if len(messages) > 0:
|
|
assert isinstance(messages[0], TextMessage)
|
|
self._last_message = messages[0].content
|
|
self._total_messages += 1
|
|
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
|
|
else:
|
|
assert self._last_message is not None
|
|
self._total_messages += 1
|
|
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
self._last_message = None
|
|
|
|
|
|
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
|
|
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
|
|
if request.param == "single_threaded":
|
|
runtime = SingleThreadedAgentRuntime()
|
|
runtime.start()
|
|
yield runtime
|
|
await runtime.stop()
|
|
elif request.param == "embedded":
|
|
yield None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_magentic_one_group_chat_cancellation(runtime: AgentRuntime | None) -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
agent_4 = _EchoAgent("agent_4", description="echo agent 4")
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
chat_completions=["test", "test", json.dumps({"is_request_satisfied": {"answer": True, "reason": "test"}})],
|
|
)
|
|
|
|
# Set max_turns to a large number to avoid stopping due to max_turns before cancellation.
|
|
team = MagenticOneGroupChat(
|
|
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, runtime=runtime
|
|
)
|
|
cancellation_token = CancellationToken()
|
|
run_task = asyncio.create_task(
|
|
team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
cancellation_token=cancellation_token,
|
|
)
|
|
)
|
|
|
|
# Cancel the task.
|
|
cancellation_token.cancel()
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await run_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_magentic_one_group_chat_basic(runtime: AgentRuntime | None) -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
agent_4 = _EchoAgent("agent_4", description="echo agent 4")
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
chat_completions=[
|
|
"No facts",
|
|
"No plan",
|
|
json.dumps(
|
|
{
|
|
"is_request_satisfied": {"answer": False, "reason": "test"},
|
|
"is_progress_being_made": {"answer": True, "reason": "test"},
|
|
"is_in_loop": {"answer": False, "reason": "test"},
|
|
"instruction_or_question": {"answer": "Continue task", "reason": "test"},
|
|
"next_speaker": {"answer": "agent_1", "reason": "test"},
|
|
}
|
|
),
|
|
json.dumps(
|
|
{
|
|
"is_request_satisfied": {"answer": True, "reason": "Because"},
|
|
"is_progress_being_made": {"answer": True, "reason": "test"},
|
|
"is_in_loop": {"answer": False, "reason": "test"},
|
|
"instruction_or_question": {"answer": "Task completed", "reason": "Because"},
|
|
"next_speaker": {"answer": "agent_1", "reason": "test"},
|
|
}
|
|
),
|
|
"print('Hello, world!')",
|
|
],
|
|
)
|
|
|
|
team = MagenticOneGroupChat(
|
|
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, runtime=runtime
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 5
|
|
assert result.messages[2].content == "Continue task"
|
|
assert result.messages[4].content == "print('Hello, world!')"
|
|
assert result.stop_reason is not None and result.stop_reason == "Because"
|
|
|
|
# Test save and load.
|
|
state = await team.save_state()
|
|
team2 = MagenticOneGroupChat(
|
|
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, runtime=runtime
|
|
)
|
|
await team2.load_state(state)
|
|
state2 = await team2.save_state()
|
|
assert state == state2
|
|
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId(f"{team._group_chat_manager_name}_{team._team_id}", team._team_id), # pyright: ignore
|
|
MagenticOneOrchestrator, # pyright: ignore
|
|
) # pyright: ignore
|
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
|
|
MagenticOneOrchestrator, # pyright: ignore
|
|
) # pyright: ignore
|
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
|
assert manager_1._task == manager_2._task # pyright: ignore
|
|
assert manager_1._facts == manager_2._facts # pyright: ignore
|
|
assert manager_1._plan == manager_2._plan # pyright: ignore
|
|
assert manager_1._n_rounds == manager_2._n_rounds # pyright: ignore
|
|
assert manager_1._n_stalls == manager_2._n_stalls # pyright: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_magentic_one_group_chat_with_stalls(runtime: AgentRuntime | None) -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
agent_4 = _EchoAgent("agent_4", description="echo agent 4")
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
chat_completions=[
|
|
"No facts",
|
|
"No plan",
|
|
json.dumps(
|
|
{
|
|
"is_request_satisfied": {"answer": False, "reason": "test"},
|
|
"is_progress_being_made": {"answer": False, "reason": "test"},
|
|
"is_in_loop": {"answer": True, "reason": "test"},
|
|
"instruction_or_question": {"answer": "Stalling", "reason": "test"},
|
|
"next_speaker": {"answer": "agent_1", "reason": "test"},
|
|
}
|
|
),
|
|
json.dumps(
|
|
{
|
|
"is_request_satisfied": {"answer": False, "reason": "test"},
|
|
"is_progress_being_made": {"answer": False, "reason": "test"},
|
|
"is_in_loop": {"answer": True, "reason": "test"},
|
|
"instruction_or_question": {"answer": "Stalling again", "reason": "test"},
|
|
"next_speaker": {"answer": "agent_2", "reason": "test"},
|
|
}
|
|
),
|
|
"No facts2",
|
|
"No plan2",
|
|
json.dumps(
|
|
{
|
|
"is_request_satisfied": {"answer": True, "reason": "test"},
|
|
"is_progress_being_made": {"answer": True, "reason": "test"},
|
|
"is_in_loop": {"answer": False, "reason": "test"},
|
|
"instruction_or_question": {"answer": "Task completed", "reason": "test"},
|
|
"next_speaker": {"answer": "agent_3", "reason": "test"},
|
|
}
|
|
),
|
|
"print('Hello, world!')",
|
|
],
|
|
)
|
|
|
|
team = MagenticOneGroupChat(
|
|
participants=[agent_1, agent_2, agent_3, agent_4],
|
|
model_client=model_client,
|
|
max_stalls=2,
|
|
runtime=runtime,
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 6
|
|
assert isinstance(result.messages[1].content, str)
|
|
assert result.messages[1].content.startswith("\nWe are working to address the following user request:")
|
|
assert isinstance(result.messages[4].content, str)
|
|
assert result.messages[4].content.startswith("\nWe are working to address the following user request:")
|
|
assert result.stop_reason is not None and result.stop_reason == "test"
|