diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 6ecf6cf73..3e50c0e94 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -23,7 +23,7 @@ from autogen_core.models import ( ModelFamily, SystemMessage, ) -from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench, Workbench +from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, ToolResult, Workbench from pydantic import BaseModel from typing_extensions import Self @@ -745,7 +745,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): else: self._workbench = [workbench] else: - self._workbench = [StaticWorkbench(self._tools)] + self._workbench = [StaticStreamWorkbench(self._tools)] if model_context is not None: self._model_context = model_context @@ -1053,18 +1053,44 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): yield tool_call_msg # STEP 4B: Execute tool calls - executed_calls_and_results = await asyncio.gather( - *[ - cls._execute_tool_call( - tool_call=call, - workbench=workbench, - handoff_tools=handoff_tools, - agent_name=agent_name, - cancellation_token=cancellation_token, - ) - for call in model_result.content - ] - ) + # Use a queue to handle streaming results from tool calls. + stream = asyncio.Queue[BaseAgentEvent | BaseChatMessage | None]() + + async def _execute_tool_calls( + function_calls: List[FunctionCall], + ) -> List[Tuple[FunctionCall, FunctionExecutionResult]]: + results = await asyncio.gather( + *[ + cls._execute_tool_call( + tool_call=call, + workbench=workbench, + handoff_tools=handoff_tools, + agent_name=agent_name, + cancellation_token=cancellation_token, + stream=stream, + ) + for call in function_calls + ] + ) + # Signal the end of streaming by putting None in the queue. + stream.put_nowait(None) + return results + + task = asyncio.create_task(_execute_tool_calls(model_result.content)) + + while True: + event = await stream.get() + if event is None: + # End of streaming, break the loop. + break + if isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage): + yield event + inner_messages.append(event) + else: + raise RuntimeError(f"Unexpected event type: {type(event)}") + + # Wait for all tool calls to complete. + executed_calls_and_results = await task exec_results = [result for _, result in executed_calls_and_results] # Yield ToolCallExecutionEvent @@ -1311,6 +1337,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): handoff_tools: List[BaseTool[Any, Any]], agent_name: str, cancellation_token: CancellationToken, + stream: asyncio.Queue[BaseAgentEvent | BaseChatMessage | None], ) -> Tuple[FunctionCall, FunctionExecutionResult]: """Execute a single tool call and return the result.""" # Load the arguments from the tool call. @@ -1348,18 +1375,38 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): for wb in workbench: tools = await wb.list_tools() if any(t["name"] == tool_call.name for t in tools): - result = await wb.call_tool( - name=tool_call.name, - arguments=arguments, - cancellation_token=cancellation_token, - call_id=tool_call.id, - ) + if isinstance(wb, StaticStreamWorkbench): + tool_result: ToolResult | None = None + async for event in wb.call_tool_stream( + name=tool_call.name, + arguments=arguments, + cancellation_token=cancellation_token, + call_id=tool_call.id, + ): + if isinstance(event, ToolResult): + tool_result = event + elif isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage): + await stream.put(event) + else: + warnings.warn( + f"Unexpected event type: {type(event)} in tool call streaming.", + UserWarning, + stacklevel=2, + ) + assert isinstance(tool_result, ToolResult), "Tool result should not be None in streaming mode." + else: + tool_result = await wb.call_tool( + name=tool_call.name, + arguments=arguments, + cancellation_token=cancellation_token, + call_id=tool_call.id, + ) return ( tool_call, FunctionExecutionResult( - content=result.to_text(), + content=tool_result.to_text(), call_id=tool_call.id, - is_error=result.is_error, + is_error=tool_result.is_error, name=tool_call.name, ), ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_agent.py index 3d6ff9160..e4d51195d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_agent.py @@ -11,6 +11,10 @@ class AgentToolConfig(BaseModel): """Configuration for the AgentTool.""" agent: ComponentModel + """The agent to be used for running the task.""" + + return_value_as_last_message: bool = False + """Whether to return the value as the last message of the task result.""" class AgentTool(TaskRunnerTool, Component[AgentToolConfig]): @@ -20,6 +24,11 @@ class AgentTool(TaskRunnerTool, Component[AgentToolConfig]): Args: agent (BaseChatAgent): The agent to be used for running the task. + return_value_as_last_message (bool): Whether to use the last message content of the task result + as the return value of the tool in :meth:`~autogen_agentchat.tools.TaskRunnerTool.return_value_as_string`. + If set to True, the last message content will be returned as a string. + If set to False, the tool will return all messages in the task result as a string concatenated together, + with each message prefixed by its source (e.g., "writer: ...", "assistant: ..."). Example: @@ -57,15 +66,18 @@ class AgentTool(TaskRunnerTool, Component[AgentToolConfig]): component_config_schema = AgentToolConfig component_provider_override = "autogen_agentchat.tools.AgentTool" - def __init__(self, agent: BaseChatAgent) -> None: + def __init__(self, agent: BaseChatAgent, return_value_as_last_message: bool = False) -> None: self._agent = agent - super().__init__(agent, agent.name, agent.description) + super().__init__( + agent, agent.name, agent.description, return_value_as_last_message=return_value_as_last_message + ) def _to_config(self) -> AgentToolConfig: return AgentToolConfig( agent=self._agent.dump_component(), + return_value_as_last_message=self._return_value_as_last_message, ) @classmethod def _from_config(cls, config: AgentToolConfig) -> Self: - return cls(BaseChatAgent.load_component(config.agent)) + return cls(BaseChatAgent.load_component(config.agent), config.return_value_as_last_message) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_task_runner_tool.py b/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_task_runner_tool.py index 5aa8672d7..0db95ed2b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_task_runner_tool.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_task_runner_tool.py @@ -1,13 +1,13 @@ from abc import ABC -from typing import Annotated, Any, List, Mapping +from typing import Annotated, Any, AsyncGenerator, List, Mapping from autogen_core import CancellationToken -from autogen_core.tools import BaseTool +from autogen_core.tools import BaseStreamTool from pydantic import BaseModel from ..agents import BaseChatAgent from ..base import TaskResult -from ..messages import BaseChatMessage +from ..messages import BaseAgentEvent, BaseChatMessage from ..teams import BaseGroupChat @@ -17,13 +17,20 @@ class TaskRunnerToolArgs(BaseModel): task: Annotated[str, "The task to be executed."] -class TaskRunnerTool(BaseTool[TaskRunnerToolArgs, TaskResult], ABC): +class TaskRunnerTool(BaseStreamTool[TaskRunnerToolArgs, BaseAgentEvent | BaseChatMessage, TaskResult], ABC): """An base class for tool that can be used to run a task using a team or an agent.""" component_type = "tool" - def __init__(self, task_runner: BaseGroupChat | BaseChatAgent, name: str, description: str) -> None: + def __init__( + self, + task_runner: BaseGroupChat | BaseChatAgent, + name: str, + description: str, + return_value_as_last_message: bool, + ) -> None: self._task_runner = task_runner + self._return_value_as_last_message = return_value_as_last_message super().__init__( args_type=TaskRunnerToolArgs, return_type=TaskResult, @@ -32,10 +39,23 @@ class TaskRunnerTool(BaseTool[TaskRunnerToolArgs, TaskResult], ABC): ) async def run(self, args: TaskRunnerToolArgs, cancellation_token: CancellationToken) -> TaskResult: + """Run the task and return the result.""" return await self._task_runner.run(task=args.task, cancellation_token=cancellation_token) + async def run_stream( + self, args: TaskRunnerToolArgs, cancellation_token: CancellationToken + ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]: + """Run the task and yield events or messages as they are produced, the final :class:`TaskResult` + will be yielded at the end.""" + async for event in self._task_runner.run_stream(task=args.task, cancellation_token=cancellation_token): + yield event + def return_value_as_string(self, value: TaskResult) -> str: """Convert the task result to a string.""" + if self._return_value_as_last_message: + if value.messages and isinstance(value.messages[-1], BaseChatMessage): + return value.messages[-1].to_model_text() + raise ValueError("The last message is not a BaseChatMessage.") parts: List[str] = [] for message in value.messages: if isinstance(message, BaseChatMessage): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_team.py b/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_team.py index c7d9e6025..28bbb0daa 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_team.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/tools/_team.py @@ -11,8 +11,13 @@ class TeamToolConfig(BaseModel): """Configuration for the TeamTool.""" name: str + """The name of the tool.""" description: str + """The name and description of the tool.""" team: ComponentModel + """The team to be used for running the task.""" + return_value_as_last_message: bool = False + """Whether to return the value as the last message of the task result.""" class TeamTool(TaskRunnerTool, Component[TeamToolConfig]): @@ -24,22 +29,92 @@ class TeamTool(TaskRunnerTool, Component[TeamToolConfig]): team (BaseGroupChat): The team to be used for running the task. name (str): The name of the tool. description (str): The description of the tool. + return_value_as_last_message (bool): Whether to use the last message content of the task result + as the return value of the tool in :meth:`~autogen_agentchat.tools.TaskRunnerTool.return_value_as_string`. + If set to True, the last message content will be returned as a string. + If set to False, the tool will return all messages in the task result as a string concatenated together, + with each message prefixed by its source (e.g., "writer: ...", "assistant: ..."). + + Example: + + .. code-block:: python + + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.conditions import SourceMatchTermination + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_agentchat.tools import TeamTool + from autogen_agentchat.ui import Console + from autogen_ext.models.ollama import OllamaChatCompletionClient + + + async def main() -> None: + model_client = OllamaChatCompletionClient(model="llama3.2") + + writer = AssistantAgent(name="writer", model_client=model_client, system_message="You are a helpful assistant.") + reviewer = AssistantAgent( + name="reviewer", model_client=model_client, system_message="You are a critical reviewer." + ) + summarizer = AssistantAgent( + name="summarizer", + model_client=model_client, + system_message="You combine the review and produce a revised response.", + ) + team = RoundRobinGroupChat( + [writer, reviewer, summarizer], termination_condition=SourceMatchTermination(sources=["summarizer"]) + ) + + # Create a TeamTool that uses the team to run tasks, returning the last message as the result. + tool = TeamTool( + team=team, name="writing_team", description="A tool for writing tasks.", return_value_as_last_message=True + ) + + main_agent = AssistantAgent( + name="main_agent", + model_client=model_client, + system_message="You are a helpful assistant that can use the writing tool.", + tools=[tool], + ) + # For handling each events manually. + # async for message in main_agent.run_stream( + # task="Write a short story about a robot learning to love.", + # ): + # print(message) + # Use Console to display the messages in a more readable format. + await Console( + main_agent.run_stream( + task="Write a short story about a robot learning to love.", + ) + ) + + + if __name__ == "__main__": + import asyncio + + asyncio.run(main()) """ component_config_schema = TeamToolConfig component_provider_override = "autogen_agentchat.tools.TeamTool" - def __init__(self, team: BaseGroupChat, name: str, description: str) -> None: + def __init__( + self, team: BaseGroupChat, name: str, description: str, return_value_as_last_message: bool = False + ) -> None: self._team = team - super().__init__(team, name, description) + super().__init__(team, name, description, return_value_as_last_message=return_value_as_last_message) def _to_config(self) -> TeamToolConfig: return TeamToolConfig( name=self._name, description=self._description, team=self._team.dump_component(), + return_value_as_last_message=self._return_value_as_last_message, ) @classmethod def _from_config(cls, config: TeamToolConfig) -> Self: - return cls(BaseGroupChat.load_component(config.team), config.name, config.description) + return cls( + BaseGroupChat.load_component(config.team), + config.name, + config.description, + config.return_value_as_last_message, + ) diff --git a/python/packages/autogen-agentchat/tests/test_task_runner_tool.py b/python/packages/autogen-agentchat/tests/test_task_runner_tool.py index 3d3d58b7d..244952459 100644 --- a/python/packages/autogen-agentchat/tests/test_task_runner_tool.py +++ b/python/packages/autogen-agentchat/tests/test_task_runner_tool.py @@ -1,11 +1,14 @@ import pytest from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import MaxMessageTermination +from autogen_agentchat.messages import TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.tools import AgentTool, TeamTool from autogen_core import ( CancellationToken, + FunctionCall, ) +from autogen_core.models import CreateResult, RequestUsage from autogen_ext.models.replay import ReplayChatCompletionClient from test_group_chat import _EchoAgent # type: ignore[reportPrivateUsage] @@ -98,3 +101,139 @@ async def test_team_tool_component() -> None: assert tool2.name == "Team Tool" assert tool2.description == "A team tool for testing" assert isinstance(tool2._team, RoundRobinGroupChat) # type: ignore[reportPrivateUsage] + + +@pytest.mark.asyncio +async def test_agent_tool_stream() -> None: + """Test running a task with AgentTool in streaming mode.""" + + def _query_function() -> str: + return "Test task" + + tool_agent_model_client = ReplayChatCompletionClient( + [ + CreateResult( + content=[FunctionCall(name="query_function", arguments="{}", id="1")], + finish_reason="function_calls", + usage=RequestUsage(prompt_tokens=0, completion_tokens=0), + cached=False, + ), + "Summary from tool agent", + ], + model_info={ + "family": "gpt-41", + "function_calling": True, + "json_output": True, + "multiple_system_messages": True, + "structured_output": True, + "vision": True, + }, + ) + tool_agent = AssistantAgent( + name="tool_agent", + model_client=tool_agent_model_client, + tools=[_query_function], + reflect_on_tool_use=True, + description="An agent for testing", + ) + tool = AgentTool(tool_agent) + + main_agent_model_client = ReplayChatCompletionClient( + [ + CreateResult( + content=[FunctionCall(id="1", name="tool_agent", arguments='{"task": "Input task from main agent"}')], + finish_reason="function_calls", + usage=RequestUsage(prompt_tokens=0, completion_tokens=0), + cached=False, + ), + "Summary from main agent", + ], + model_info={ + "family": "gpt-41", + "function_calling": True, + "json_output": True, + "multiple_system_messages": True, + "structured_output": True, + "vision": True, + }, + ) + + main_agent = AssistantAgent( + name="main_agent", + model_client=main_agent_model_client, + tools=[tool], + reflect_on_tool_use=True, + description="An agent for testing", + ) + result = await main_agent.run(task="Input task from user", cancellation_token=CancellationToken()) + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].content == "Input task from user" + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert isinstance(result.messages[2], TextMessage) + assert result.messages[2].content == "Input task from main agent" + assert isinstance(result.messages[3], ToolCallRequestEvent) + assert isinstance(result.messages[4], ToolCallExecutionEvent) + assert isinstance(result.messages[5], TextMessage) + assert result.messages[5].content == "Summary from tool agent" + assert isinstance(result.messages[6], ToolCallExecutionEvent) + assert result.messages[6].content[0].content == "tool_agent: Summary from tool agent" + assert isinstance(result.messages[7], TextMessage) + assert result.messages[7].content == "Summary from main agent" + + +@pytest.mark.asyncio +async def test_team_tool_stream() -> None: + """Test running a task with TeamTool in streaming mode.""" + agent1 = _EchoAgent("Agent1", "An agent for testing") + agent2 = _EchoAgent("Agent2", "Another agent for testing") + termination = MaxMessageTermination(max_messages=3) + team = RoundRobinGroupChat( + [agent1, agent2], + termination_condition=termination, + ) + tool = TeamTool( + team=team, name="team_tool", description="A team tool for testing", return_value_as_last_message=True + ) + + model_client = ReplayChatCompletionClient( + [ + CreateResult( + content=[FunctionCall(name="team_tool", arguments='{"task": "test task from main agent"}', id="1")], + finish_reason="function_calls", + usage=RequestUsage(prompt_tokens=0, completion_tokens=0), + cached=False, + ), + "Summary from main agent", + ], + model_info={ + "family": "gpt-41", + "function_calling": True, + "json_output": True, + "multiple_system_messages": True, + "structured_output": True, + "vision": True, + }, + ) + main_agent = AssistantAgent( + name="main_agent", + model_client=model_client, + tools=[tool], + reflect_on_tool_use=True, + description="An agent for testing", + ) + result = await main_agent.run(task="test task from user", cancellation_token=CancellationToken()) + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].content == "test task from user" + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert isinstance(result.messages[2], TextMessage) + assert result.messages[2].content == "test task from main agent" + assert isinstance(result.messages[3], TextMessage) + assert result.messages[3].content == "test task from main agent" + assert result.messages[3].source == "Agent1" + assert isinstance(result.messages[4], TextMessage) + assert result.messages[4].content == "test task from main agent" + assert result.messages[4].source == "Agent2" + assert isinstance(result.messages[5], ToolCallExecutionEvent) + assert result.messages[5].content[0].content == "test task from main agent" + assert isinstance(result.messages[6], TextMessage) + assert result.messages[6].content == "Summary from main agent" diff --git a/python/packages/autogen-core/src/autogen_core/tools/__init__.py b/python/packages/autogen-core/src/autogen_core/tools/__init__.py index 4f48463a0..d6bbaf577 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/tools/__init__.py @@ -1,18 +1,21 @@ -from ._base import BaseTool, BaseToolWithState, ParametersSchema, Tool, ToolSchema +from ._base import BaseStreamTool, BaseTool, BaseToolWithState, ParametersSchema, StreamTool, Tool, ToolSchema from ._function_tool import FunctionTool -from ._static_workbench import StaticWorkbench +from ._static_workbench import StaticStreamWorkbench, StaticWorkbench from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench __all__ = [ "Tool", + "StreamTool", "ToolSchema", "ParametersSchema", "BaseTool", "BaseToolWithState", + "BaseStreamTool", "FunctionTool", "Workbench", "ToolResult", "TextResultContent", "ImageResultContent", "StaticWorkbench", + "StaticStreamWorkbench", ] diff --git a/python/packages/autogen-core/src/autogen_core/tools/_base.py b/python/packages/autogen-core/src/autogen_core/tools/_base.py index 6292e259b..8936c9361 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_base.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_base.py @@ -2,7 +2,7 @@ import json import logging from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable +from typing import Any, AsyncGenerator, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable import jsonref from pydantic import BaseModel @@ -61,9 +61,17 @@ class Tool(Protocol): async def load_state_json(self, state: Mapping[str, Any]) -> None: ... +@runtime_checkable +class StreamTool(Tool, Protocol): + def run_json_stream( + self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None + ) -> AsyncGenerator[Any, None]: ... + + ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True) ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True) StateT = TypeVar("StateT", bound=BaseModel) +StreamT = TypeVar("StreamT", bound=BaseModel, covariant=True) class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]): @@ -187,6 +195,59 @@ class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]): pass +class BaseStreamTool( + BaseTool[ArgsT, ReturnT], StreamTool, ABC, Generic[ArgsT, StreamT, ReturnT], ComponentBase[BaseModel] +): + component_type = "tool" + + @abstractmethod + def run_stream(self, args: ArgsT, cancellation_token: CancellationToken) -> AsyncGenerator[StreamT | ReturnT, None]: + """Run the tool with the provided arguments and return a stream of data and end with the final return value.""" + ... + + async def run_json_stream( + self, + args: Mapping[str, Any], + cancellation_token: CancellationToken, + call_id: str | None = None, + ) -> AsyncGenerator[StreamT | ReturnT, None]: + """Run the tool with the provided arguments in a dictionary and return a stream of data + from the tool's :meth:`run_stream` method and end with the final return value. + + Args: + args (Mapping[str, Any]): The arguments to pass to the tool. + cancellation_token (CancellationToken): A token to cancel the operation if needed. + call_id (str | None): An optional identifier for the tool call, used for tracing. + + Returns: + AsyncGenerator[StreamT | ReturnT, None]: A generator yielding results from the tool's :meth:`run_stream` method. + """ + return_value: ReturnT | StreamT | None = None + with trace_tool_span( + tool_name=self._name, + tool_description=self._description, + tool_call_id=call_id, + ): + # Execute the tool's run_stream method + async for result in self.run_stream(self._args_type.model_validate(args), cancellation_token): + return_value = result + yield result + + assert return_value is not None, "The tool must yield a final return value at the end of the stream." + if not isinstance(return_value, self._return_type): + raise TypeError( + f"Expected return value of type {self._return_type.__name__}, but got {type(return_value).__name__}" + ) + + # Log the tool call event + event = ToolCallEvent( + tool_name=self.name, + arguments=dict(args), # Using the raw args passed to run_json + result=self.return_value_as_string(return_value), + ) + logger.info(event) + + class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]): def __init__( self, diff --git a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py index e11cad9ad..71e9ca4af 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py @@ -1,14 +1,14 @@ import asyncio import builtins -from typing import Any, Dict, List, Literal, Mapping +from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping from pydantic import BaseModel from typing_extensions import Self from .._cancellation_token import CancellationToken from .._component_config import Component, ComponentModel -from ._base import BaseTool, ToolSchema -from ._workbench import TextResultContent, ToolResult, Workbench +from ._base import BaseTool, StreamTool, ToolSchema +from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench class StaticWorkbenchConfig(BaseModel): @@ -108,3 +108,61 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): else: error_message += f"{str(error)}\n" return error_message.strip() + + +class StaticStreamWorkbench(StaticWorkbench, StreamWorkbench): + """ + A workbench that provides a static set of tools that do not change after + each tool execution, and supports streaming results. + """ + + component_provider_override = "autogen_core.tools.StaticStreamWorkbench" + + async def call_tool_stream( + self, + name: str, + arguments: Mapping[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + call_id: str | None = None, + ) -> AsyncGenerator[Any | ToolResult, None]: + tool = next((tool for tool in self._tools if tool.name == name), None) + if tool is None: + yield ToolResult( + name=name, + result=[TextResultContent(content=f"Tool {name} not found.")], + is_error=True, + ) + return + if not cancellation_token: + cancellation_token = CancellationToken() + if not arguments: + arguments = {} + try: + actual_tool_output: Any | None = None + if isinstance(tool, StreamTool): + previous_result: Any | None = None + try: + async for result in tool.run_json_stream(arguments, cancellation_token, call_id=call_id): + if previous_result is not None: + yield previous_result + previous_result = result + actual_tool_output = previous_result + except Exception as e: + # If there was a previous result before the exception, yield it first + if previous_result is not None: + yield previous_result + # Then yield the error result + result_str = self._format_errors(e) + yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=True) + return + else: + # If the tool is not a stream tool, we run it normally and yield the result + result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id)) + cancellation_token.link_future(result_future) + actual_tool_output = await result_future + is_error = False + result_str = tool.return_value_as_string(actual_tool_output) + except Exception as e: + result_str = self._format_errors(e) + is_error = True + yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error) diff --git a/python/packages/autogen-core/src/autogen_core/tools/_workbench.py b/python/packages/autogen-core/src/autogen_core/tools/_workbench.py index 5d387680f..7869c5d42 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_workbench.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_workbench.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from types import TracebackType -from typing import Any, List, Literal, Mapping, Optional, Type +from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Type from pydantic import BaseModel, Field from typing_extensions import Annotated, Self @@ -189,3 +189,28 @@ class Workbench(ABC, ComponentBase[BaseModel]): It calls the :meth:`~autogen_core.tools.WorkBench.stop` method to stop the workbench. """ await self.stop() + + +class StreamWorkbench(Workbench, ABC): + """A workbench that supports streaming results from tool calls.""" + + @abstractmethod + def call_tool_stream( + self, + name: str, + arguments: Mapping[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + call_id: str | None = None, + ) -> AsyncGenerator[Any | ToolResult, None]: + """ + Call a tool in the workbench and return a stream of results. + + Args: + name (str): The name of the tool to call. + arguments (Mapping[str, Any] | None): The arguments to pass to the tool + If None, the tool will be called with no arguments. + cancellation_token (CancellationToken | None): An optional cancellation token + to cancel the tool execution. + call_id (str | None): An optional identifier for the tool call, used for tracing. + """ + ... diff --git a/python/packages/autogen-core/tests/test_workbench.py b/python/packages/autogen-core/tests/test_workbench.py index 06e87e120..59f24b1c4 100644 --- a/python/packages/autogen-core/tests/test_workbench.py +++ b/python/packages/autogen-core/tests/test_workbench.py @@ -1,8 +1,73 @@ -from typing import Annotated +from typing import Annotated, AsyncGenerator import pytest +from autogen_core._cancellation_token import CancellationToken from autogen_core.code_executor import ImportFromModule -from autogen_core.tools import FunctionTool, StaticWorkbench, Workbench +from autogen_core.tools import ( + BaseStreamTool, + FunctionTool, + StaticStreamWorkbench, + StaticWorkbench, + TextResultContent, + ToolResult, + Workbench, +) +from pydantic import BaseModel + + +class StreamArgs(BaseModel): + count: int + + +class StreamResult(BaseModel): + final_count: int + + +class StreamItem(BaseModel): + current: int + + +class StreamTool(BaseStreamTool[StreamArgs, StreamItem, StreamResult]): + def __init__(self) -> None: + super().__init__( + args_type=StreamArgs, + return_type=StreamResult, + name="test_stream_tool", + description="A test stream tool that counts up to a number.", + ) + + async def run(self, args: StreamArgs, cancellation_token: CancellationToken) -> StreamResult: + # For the regular run method, just return the final result + return StreamResult(final_count=args.count) + + async def run_stream( + self, args: StreamArgs, cancellation_token: CancellationToken + ) -> AsyncGenerator[StreamItem | StreamResult, None]: + for i in range(1, args.count + 1): + if cancellation_token.is_cancelled(): + break + yield StreamItem(current=i) + yield StreamResult(final_count=args.count) + + +class StreamToolWithError(BaseStreamTool[StreamArgs, StreamItem, StreamResult]): + def __init__(self) -> None: + super().__init__( + args_type=StreamArgs, + return_type=StreamResult, + name="test_stream_tool_error", + description="A test stream tool that raises an error.", + ) + + async def run(self, args: StreamArgs, cancellation_token: CancellationToken) -> StreamResult: + # For the regular run method, just raise the error + raise ValueError("Stream tool error") + + async def run_stream( + self, args: StreamArgs, cancellation_token: CancellationToken + ) -> AsyncGenerator[StreamItem | StreamResult, None]: + yield StreamItem(current=1) + raise ValueError("Stream tool error") @pytest.mark.asyncio @@ -121,3 +186,152 @@ async def test_static_workbench() -> None: assert result_2.result[0].content == "This is a test error" assert result_2.to_text() == "This is a test error" assert result_2.is_error is True + + +@pytest.mark.asyncio +async def test_static_stream_workbench_call_tool_stream() -> None: + """Test call_tool_stream with streaming tools and regular tools.""" + + def regular_tool_func(x: Annotated[int, "The number to double."]) -> int: + return x * 2 + + regular_tool = FunctionTool( + regular_tool_func, + name="regular_tool", + description="A regular tool that doubles a number.", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + + stream_tool = StreamTool() + stream_tool_with_error = StreamToolWithError() + + async with StaticStreamWorkbench(tools=[regular_tool, stream_tool, stream_tool_with_error]) as workbench: + # Test streaming tool + results: list[StreamItem | StreamResult | ToolResult] = [] + async for result in workbench.call_tool_stream("test_stream_tool", {"count": 3}): + results.append(result) + + # Should get 3 intermediate results and 1 final result + assert len(results) == 4 + + # Check intermediate results (StreamItem objects) + for i, result in enumerate(results[:3]): + assert isinstance(result, StreamItem) + assert result.current == i + 1 + + # Check final result (ToolResult) + final_result = results[-1] + assert isinstance(final_result, ToolResult) + assert final_result.name == "test_stream_tool" + assert final_result.is_error is False + assert final_result.result[0].type == "TextResultContent" + assert "final_count" in final_result.result[0].content + + # Test regular (non-streaming) tool + results_regular: list[ToolResult] = [] + async for result in workbench.call_tool_stream("regular_tool", {"x": 5}): + results_regular.append(result) # type: ignore + + # Should get only 1 result for non-streaming tool + assert len(results_regular) == 1 + final_result = results_regular[0] + assert final_result.name == "regular_tool" + assert final_result.is_error is False + assert final_result.result[0].content == "10" + + # Test streaming tool with error + results_error: list[StreamItem | ToolResult] = [] + async for result in workbench.call_tool_stream("test_stream_tool_error", {"count": 3}): + results_error.append(result) # type: ignore + + # Should get 1 intermediate result and 1 error result + assert len(results_error) == 2 + + # Check intermediate result + intermediate_result = results_error[0] + assert isinstance(intermediate_result, StreamItem) + assert intermediate_result.current == 1 + + # Check error result + error_result = results_error[1] + assert isinstance(error_result, ToolResult) + assert error_result.name == "test_stream_tool_error" + assert error_result.is_error is True + result_content = error_result.result[0] + assert isinstance(result_content, TextResultContent) + assert "Stream tool error" in result_content.content + + # Test tool not found + results_not_found: list[ToolResult] = [] + async for result in workbench.call_tool_stream("nonexistent_tool", {"x": 5}): + results_not_found.append(result) # type: ignore + + assert len(results_not_found) == 1 + error_result = results_not_found[0] + assert error_result.name == "nonexistent_tool" + assert error_result.is_error is True + result_content = error_result.result[0] + assert isinstance(result_content, TextResultContent) + assert "Tool nonexistent_tool not found" in result_content.content + + # Test with no arguments + results_no_args: list[StreamItem | StreamResult | ToolResult] = [] + async for result in workbench.call_tool_stream("test_stream_tool", {"count": 1}): + results_no_args.append(result) # type: ignore + + assert len(results_no_args) == 2 # 1 intermediate + 1 final + + # Test with None arguments + results_none: list[ToolResult] = [] + async for result in workbench.call_tool_stream("regular_tool", None): + results_none.append(result) # type: ignore + + # Should still work but may get error due to missing required argument + assert len(results_none) == 1 + result = results_none[0] + assert result.name == "regular_tool" + # This should error because x is required + assert result.is_error is True + + +@pytest.mark.asyncio +async def test_static_stream_workbench_call_tool_stream_cancellation() -> None: + """Test call_tool_stream with cancellation token.""" + stream_tool = StreamTool() + + async with StaticStreamWorkbench(tools=[stream_tool]) as workbench: + # Test with cancellation token + cancellation_token = CancellationToken() + + results: list[StreamItem | StreamResult | ToolResult] = [] + async for result in workbench.call_tool_stream("test_stream_tool", {"count": 5}, cancellation_token): + results.append(result) # type: ignore + if len(results) == 2: # Cancel after 2 results + cancellation_token.cancel() + + # Should get at least 2 results before cancellation + assert len(results) >= 2 + + +@pytest.mark.asyncio +async def test_static_stream_workbench_inheritance() -> None: + """Test that StaticStreamWorkbench inherits from both StaticWorkbench and StreamWorkbench.""" + stream_tool = StreamTool() + + async with StaticStreamWorkbench(tools=[stream_tool]) as workbench: + # Test that it has regular workbench functionality + tools = await workbench.list_tools() + assert len(tools) == 1 + assert tools[0]["name"] == "test_stream_tool" + + # Test regular call_tool method + result = await workbench.call_tool("test_stream_tool", {"count": 2}) + assert result.name == "test_stream_tool" + assert result.is_error is False + + # Test streaming functionality exists + assert hasattr(workbench, "call_tool_stream") + results: list[StreamItem | StreamResult | ToolResult] = [] + async for result in workbench.call_tool_stream("test_stream_tool", {"count": 2}): + results.append(result) # type: ignore + assert len(results) == 3 # 2 intermediate + 1 final