AssistantAgent to support Workbench (#6393)

Finishing up the work on workbench.

```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.tools.mcp import StdioServerParams, McpWorkbench

async def main() -> None:
    params = StdioServerParams(
        command="uvx",
        args=["mcp-server-fetch"],
        read_timeout_seconds=60,
    )

    # You can also use `start()` and `stop()` to manage the session.
    async with McpWorkbench(server_params=params) as workbench:
        model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
        assistant = AssistantAgent(
            name="Assistant",
            model_client=model_client,
            workbench=workbench,
            reflect_on_tool_use=True,
        )
        await Console(assistant.run_stream(task="Go to https://github.com/microsoft/autogen and tell me what you see."))
    
asyncio.run(main())
```
This commit is contained in:
Eric Zhu 2025-04-24 16:19:36 -07:00 committed by GitHub
parent 0c9fd64d6e
commit bab0dfd1e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 227 additions and 37 deletions

View File

@ -32,7 +32,7 @@ from autogen_core.models import (
ModelFamily,
SystemMessage,
)
from autogen_core.tools import BaseTool, FunctionTool
from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench, Workbench
from pydantic import BaseModel
from typing_extensions import Self
@ -66,6 +66,7 @@ class AssistantAgentConfig(BaseModel):
name: str
model_client: ComponentModel
tools: List[ComponentModel] | None
workbench: ComponentModel | None = None
handoffs: List[HandoffBase | str] | None = None
model_context: ComponentModel | None = None
memory: List[ComponentModel] | None = None
@ -168,6 +169,8 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
tools (List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
workbench (Workbench | None, optional): The workbench to use for the agent.
Tools cannot be used when workbench is set and vice versa.
handoffs (List[HandoffBase | str] | None, optional): The handoff configurations for the agent,
allowing it to transfer to other agents by responding with a :class:`HandoffMessage`.
The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`.
@ -334,7 +337,45 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
asyncio.run(main())
**Example 4: agent with structured output and tool**
**Example 4: agent with Model-Context Protocol (MCP) workbench**
The following example demonstrates how to create an assistant agent with
a model client and an :class:`~autogen_ext.tools.mcp.McpWorkbench` for
interacting with a Model-Context Protocol (MCP) server.
.. code-block:: python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.tools.mcp import StdioServerParams, McpWorkbench
async def main() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)
# You can also use `start()` and `stop()` to manage the session.
async with McpWorkbench(server_params=params) as workbench:
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
assistant = AssistantAgent(
name="Assistant",
model_client=model_client,
workbench=workbench,
reflect_on_tool_use=True,
)
await Console(
assistant.run_stream(task="Go to https://github.com/microsoft/autogen and tell me what you see.")
)
asyncio.run(main())
**Example 5: agent with structured output and tool**
The following example demonstrates how to create an assistant agent with
a model client configured to use structured output and a tool.
@ -404,7 +445,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
---------- assistant ----------
{"thoughts":"The user expresses a clear positive emotion by stating they are happy today, suggesting an upbeat mood.","response":"happy"}
**Example 5: agent with bounded model context**
**Example 6: agent with bounded model context**
The following example shows how to use a
:class:`~autogen_core.model_context.BufferedChatCompletionContext`
@ -465,7 +506,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
That's great! Blue is often associated with calmness and serenity. Do you have a specific shade of blue that you like, or any particular reason why it's your favorite?
No, you didn't ask a question. I apologize for any misunderstanding. If you have something specific you'd like to discuss or ask, feel free to let me know!
**Example 6: agent with memory**
**Example 7: agent with memory**
The following example shows how to use a list-based memory with the assistant agent.
The memory is preloaded with some initial content.
@ -525,7 +566,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
Serve it with a side salad or some garlic bread to complete the meal! Enjoy your dinner!
**Example 7: agent with `o1-mini`**
**Example 8: agent with `o1-mini`**
The following example shows how to use `o1-mini` model with the assistant agent.
@ -561,7 +602,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
See `o1 beta limitations <https://platform.openai.com/docs/guides/reasoning#beta-limitations>`_ for more details.
**Example 8: agent using reasoning model with custom model context.**
**Example 9: agent using reasoning model with custom model context.**
The following example shows how to use a reasoning model (DeepSeek R1) with the assistant agent.
The model context is used to filter out the thought field from the assistant message.
@ -628,6 +669,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client: ChatCompletionClient,
*,
tools: List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
workbench: Workbench | None = None,
handoffs: List[HandoffBase | str] | None = None,
model_context: ChatCompletionContext | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
@ -711,6 +753,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
f"Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
if workbench is not None:
if self._tools:
raise ValueError("Tools cannot be used with a workbench.")
self._workbench = workbench
else:
self._workbench = StaticWorkbench(self._tools)
if model_context is not None:
self._model_context = model_context
else:
@ -774,7 +823,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_context = self._model_context
memory = self._memory
system_messages = self._system_messages
tools = self._tools
workbench = self._workbench
handoff_tools = self._handoff_tools
handoffs = self._handoffs
model_client = self._model_client
@ -807,7 +856,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream=model_client_stream,
system_messages=system_messages,
model_context=model_context,
tools=tools,
workbench=workbench,
handoff_tools=handoff_tools,
agent_name=agent_name,
cancellation_token=cancellation_token,
@ -844,7 +893,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
agent_name=agent_name,
system_messages=system_messages,
model_context=model_context,
tools=tools,
workbench=workbench,
handoff_tools=handoff_tools,
handoffs=handoffs,
model_client=model_client,
@ -898,7 +947,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream: bool,
system_messages: List[SystemMessage],
model_context: ChatCompletionContext,
tools: List[BaseTool[Any, Any]],
workbench: Workbench,
handoff_tools: List[BaseTool[Any, Any]],
agent_name: str,
cancellation_token: CancellationToken,
@ -910,13 +959,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
all_messages = await model_context.get_messages()
llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages)
all_tools = tools + handoff_tools
tools = (await workbench.list_tools()) + handoff_tools
if model_client_stream:
model_result: Optional[CreateResult] = None
async for chunk in model_client.create_stream(
llm_messages,
tools=all_tools,
tools=tools,
json_output=output_content_type,
cancellation_token=cancellation_token,
):
@ -932,7 +981,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
else:
model_result = await model_client.create(
llm_messages,
tools=all_tools,
tools=tools,
cancellation_token=cancellation_token,
json_output=output_content_type,
)
@ -947,7 +996,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
agent_name: str,
system_messages: List[SystemMessage],
model_context: ChatCompletionContext,
tools: List[BaseTool[Any, Any]],
workbench: Workbench,
handoff_tools: List[BaseTool[Any, Any]],
handoffs: Dict[str, HandoffBase],
model_client: ChatCompletionClient,
@ -1006,7 +1055,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
*[
cls._execute_tool_call(
tool_call=call,
tools=tools,
workbench=workbench,
handoff_tools=handoff_tools,
agent_name=agent_name,
cancellation_token=cancellation_token,
@ -1238,32 +1287,16 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
@staticmethod
async def _execute_tool_call(
tool_call: FunctionCall,
tools: List[BaseTool[Any, Any]],
workbench: Workbench,
handoff_tools: List[BaseTool[Any, Any]],
agent_name: str,
cancellation_token: CancellationToken,
) -> Tuple[FunctionCall, FunctionExecutionResult]:
"""Execute a single tool call and return the result."""
# Load the arguments from the tool call.
try:
all_tools = tools + handoff_tools
if not all_tools:
raise ValueError("No tools are available.")
tool = next((t for t in all_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments: Dict[str, Any] = json.loads(tool_call.arguments) if tool_call.arguments else {}
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return (
tool_call,
FunctionExecutionResult(
content=result_as_str,
call_id=tool_call.id,
is_error=False,
name=tool_call.name,
),
)
except Exception as e:
arguments = json.loads(tool_call.arguments)
except json.JSONDecodeError as e:
return (
tool_call,
FunctionExecutionResult(
@ -1274,6 +1307,39 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
),
)
# Check if the tool call is a handoff.
# TODO: consider creating a combined workbench to handle both handoff and normal tools.
for handoff_tool in handoff_tools:
if tool_call.name == handoff_tool.name:
# Run handoff tool call.
result = await handoff_tool.run_json(arguments, cancellation_token)
result_as_str = handoff_tool.return_value_as_string(result)
return (
tool_call,
FunctionExecutionResult(
content=result_as_str,
call_id=tool_call.id,
is_error=False,
name=tool_call.name,
),
)
# Handle normal tool call using workbench.
result = await workbench.call_tool(
name=tool_call.name,
arguments=arguments,
cancellation_token=cancellation_token,
)
return (
tool_call,
FunctionExecutionResult(
content=result.to_text(),
call_id=tool_call.id,
is_error=result.is_error,
name=tool_call.name,
),
)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
await self._model_context.clear()
@ -1304,6 +1370,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
name=self.name,
model_client=self._model_client.dump_component(),
tools=[tool.dump_component() for tool in self._tools],
workbench=self._workbench.dump_component() if self._workbench else None,
handoffs=list(self._handoffs.values()) if self._handoffs else None,
model_context=self._model_context.dump_component(),
memory=[memory.dump_component() for memory in self._memory] if self._memory else None,
@ -1336,6 +1403,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None,
workbench=Workbench.load_component(config.workbench) if config.workbench else None,
handoffs=config.handoffs,
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None,

View File

@ -33,7 +33,7 @@ from autogen_core.models import (
UserMessage,
)
from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import BaseTool, FunctionTool
from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel, ValidationError
@ -401,6 +401,124 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
assert state == state2
@pytest.mark.asyncio
async def test_run_with_workbench() -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="Hello",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="TERMINATE",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
workbench = StaticWorkbench(
[
FunctionTool(_pass_function, description="Pass"),
FunctionTool(_fail_function, description="Fail"),
FunctionTool(_echo_function, description="Echo"),
]
)
# Test raise error when both workbench and tools are provided.
with pytest.raises(ValueError):
AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
workbench=workbench,
)
agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
workbench=workbench,
reflect_on_tool_use=True,
)
result = await agent.run(task="task")
# Make sure the create call was made with the correct parameters.
assert len(model_client.create_calls) == 2
llm_messages = model_client.create_calls[0]["messages"]
assert len(llm_messages) == 2
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
llm_messages = model_client.create_calls[1]["messages"]
assert len(llm_messages) == 4
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
assert isinstance(llm_messages[2], AssistantMessage)
assert isinstance(llm_messages[3], FunctionExecutionResultMessage)
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].content == "Hello"
assert result.messages[3].models_usage is not None
assert result.messages[3].models_usage.completion_tokens == 5
assert result.messages[3].models_usage.prompt_tokens == 10
# Test streaming.
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2
@pytest.mark.asyncio
async def test_output_format() -> None:
class AgentResponse(BaseModel):

View File

@ -43,7 +43,11 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
) -> ToolResult:
tool = next((tool for tool in self._tools if tool.name == name), None)
if tool is None:
raise ValueError(f"Tool {name} not found in workbench.")
return ToolResult(
name=name,
result=[TextResultContent(content=f"Tool {name} not found.")],
is_error=True,
)
if not cancellation_token:
cancellation_token = CancellationToken()
if not arguments: