diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index f066e0822..b51ded56e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -32,7 +32,7 @@ from autogen_core.models import ( ModelFamily, SystemMessage, ) -from autogen_core.tools import BaseTool, FunctionTool +from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench, Workbench from pydantic import BaseModel from typing_extensions import Self @@ -66,6 +66,7 @@ class AssistantAgentConfig(BaseModel): name: str model_client: ComponentModel tools: List[ComponentModel] | None + workbench: ComponentModel | None = None handoffs: List[HandoffBase | str] | None = None model_context: ComponentModel | None = None memory: List[ComponentModel] | None = None @@ -168,6 +169,8 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): name (str): The name of the agent. model_client (ChatCompletionClient): The model client to use for inference. tools (List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent. + workbench (Workbench | None, optional): The workbench to use for the agent. + Tools cannot be used when workbench is set and vice versa. handoffs (List[HandoffBase | str] | None, optional): The handoff configurations for the agent, allowing it to transfer to other agents by responding with a :class:`HandoffMessage`. The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`. @@ -334,7 +337,45 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): asyncio.run(main()) - **Example 4: agent with structured output and tool** + **Example 4: agent with Model-Context Protocol (MCP) workbench** + + The following example demonstrates how to create an assistant agent with + a model client and an :class:`~autogen_ext.tools.mcp.McpWorkbench` for + interacting with a Model-Context Protocol (MCP) server. + + .. code-block:: python + + import asyncio + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.ui import Console + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_ext.tools.mcp import StdioServerParams, McpWorkbench + + + async def main() -> None: + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + # You can also use `start()` and `stop()` to manage the session. + async with McpWorkbench(server_params=params) as workbench: + model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") + assistant = AssistantAgent( + name="Assistant", + model_client=model_client, + workbench=workbench, + reflect_on_tool_use=True, + ) + await Console( + assistant.run_stream(task="Go to https://github.com/microsoft/autogen and tell me what you see.") + ) + + + asyncio.run(main()) + + **Example 5: agent with structured output and tool** The following example demonstrates how to create an assistant agent with a model client configured to use structured output and a tool. @@ -404,7 +445,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): ---------- assistant ---------- {"thoughts":"The user expresses a clear positive emotion by stating they are happy today, suggesting an upbeat mood.","response":"happy"} - **Example 5: agent with bounded model context** + **Example 6: agent with bounded model context** The following example shows how to use a :class:`~autogen_core.model_context.BufferedChatCompletionContext` @@ -465,7 +506,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): That's great! Blue is often associated with calmness and serenity. Do you have a specific shade of blue that you like, or any particular reason why it's your favorite? No, you didn't ask a question. I apologize for any misunderstanding. If you have something specific you'd like to discuss or ask, feel free to let me know! - **Example 6: agent with memory** + **Example 7: agent with memory** The following example shows how to use a list-based memory with the assistant agent. The memory is preloaded with some initial content. @@ -525,7 +566,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): Serve it with a side salad or some garlic bread to complete the meal! Enjoy your dinner! - **Example 7: agent with `o1-mini`** + **Example 8: agent with `o1-mini`** The following example shows how to use `o1-mini` model with the assistant agent. @@ -561,7 +602,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): See `o1 beta limitations `_ for more details. - **Example 8: agent using reasoning model with custom model context.** + **Example 9: agent using reasoning model with custom model context.** The following example shows how to use a reasoning model (DeepSeek R1) with the assistant agent. The model context is used to filter out the thought field from the assistant message. @@ -628,6 +669,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): model_client: ChatCompletionClient, *, tools: List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, + workbench: Workbench | None = None, handoffs: List[HandoffBase | str] | None = None, model_context: ChatCompletionContext | None = None, description: str = "An agent that provides assistance with ability to use tools.", @@ -711,6 +753,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): f"Handoff names: {handoff_tool_names}; tool names: {tool_names}" ) + if workbench is not None: + if self._tools: + raise ValueError("Tools cannot be used with a workbench.") + self._workbench = workbench + else: + self._workbench = StaticWorkbench(self._tools) + if model_context is not None: self._model_context = model_context else: @@ -774,7 +823,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): model_context = self._model_context memory = self._memory system_messages = self._system_messages - tools = self._tools + workbench = self._workbench handoff_tools = self._handoff_tools handoffs = self._handoffs model_client = self._model_client @@ -807,7 +856,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): model_client_stream=model_client_stream, system_messages=system_messages, model_context=model_context, - tools=tools, + workbench=workbench, handoff_tools=handoff_tools, agent_name=agent_name, cancellation_token=cancellation_token, @@ -844,7 +893,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): agent_name=agent_name, system_messages=system_messages, model_context=model_context, - tools=tools, + workbench=workbench, handoff_tools=handoff_tools, handoffs=handoffs, model_client=model_client, @@ -898,7 +947,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): model_client_stream: bool, system_messages: List[SystemMessage], model_context: ChatCompletionContext, - tools: List[BaseTool[Any, Any]], + workbench: Workbench, handoff_tools: List[BaseTool[Any, Any]], agent_name: str, cancellation_token: CancellationToken, @@ -910,13 +959,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): all_messages = await model_context.get_messages() llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages) - all_tools = tools + handoff_tools + tools = (await workbench.list_tools()) + handoff_tools if model_client_stream: model_result: Optional[CreateResult] = None async for chunk in model_client.create_stream( llm_messages, - tools=all_tools, + tools=tools, json_output=output_content_type, cancellation_token=cancellation_token, ): @@ -932,7 +981,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): else: model_result = await model_client.create( llm_messages, - tools=all_tools, + tools=tools, cancellation_token=cancellation_token, json_output=output_content_type, ) @@ -947,7 +996,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): agent_name: str, system_messages: List[SystemMessage], model_context: ChatCompletionContext, - tools: List[BaseTool[Any, Any]], + workbench: Workbench, handoff_tools: List[BaseTool[Any, Any]], handoffs: Dict[str, HandoffBase], model_client: ChatCompletionClient, @@ -1006,7 +1055,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): *[ cls._execute_tool_call( tool_call=call, - tools=tools, + workbench=workbench, handoff_tools=handoff_tools, agent_name=agent_name, cancellation_token=cancellation_token, @@ -1238,32 +1287,16 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): @staticmethod async def _execute_tool_call( tool_call: FunctionCall, - tools: List[BaseTool[Any, Any]], + workbench: Workbench, handoff_tools: List[BaseTool[Any, Any]], agent_name: str, cancellation_token: CancellationToken, ) -> Tuple[FunctionCall, FunctionExecutionResult]: """Execute a single tool call and return the result.""" + # Load the arguments from the tool call. try: - all_tools = tools + handoff_tools - if not all_tools: - raise ValueError("No tools are available.") - tool = next((t for t in all_tools if t.name == tool_call.name), None) - if tool is None: - raise ValueError(f"The tool '{tool_call.name}' is not available.") - arguments: Dict[str, Any] = json.loads(tool_call.arguments) if tool_call.arguments else {} - result = await tool.run_json(arguments, cancellation_token) - result_as_str = tool.return_value_as_string(result) - return ( - tool_call, - FunctionExecutionResult( - content=result_as_str, - call_id=tool_call.id, - is_error=False, - name=tool_call.name, - ), - ) - except Exception as e: + arguments = json.loads(tool_call.arguments) + except json.JSONDecodeError as e: return ( tool_call, FunctionExecutionResult( @@ -1274,6 +1307,39 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): ), ) + # Check if the tool call is a handoff. + # TODO: consider creating a combined workbench to handle both handoff and normal tools. + for handoff_tool in handoff_tools: + if tool_call.name == handoff_tool.name: + # Run handoff tool call. + result = await handoff_tool.run_json(arguments, cancellation_token) + result_as_str = handoff_tool.return_value_as_string(result) + return ( + tool_call, + FunctionExecutionResult( + content=result_as_str, + call_id=tool_call.id, + is_error=False, + name=tool_call.name, + ), + ) + + # Handle normal tool call using workbench. + result = await workbench.call_tool( + name=tool_call.name, + arguments=arguments, + cancellation_token=cancellation_token, + ) + return ( + tool_call, + FunctionExecutionResult( + content=result.to_text(), + call_id=tool_call.id, + is_error=result.is_error, + name=tool_call.name, + ), + ) + async def on_reset(self, cancellation_token: CancellationToken) -> None: """Reset the assistant agent to its initialization state.""" await self._model_context.clear() @@ -1304,6 +1370,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): name=self.name, model_client=self._model_client.dump_component(), tools=[tool.dump_component() for tool in self._tools], + workbench=self._workbench.dump_component() if self._workbench else None, handoffs=list(self._handoffs.values()) if self._handoffs else None, model_context=self._model_context.dump_component(), memory=[memory.dump_component() for memory in self._memory] if self._memory else None, @@ -1336,6 +1403,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): name=config.name, model_client=ChatCompletionClient.load_component(config.model_client), tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None, + workbench=Workbench.load_component(config.workbench) if config.workbench else None, handoffs=config.handoffs, model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None, diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 6aab0374f..6ad75f2cc 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -33,7 +33,7 @@ from autogen_core.models import ( UserMessage, ) from autogen_core.models._model_client import ModelFamily -from autogen_core.tools import BaseTool, FunctionTool +from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.models.replay import ReplayChatCompletionClient from pydantic import BaseModel, ValidationError @@ -401,6 +401,124 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None: assert state == state2 +@pytest.mark.asyncio +async def test_run_with_workbench() -> None: + model_client = ReplayChatCompletionClient( + [ + CreateResult( + finish_reason="function_calls", + content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")], + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + cached=False, + ), + CreateResult( + finish_reason="stop", + content="Hello", + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + cached=False, + ), + CreateResult( + finish_reason="stop", + content="TERMINATE", + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + cached=False, + ), + ], + model_info={ + "function_calling": True, + "vision": True, + "json_output": True, + "family": ModelFamily.GPT_4O, + "structured_output": True, + }, + ) + workbench = StaticWorkbench( + [ + FunctionTool(_pass_function, description="Pass"), + FunctionTool(_fail_function, description="Fail"), + FunctionTool(_echo_function, description="Echo"), + ] + ) + + # Test raise error when both workbench and tools are provided. + with pytest.raises(ValueError): + AssistantAgent( + "tool_use_agent", + model_client=model_client, + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], + workbench=workbench, + ) + + agent = AssistantAgent( + "tool_use_agent", + model_client=model_client, + workbench=workbench, + reflect_on_tool_use=True, + ) + result = await agent.run(task="task") + + # Make sure the create call was made with the correct parameters. + assert len(model_client.create_calls) == 2 + llm_messages = model_client.create_calls[0]["messages"] + assert len(llm_messages) == 2 + assert isinstance(llm_messages[0], SystemMessage) + assert llm_messages[0].content == agent._system_messages[0].content # type: ignore + assert isinstance(llm_messages[1], UserMessage) + assert llm_messages[1].content == "task" + llm_messages = model_client.create_calls[1]["messages"] + assert len(llm_messages) == 4 + assert isinstance(llm_messages[0], SystemMessage) + assert llm_messages[0].content == agent._system_messages[0].content # type: ignore + assert isinstance(llm_messages[1], UserMessage) + assert llm_messages[1].content == "task" + assert isinstance(llm_messages[2], AssistantMessage) + assert isinstance(llm_messages[3], FunctionExecutionResultMessage) + + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].models_usage is None + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert result.messages[1].models_usage is not None + assert result.messages[1].models_usage.completion_tokens == 5 + assert result.messages[1].models_usage.prompt_tokens == 10 + assert isinstance(result.messages[2], ToolCallExecutionEvent) + assert result.messages[2].models_usage is None + assert isinstance(result.messages[3], TextMessage) + assert result.messages[3].content == "Hello" + assert result.messages[3].models_usage is not None + assert result.messages[3].models_usage.completion_tokens == 5 + assert result.messages[3].models_usage.prompt_tokens == 10 + + # Test streaming. + model_client.reset() + index = 0 + async for message in agent.run_stream(task="task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + + # Test state saving and loading. + state = await agent.save_state() + agent2 = AssistantAgent( + "tool_use_agent", + model_client=model_client, + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], + ) + await agent2.load_state(state) + state2 = await agent2.save_state() + assert state == state2 + + @pytest.mark.asyncio async def test_output_format() -> None: class AgentResponse(BaseModel): diff --git a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py index 5a997defa..2e762232e 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py @@ -43,7 +43,11 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): ) -> ToolResult: tool = next((tool for tool in self._tools if tool.name == name), None) if tool is None: - raise ValueError(f"Tool {name} not found in workbench.") + return ToolResult( + name=name, + result=[TextResultContent(content=f"Tool {name} not found.")], + is_error=True, + ) if not cancellation_token: cancellation_token = CancellationToken() if not arguments: