mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-03 07:04:16 +00:00
update
This commit is contained in:
parent
b158cb8b5e
commit
336f78870d
@ -32,7 +32,7 @@ from autogen_core.models import (
|
|||||||
ModelFamily,
|
ModelFamily,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, StaticWorkbench, ToolResult, Workbench
|
from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, ToolResult, Workbench
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
@ -754,7 +754,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||||||
else:
|
else:
|
||||||
self._workbench = [workbench]
|
self._workbench = [workbench]
|
||||||
else:
|
else:
|
||||||
self._workbench = [StaticWorkbench(self._tools)]
|
self._workbench = [StaticStreamWorkbench(self._tools)]
|
||||||
|
|
||||||
if model_context is not None:
|
if model_context is not None:
|
||||||
self._model_context = model_context
|
self._model_context = model_context
|
||||||
@ -1051,6 +1051,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||||||
yield tool_call_msg
|
yield tool_call_msg
|
||||||
|
|
||||||
# STEP 4B: Execute tool calls
|
# STEP 4B: Execute tool calls
|
||||||
|
# Use a queue to handle streaming results from tool calls.
|
||||||
stream = asyncio.Queue[BaseAgentEvent | BaseChatMessage | None]()
|
stream = asyncio.Queue[BaseAgentEvent | BaseChatMessage | None]()
|
||||||
|
|
||||||
async def _execute_tool_calls(
|
async def _execute_tool_calls(
|
||||||
@ -1069,25 +1070,24 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||||||
for call in function_calls
|
for call in function_calls
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
# Signal the end of streaming by putting None in the queue.
|
||||||
stream.put_nowait(None)
|
stream.put_nowait(None)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
task = asyncio.create_task(_execute_tool_calls(model_result.content))
|
task = asyncio.create_task(_execute_tool_calls(model_result.content))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
|
||||||
event = await stream.get()
|
event = await stream.get()
|
||||||
if event is None:
|
if event is None:
|
||||||
|
# End of streaming, break the loop.
|
||||||
break
|
break
|
||||||
if isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
|
if isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
|
||||||
yield event
|
yield event
|
||||||
inner_messages.append(event)
|
inner_messages.append(event)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unexpected event type: {type(event)}")
|
raise RuntimeError(f"Unexpected event type: {type(event)}")
|
||||||
except asyncio.CancelledError:
|
|
||||||
task.cancel()
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
# Wait for all tool calls to complete.
|
||||||
executed_calls_and_results = await task
|
executed_calls_and_results = await task
|
||||||
exec_results = [result for _, result in executed_calls_and_results]
|
exec_results = [result for _, result in executed_calls_and_results]
|
||||||
|
|
||||||
@ -1377,7 +1377,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
|||||||
if isinstance(event, ToolResult):
|
if isinstance(event, ToolResult):
|
||||||
tool_result = event
|
tool_result = event
|
||||||
elif isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
|
elif isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
|
||||||
stream.put_nowait(event)
|
await stream.put(event)
|
||||||
else:
|
else:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Unexpected event type: {type(event)} in tool call streaming.",
|
f"Unexpected event type: {type(event)} in tool call streaming.",
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from autogen_agentchat.agents import AssistantAgent
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
from autogen_agentchat.conditions import MaxMessageTermination
|
from autogen_agentchat.conditions import MaxMessageTermination
|
||||||
|
from autogen_agentchat.messages import TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent
|
||||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||||
from autogen_agentchat.tools import AgentTool, TeamTool
|
from autogen_agentchat.tools import AgentTool, TeamTool
|
||||||
from autogen_core import (
|
from autogen_core import (
|
||||||
CancellationToken,
|
CancellationToken,
|
||||||
|
FunctionCall,
|
||||||
)
|
)
|
||||||
|
from autogen_core.models import CreateResult, RequestUsage
|
||||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||||
from test_group_chat import _EchoAgent # type: ignore[reportPrivateUsage]
|
from test_group_chat import _EchoAgent # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
@ -98,3 +101,135 @@ async def test_team_tool_component() -> None:
|
|||||||
assert tool2.name == "Team Tool"
|
assert tool2.name == "Team Tool"
|
||||||
assert tool2.description == "A team tool for testing"
|
assert tool2.description == "A team tool for testing"
|
||||||
assert isinstance(tool2._team, RoundRobinGroupChat) # type: ignore[reportPrivateUsage]
|
assert isinstance(tool2._team, RoundRobinGroupChat) # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_tool_stream() -> None:
|
||||||
|
"""Test running a task with AgentTool in streaming mode."""
|
||||||
|
|
||||||
|
def _query_function() -> str:
|
||||||
|
return "Test task"
|
||||||
|
|
||||||
|
tool_agent_model_client = ReplayChatCompletionClient(
|
||||||
|
[
|
||||||
|
CreateResult(
|
||||||
|
content=[FunctionCall(name="query_function", arguments="{}", id="1")],
|
||||||
|
finish_reason="function_calls",
|
||||||
|
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
|
||||||
|
cached=False,
|
||||||
|
),
|
||||||
|
"Summary from tool agent",
|
||||||
|
],
|
||||||
|
model_info={
|
||||||
|
"family": "gpt-41",
|
||||||
|
"function_calling": True,
|
||||||
|
"json_output": True,
|
||||||
|
"multiple_system_messages": True,
|
||||||
|
"structured_output": True,
|
||||||
|
"vision": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tool_agent = AssistantAgent(
|
||||||
|
name="tool_agent",
|
||||||
|
model_client=tool_agent_model_client,
|
||||||
|
tools=[_query_function],
|
||||||
|
reflect_on_tool_use=True,
|
||||||
|
description="An agent for testing",
|
||||||
|
)
|
||||||
|
tool = AgentTool(tool_agent)
|
||||||
|
|
||||||
|
main_agent_model_client = ReplayChatCompletionClient(
|
||||||
|
[
|
||||||
|
CreateResult(
|
||||||
|
content=[FunctionCall(id="1", name="tool_agent", arguments='{"task": "Input task from main agent"}')],
|
||||||
|
finish_reason="function_calls",
|
||||||
|
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
|
||||||
|
cached=False,
|
||||||
|
),
|
||||||
|
"Summary from main agent",
|
||||||
|
],
|
||||||
|
model_info={
|
||||||
|
"family": "gpt-41",
|
||||||
|
"function_calling": True,
|
||||||
|
"json_output": True,
|
||||||
|
"multiple_system_messages": True,
|
||||||
|
"structured_output": True,
|
||||||
|
"vision": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
main_agent = AssistantAgent(
|
||||||
|
name="main_agent",
|
||||||
|
model_client=main_agent_model_client,
|
||||||
|
tools=[tool],
|
||||||
|
reflect_on_tool_use=True,
|
||||||
|
description="An agent for testing",
|
||||||
|
)
|
||||||
|
result = await main_agent.run(task="Input task from user", cancellation_token=CancellationToken())
|
||||||
|
assert isinstance(result.messages[0], TextMessage)
|
||||||
|
assert result.messages[0].content == "Input task from user"
|
||||||
|
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
||||||
|
assert isinstance(result.messages[2], TextMessage)
|
||||||
|
assert result.messages[2].content == "Input task from main agent"
|
||||||
|
assert isinstance(result.messages[3], ToolCallRequestEvent)
|
||||||
|
assert isinstance(result.messages[4], ToolCallExecutionEvent)
|
||||||
|
assert isinstance(result.messages[5], TextMessage)
|
||||||
|
assert result.messages[5].content == "Summary from tool agent"
|
||||||
|
assert isinstance(result.messages[6], ToolCallExecutionEvent)
|
||||||
|
assert isinstance(result.messages[7], TextMessage)
|
||||||
|
assert result.messages[7].content == "Summary from main agent"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_tool_stream() -> None:
|
||||||
|
"""Test running a task with TeamTool in streaming mode."""
|
||||||
|
agent1 = _EchoAgent("Agent1", "An agent for testing")
|
||||||
|
agent2 = _EchoAgent("Agent2", "Another agent for testing")
|
||||||
|
termination = MaxMessageTermination(max_messages=3)
|
||||||
|
team = RoundRobinGroupChat(
|
||||||
|
[agent1, agent2],
|
||||||
|
termination_condition=termination,
|
||||||
|
)
|
||||||
|
tool = TeamTool(team=team, name="team_tool", description="A team tool for testing")
|
||||||
|
|
||||||
|
model_client = ReplayChatCompletionClient(
|
||||||
|
[
|
||||||
|
CreateResult(
|
||||||
|
content=[FunctionCall(name="team_tool", arguments='{"task": "test task from main agent"}', id="1")],
|
||||||
|
finish_reason="function_calls",
|
||||||
|
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
|
||||||
|
cached=False,
|
||||||
|
),
|
||||||
|
"Summary from main agent",
|
||||||
|
],
|
||||||
|
model_info={
|
||||||
|
"family": "gpt-41",
|
||||||
|
"function_calling": True,
|
||||||
|
"json_output": True,
|
||||||
|
"multiple_system_messages": True,
|
||||||
|
"structured_output": True,
|
||||||
|
"vision": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
main_agent = AssistantAgent(
|
||||||
|
name="main_agent",
|
||||||
|
model_client=model_client,
|
||||||
|
tools=[tool],
|
||||||
|
reflect_on_tool_use=True,
|
||||||
|
description="An agent for testing",
|
||||||
|
)
|
||||||
|
result = await main_agent.run(task="test task from user", cancellation_token=CancellationToken())
|
||||||
|
assert isinstance(result.messages[0], TextMessage)
|
||||||
|
assert result.messages[0].content == "test task from user"
|
||||||
|
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
||||||
|
assert isinstance(result.messages[2], TextMessage)
|
||||||
|
assert result.messages[2].content == "test task from main agent"
|
||||||
|
assert isinstance(result.messages[3], TextMessage)
|
||||||
|
assert result.messages[3].content == "test task from main agent"
|
||||||
|
assert result.messages[3].source == "Agent1"
|
||||||
|
assert isinstance(result.messages[4], TextMessage)
|
||||||
|
assert result.messages[4].content == "test task from main agent"
|
||||||
|
assert result.messages[4].source == "Agent2"
|
||||||
|
assert isinstance(result.messages[5], ToolCallExecutionEvent)
|
||||||
|
assert isinstance(result.messages[6], TextMessage)
|
||||||
|
assert result.messages[6].content == "Summary from main agent"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user