diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index d1b4dd173..361103a4e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -32,7 +32,7 @@ from autogen_core.models import ( ModelFamily, SystemMessage, ) -from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, StaticWorkbench, ToolResult, Workbench +from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, ToolResult, Workbench from pydantic import BaseModel from typing_extensions import Self @@ -754,7 +754,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): else: self._workbench = [workbench] else: - self._workbench = [StaticWorkbench(self._tools)] + self._workbench = [StaticStreamWorkbench(self._tools)] if model_context is not None: self._model_context = model_context @@ -1051,6 +1051,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): yield tool_call_msg # STEP 4B: Execute tool calls + # Use a queue to handle streaming results from tool calls. stream = asyncio.Queue[BaseAgentEvent | BaseChatMessage | None]() async def _execute_tool_calls( @@ -1069,25 +1070,24 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): for call in function_calls ] ) + # Signal the end of streaming by putting None in the queue. stream.put_nowait(None) return results task = asyncio.create_task(_execute_tool_calls(model_result.content)) while True: - try: - event = await stream.get() - if event is None: - break - if isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage): - yield event - inner_messages.append(event) - else: - raise RuntimeError(f"Unexpected event type: {type(event)}") - except asyncio.CancelledError: - task.cancel() - raise + event = await stream.get() + if event is None: + # End of streaming, break the loop. + break + if isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage): + yield event + inner_messages.append(event) + else: + raise RuntimeError(f"Unexpected event type: {type(event)}") + # Wait for all tool calls to complete. executed_calls_and_results = await task exec_results = [result for _, result in executed_calls_and_results] @@ -1377,7 +1377,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): if isinstance(event, ToolResult): tool_result = event elif isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage): - stream.put_nowait(event) + await stream.put(event) else: warnings.warn( f"Unexpected event type: {type(event)} in tool call streaming.", diff --git a/python/packages/autogen-agentchat/tests/test_task_runner_tool.py b/python/packages/autogen-agentchat/tests/test_task_runner_tool.py index 3d3d58b7d..8249f95d1 100644 --- a/python/packages/autogen-agentchat/tests/test_task_runner_tool.py +++ b/python/packages/autogen-agentchat/tests/test_task_runner_tool.py @@ -1,11 +1,14 @@ import pytest from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import MaxMessageTermination +from autogen_agentchat.messages import TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.tools import AgentTool, TeamTool from autogen_core import ( CancellationToken, + FunctionCall, ) +from autogen_core.models import CreateResult, RequestUsage from autogen_ext.models.replay import ReplayChatCompletionClient from test_group_chat import _EchoAgent # type: ignore[reportPrivateUsage] @@ -98,3 +101,135 @@ async def test_team_tool_component() -> None: assert tool2.name == "Team Tool" assert tool2.description == "A team tool for testing" assert isinstance(tool2._team, RoundRobinGroupChat) # type: ignore[reportPrivateUsage] + + +@pytest.mark.asyncio +async def test_agent_tool_stream() -> None: + """Test running a task with AgentTool in streaming mode.""" + + def _query_function() -> str: + return "Test task" + + tool_agent_model_client = ReplayChatCompletionClient( + [ + CreateResult( + content=[FunctionCall(name="query_function", arguments="{}", id="1")], + finish_reason="function_calls", + usage=RequestUsage(prompt_tokens=0, completion_tokens=0), + cached=False, + ), + "Summary from tool agent", + ], + model_info={ + "family": "gpt-41", + "function_calling": True, + "json_output": True, + "multiple_system_messages": True, + "structured_output": True, + "vision": True, + }, + ) + tool_agent = AssistantAgent( + name="tool_agent", + model_client=tool_agent_model_client, + tools=[_query_function], + reflect_on_tool_use=True, + description="An agent for testing", + ) + tool = AgentTool(tool_agent) + + main_agent_model_client = ReplayChatCompletionClient( + [ + CreateResult( + content=[FunctionCall(id="1", name="tool_agent", arguments='{"task": "Input task from main agent"}')], + finish_reason="function_calls", + usage=RequestUsage(prompt_tokens=0, completion_tokens=0), + cached=False, + ), + "Summary from main agent", + ], + model_info={ + "family": "gpt-41", + "function_calling": True, + "json_output": True, + "multiple_system_messages": True, + "structured_output": True, + "vision": True, + }, + ) + + main_agent = AssistantAgent( + name="main_agent", + model_client=main_agent_model_client, + tools=[tool], + reflect_on_tool_use=True, + description="An agent for testing", + ) + result = await main_agent.run(task="Input task from user", cancellation_token=CancellationToken()) + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].content == "Input task from user" + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert isinstance(result.messages[2], TextMessage) + assert result.messages[2].content == "Input task from main agent" + assert isinstance(result.messages[3], ToolCallRequestEvent) + assert isinstance(result.messages[4], ToolCallExecutionEvent) + assert isinstance(result.messages[5], TextMessage) + assert result.messages[5].content == "Summary from tool agent" + assert isinstance(result.messages[6], ToolCallExecutionEvent) + assert isinstance(result.messages[7], TextMessage) + assert result.messages[7].content == "Summary from main agent" + + +@pytest.mark.asyncio +async def test_team_tool_stream() -> None: + """Test running a task with TeamTool in streaming mode.""" + agent1 = _EchoAgent("Agent1", "An agent for testing") + agent2 = _EchoAgent("Agent2", "Another agent for testing") + termination = MaxMessageTermination(max_messages=3) + team = RoundRobinGroupChat( + [agent1, agent2], + termination_condition=termination, + ) + tool = TeamTool(team=team, name="team_tool", description="A team tool for testing") + + model_client = ReplayChatCompletionClient( + [ + CreateResult( + content=[FunctionCall(name="team_tool", arguments='{"task": "test task from main agent"}', id="1")], + finish_reason="function_calls", + usage=RequestUsage(prompt_tokens=0, completion_tokens=0), + cached=False, + ), + "Summary from main agent", + ], + model_info={ + "family": "gpt-41", + "function_calling": True, + "json_output": True, + "multiple_system_messages": True, + "structured_output": True, + "vision": True, + }, + ) + main_agent = AssistantAgent( + name="main_agent", + model_client=model_client, + tools=[tool], + reflect_on_tool_use=True, + description="An agent for testing", + ) + result = await main_agent.run(task="test task from user", cancellation_token=CancellationToken()) + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].content == "test task from user" + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert isinstance(result.messages[2], TextMessage) + assert result.messages[2].content == "test task from main agent" + assert isinstance(result.messages[3], TextMessage) + assert result.messages[3].content == "test task from main agent" + assert result.messages[3].source == "Agent1" + assert isinstance(result.messages[4], TextMessage) + assert result.messages[4].content == "test task from main agent" + assert result.messages[4].source == "Agent2" + assert isinstance(result.messages[5], ToolCallExecutionEvent) + assert isinstance(result.messages[6], TextMessage) + assert result.messages[6].content == "Summary from main agent"