This commit is contained in:
Eric Zhu 2025-06-22 09:03:56 +08:00
parent b158cb8b5e
commit 336f78870d
2 changed files with 150 additions and 15 deletions

View File

@ -32,7 +32,7 @@ from autogen_core.models import (
ModelFamily,
SystemMessage,
)
from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, StaticWorkbench, ToolResult, Workbench
from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, ToolResult, Workbench
from pydantic import BaseModel
from typing_extensions import Self
@ -754,7 +754,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
else:
self._workbench = [workbench]
else:
self._workbench = [StaticWorkbench(self._tools)]
self._workbench = [StaticStreamWorkbench(self._tools)]
if model_context is not None:
self._model_context = model_context
@ -1051,6 +1051,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
yield tool_call_msg
# STEP 4B: Execute tool calls
# Use a queue to handle streaming results from tool calls.
stream = asyncio.Queue[BaseAgentEvent | BaseChatMessage | None]()
async def _execute_tool_calls(
@ -1069,25 +1070,24 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
for call in function_calls
]
)
# Signal the end of streaming by putting None in the queue.
stream.put_nowait(None)
return results
task = asyncio.create_task(_execute_tool_calls(model_result.content))
while True:
try:
event = await stream.get()
if event is None:
break
if isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
yield event
inner_messages.append(event)
else:
raise RuntimeError(f"Unexpected event type: {type(event)}")
except asyncio.CancelledError:
task.cancel()
raise
event = await stream.get()
if event is None:
# End of streaming, break the loop.
break
if isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
yield event
inner_messages.append(event)
else:
raise RuntimeError(f"Unexpected event type: {type(event)}")
# Wait for all tool calls to complete.
executed_calls_and_results = await task
exec_results = [result for _, result in executed_calls_and_results]
@ -1377,7 +1377,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
if isinstance(event, ToolResult):
tool_result = event
elif isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
stream.put_nowait(event)
await stream.put(event)
else:
warnings.warn(
f"Unexpected event type: {type(event)} in tool call streaming.",

View File

@ -1,11 +1,14 @@
import pytest
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.messages import TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.tools import AgentTool, TeamTool
from autogen_core import (
CancellationToken,
FunctionCall,
)
from autogen_core.models import CreateResult, RequestUsage
from autogen_ext.models.replay import ReplayChatCompletionClient
from test_group_chat import _EchoAgent # type: ignore[reportPrivateUsage]
@ -98,3 +101,135 @@ async def test_team_tool_component() -> None:
assert tool2.name == "Team Tool"
assert tool2.description == "A team tool for testing"
assert isinstance(tool2._team, RoundRobinGroupChat) # type: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_agent_tool_stream() -> None:
"""Test running a task with AgentTool in streaming mode."""
def _query_function() -> str:
return "Test task"
tool_agent_model_client = ReplayChatCompletionClient(
[
CreateResult(
content=[FunctionCall(name="query_function", arguments="{}", id="1")],
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
),
"Summary from tool agent",
],
model_info={
"family": "gpt-41",
"function_calling": True,
"json_output": True,
"multiple_system_messages": True,
"structured_output": True,
"vision": True,
},
)
tool_agent = AssistantAgent(
name="tool_agent",
model_client=tool_agent_model_client,
tools=[_query_function],
reflect_on_tool_use=True,
description="An agent for testing",
)
tool = AgentTool(tool_agent)
main_agent_model_client = ReplayChatCompletionClient(
[
CreateResult(
content=[FunctionCall(id="1", name="tool_agent", arguments='{"task": "Input task from main agent"}')],
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
),
"Summary from main agent",
],
model_info={
"family": "gpt-41",
"function_calling": True,
"json_output": True,
"multiple_system_messages": True,
"structured_output": True,
"vision": True,
},
)
main_agent = AssistantAgent(
name="main_agent",
model_client=main_agent_model_client,
tools=[tool],
reflect_on_tool_use=True,
description="An agent for testing",
)
result = await main_agent.run(task="Input task from user", cancellation_token=CancellationToken())
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "Input task from user"
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert isinstance(result.messages[2], TextMessage)
assert result.messages[2].content == "Input task from main agent"
assert isinstance(result.messages[3], ToolCallRequestEvent)
assert isinstance(result.messages[4], ToolCallExecutionEvent)
assert isinstance(result.messages[5], TextMessage)
assert result.messages[5].content == "Summary from tool agent"
assert isinstance(result.messages[6], ToolCallExecutionEvent)
assert isinstance(result.messages[7], TextMessage)
assert result.messages[7].content == "Summary from main agent"
@pytest.mark.asyncio
async def test_team_tool_stream() -> None:
"""Test running a task with TeamTool in streaming mode."""
agent1 = _EchoAgent("Agent1", "An agent for testing")
agent2 = _EchoAgent("Agent2", "Another agent for testing")
termination = MaxMessageTermination(max_messages=3)
team = RoundRobinGroupChat(
[agent1, agent2],
termination_condition=termination,
)
tool = TeamTool(team=team, name="team_tool", description="A team tool for testing")
model_client = ReplayChatCompletionClient(
[
CreateResult(
content=[FunctionCall(name="team_tool", arguments='{"task": "test task from main agent"}', id="1")],
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
),
"Summary from main agent",
],
model_info={
"family": "gpt-41",
"function_calling": True,
"json_output": True,
"multiple_system_messages": True,
"structured_output": True,
"vision": True,
},
)
main_agent = AssistantAgent(
name="main_agent",
model_client=model_client,
tools=[tool],
reflect_on_tool_use=True,
description="An agent for testing",
)
result = await main_agent.run(task="test task from user", cancellation_token=CancellationToken())
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "test task from user"
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert isinstance(result.messages[2], TextMessage)
assert result.messages[2].content == "test task from main agent"
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].content == "test task from main agent"
assert result.messages[3].source == "Agent1"
assert isinstance(result.messages[4], TextMessage)
assert result.messages[4].content == "test task from main agent"
assert result.messages[4].source == "Agent2"
assert isinstance(result.messages[5], ToolCallExecutionEvent)
assert isinstance(result.messages[6], TextMessage)
assert result.messages[6].content == "Summary from main agent"