AssistantAgent to support Workbench (#6393)

Finishing up the work on workbench.

```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.tools.mcp import StdioServerParams, McpWorkbench

async def main() -> None:
    params = StdioServerParams(
        command="uvx",
        args=["mcp-server-fetch"],
        read_timeout_seconds=60,
    )

    # You can also use `start()` and `stop()` to manage the session.
    async with McpWorkbench(server_params=params) as workbench:
        model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
        assistant = AssistantAgent(
            name="Assistant",
            model_client=model_client,
            workbench=workbench,
            reflect_on_tool_use=True,
        )
        await Console(assistant.run_stream(task="Go to https://github.com/microsoft/autogen and tell me what you see."))
    
asyncio.run(main())
```
This commit is contained in:
Eric Zhu 2025-04-24 16:19:36 -07:00 committed by GitHub
parent 0c9fd64d6e
commit bab0dfd1e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 227 additions and 37 deletions

View File

@ -32,7 +32,7 @@ from autogen_core.models import (
ModelFamily, ModelFamily,
SystemMessage, SystemMessage,
) )
from autogen_core.tools import BaseTool, FunctionTool from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench, Workbench
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Self from typing_extensions import Self
@ -66,6 +66,7 @@ class AssistantAgentConfig(BaseModel):
name: str name: str
model_client: ComponentModel model_client: ComponentModel
tools: List[ComponentModel] | None tools: List[ComponentModel] | None
workbench: ComponentModel | None = None
handoffs: List[HandoffBase | str] | None = None handoffs: List[HandoffBase | str] | None = None
model_context: ComponentModel | None = None model_context: ComponentModel | None = None
memory: List[ComponentModel] | None = None memory: List[ComponentModel] | None = None
@ -168,6 +169,8 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
name (str): The name of the agent. name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference. model_client (ChatCompletionClient): The model client to use for inference.
tools (List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent. tools (List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
workbench (Workbench | None, optional): The workbench to use for the agent.
Tools cannot be used when workbench is set and vice versa.
handoffs (List[HandoffBase | str] | None, optional): The handoff configurations for the agent, handoffs (List[HandoffBase | str] | None, optional): The handoff configurations for the agent,
allowing it to transfer to other agents by responding with a :class:`HandoffMessage`. allowing it to transfer to other agents by responding with a :class:`HandoffMessage`.
The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`. The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`.
@ -334,7 +337,45 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
asyncio.run(main()) asyncio.run(main())
**Example 4: agent with structured output and tool** **Example 4: agent with Model-Context Protocol (MCP) workbench**
The following example demonstrates how to create an assistant agent with
a model client and an :class:`~autogen_ext.tools.mcp.McpWorkbench` for
interacting with a Model-Context Protocol (MCP) server.
.. code-block:: python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.tools.mcp import StdioServerParams, McpWorkbench
async def main() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)
# You can also use `start()` and `stop()` to manage the session.
async with McpWorkbench(server_params=params) as workbench:
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
assistant = AssistantAgent(
name="Assistant",
model_client=model_client,
workbench=workbench,
reflect_on_tool_use=True,
)
await Console(
assistant.run_stream(task="Go to https://github.com/microsoft/autogen and tell me what you see.")
)
asyncio.run(main())
**Example 5: agent with structured output and tool**
The following example demonstrates how to create an assistant agent with The following example demonstrates how to create an assistant agent with
a model client configured to use structured output and a tool. a model client configured to use structured output and a tool.
@ -404,7 +445,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
---------- assistant ---------- ---------- assistant ----------
{"thoughts":"The user expresses a clear positive emotion by stating they are happy today, suggesting an upbeat mood.","response":"happy"} {"thoughts":"The user expresses a clear positive emotion by stating they are happy today, suggesting an upbeat mood.","response":"happy"}
**Example 5: agent with bounded model context** **Example 6: agent with bounded model context**
The following example shows how to use a The following example shows how to use a
:class:`~autogen_core.model_context.BufferedChatCompletionContext` :class:`~autogen_core.model_context.BufferedChatCompletionContext`
@ -465,7 +506,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
That's great! Blue is often associated with calmness and serenity. Do you have a specific shade of blue that you like, or any particular reason why it's your favorite? That's great! Blue is often associated with calmness and serenity. Do you have a specific shade of blue that you like, or any particular reason why it's your favorite?
No, you didn't ask a question. I apologize for any misunderstanding. If you have something specific you'd like to discuss or ask, feel free to let me know! No, you didn't ask a question. I apologize for any misunderstanding. If you have something specific you'd like to discuss or ask, feel free to let me know!
**Example 6: agent with memory** **Example 7: agent with memory**
The following example shows how to use a list-based memory with the assistant agent. The following example shows how to use a list-based memory with the assistant agent.
The memory is preloaded with some initial content. The memory is preloaded with some initial content.
@ -525,7 +566,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
Serve it with a side salad or some garlic bread to complete the meal! Enjoy your dinner! Serve it with a side salad or some garlic bread to complete the meal! Enjoy your dinner!
**Example 7: agent with `o1-mini`** **Example 8: agent with `o1-mini`**
The following example shows how to use `o1-mini` model with the assistant agent. The following example shows how to use `o1-mini` model with the assistant agent.
@ -561,7 +602,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
See `o1 beta limitations <https://platform.openai.com/docs/guides/reasoning#beta-limitations>`_ for more details. See `o1 beta limitations <https://platform.openai.com/docs/guides/reasoning#beta-limitations>`_ for more details.
**Example 8: agent using reasoning model with custom model context.** **Example 9: agent using reasoning model with custom model context.**
The following example shows how to use a reasoning model (DeepSeek R1) with the assistant agent. The following example shows how to use a reasoning model (DeepSeek R1) with the assistant agent.
The model context is used to filter out the thought field from the assistant message. The model context is used to filter out the thought field from the assistant message.
@ -628,6 +669,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client: ChatCompletionClient, model_client: ChatCompletionClient,
*, *,
tools: List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, tools: List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
workbench: Workbench | None = None,
handoffs: List[HandoffBase | str] | None = None, handoffs: List[HandoffBase | str] | None = None,
model_context: ChatCompletionContext | None = None, model_context: ChatCompletionContext | None = None,
description: str = "An agent that provides assistance with ability to use tools.", description: str = "An agent that provides assistance with ability to use tools.",
@ -711,6 +753,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
f"Handoff names: {handoff_tool_names}; tool names: {tool_names}" f"Handoff names: {handoff_tool_names}; tool names: {tool_names}"
) )
if workbench is not None:
if self._tools:
raise ValueError("Tools cannot be used with a workbench.")
self._workbench = workbench
else:
self._workbench = StaticWorkbench(self._tools)
if model_context is not None: if model_context is not None:
self._model_context = model_context self._model_context = model_context
else: else:
@ -774,7 +823,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_context = self._model_context model_context = self._model_context
memory = self._memory memory = self._memory
system_messages = self._system_messages system_messages = self._system_messages
tools = self._tools workbench = self._workbench
handoff_tools = self._handoff_tools handoff_tools = self._handoff_tools
handoffs = self._handoffs handoffs = self._handoffs
model_client = self._model_client model_client = self._model_client
@ -807,7 +856,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream=model_client_stream, model_client_stream=model_client_stream,
system_messages=system_messages, system_messages=system_messages,
model_context=model_context, model_context=model_context,
tools=tools, workbench=workbench,
handoff_tools=handoff_tools, handoff_tools=handoff_tools,
agent_name=agent_name, agent_name=agent_name,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
@ -844,7 +893,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
agent_name=agent_name, agent_name=agent_name,
system_messages=system_messages, system_messages=system_messages,
model_context=model_context, model_context=model_context,
tools=tools, workbench=workbench,
handoff_tools=handoff_tools, handoff_tools=handoff_tools,
handoffs=handoffs, handoffs=handoffs,
model_client=model_client, model_client=model_client,
@ -898,7 +947,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream: bool, model_client_stream: bool,
system_messages: List[SystemMessage], system_messages: List[SystemMessage],
model_context: ChatCompletionContext, model_context: ChatCompletionContext,
tools: List[BaseTool[Any, Any]], workbench: Workbench,
handoff_tools: List[BaseTool[Any, Any]], handoff_tools: List[BaseTool[Any, Any]],
agent_name: str, agent_name: str,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
@ -910,13 +959,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
all_messages = await model_context.get_messages() all_messages = await model_context.get_messages()
llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages) llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages)
all_tools = tools + handoff_tools tools = (await workbench.list_tools()) + handoff_tools
if model_client_stream: if model_client_stream:
model_result: Optional[CreateResult] = None model_result: Optional[CreateResult] = None
async for chunk in model_client.create_stream( async for chunk in model_client.create_stream(
llm_messages, llm_messages,
tools=all_tools, tools=tools,
json_output=output_content_type, json_output=output_content_type,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
): ):
@ -932,7 +981,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
else: else:
model_result = await model_client.create( model_result = await model_client.create(
llm_messages, llm_messages,
tools=all_tools, tools=tools,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
json_output=output_content_type, json_output=output_content_type,
) )
@ -947,7 +996,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
agent_name: str, agent_name: str,
system_messages: List[SystemMessage], system_messages: List[SystemMessage],
model_context: ChatCompletionContext, model_context: ChatCompletionContext,
tools: List[BaseTool[Any, Any]], workbench: Workbench,
handoff_tools: List[BaseTool[Any, Any]], handoff_tools: List[BaseTool[Any, Any]],
handoffs: Dict[str, HandoffBase], handoffs: Dict[str, HandoffBase],
model_client: ChatCompletionClient, model_client: ChatCompletionClient,
@ -1006,7 +1055,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
*[ *[
cls._execute_tool_call( cls._execute_tool_call(
tool_call=call, tool_call=call,
tools=tools, workbench=workbench,
handoff_tools=handoff_tools, handoff_tools=handoff_tools,
agent_name=agent_name, agent_name=agent_name,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
@ -1238,32 +1287,16 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
@staticmethod @staticmethod
async def _execute_tool_call( async def _execute_tool_call(
tool_call: FunctionCall, tool_call: FunctionCall,
tools: List[BaseTool[Any, Any]], workbench: Workbench,
handoff_tools: List[BaseTool[Any, Any]], handoff_tools: List[BaseTool[Any, Any]],
agent_name: str, agent_name: str,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
) -> Tuple[FunctionCall, FunctionExecutionResult]: ) -> Tuple[FunctionCall, FunctionExecutionResult]:
"""Execute a single tool call and return the result.""" """Execute a single tool call and return the result."""
# Load the arguments from the tool call.
try: try:
all_tools = tools + handoff_tools arguments = json.loads(tool_call.arguments)
if not all_tools: except json.JSONDecodeError as e:
raise ValueError("No tools are available.")
tool = next((t for t in all_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments: Dict[str, Any] = json.loads(tool_call.arguments) if tool_call.arguments else {}
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return (
tool_call,
FunctionExecutionResult(
content=result_as_str,
call_id=tool_call.id,
is_error=False,
name=tool_call.name,
),
)
except Exception as e:
return ( return (
tool_call, tool_call,
FunctionExecutionResult( FunctionExecutionResult(
@ -1274,6 +1307,39 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
), ),
) )
# Check if the tool call is a handoff.
# TODO: consider creating a combined workbench to handle both handoff and normal tools.
for handoff_tool in handoff_tools:
if tool_call.name == handoff_tool.name:
# Run handoff tool call.
result = await handoff_tool.run_json(arguments, cancellation_token)
result_as_str = handoff_tool.return_value_as_string(result)
return (
tool_call,
FunctionExecutionResult(
content=result_as_str,
call_id=tool_call.id,
is_error=False,
name=tool_call.name,
),
)
# Handle normal tool call using workbench.
result = await workbench.call_tool(
name=tool_call.name,
arguments=arguments,
cancellation_token=cancellation_token,
)
return (
tool_call,
FunctionExecutionResult(
content=result.to_text(),
call_id=tool_call.id,
is_error=result.is_error,
name=tool_call.name,
),
)
async def on_reset(self, cancellation_token: CancellationToken) -> None: async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state.""" """Reset the assistant agent to its initialization state."""
await self._model_context.clear() await self._model_context.clear()
@ -1304,6 +1370,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
name=self.name, name=self.name,
model_client=self._model_client.dump_component(), model_client=self._model_client.dump_component(),
tools=[tool.dump_component() for tool in self._tools], tools=[tool.dump_component() for tool in self._tools],
workbench=self._workbench.dump_component() if self._workbench else None,
handoffs=list(self._handoffs.values()) if self._handoffs else None, handoffs=list(self._handoffs.values()) if self._handoffs else None,
model_context=self._model_context.dump_component(), model_context=self._model_context.dump_component(),
memory=[memory.dump_component() for memory in self._memory] if self._memory else None, memory=[memory.dump_component() for memory in self._memory] if self._memory else None,
@ -1336,6 +1403,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
name=config.name, name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client), model_client=ChatCompletionClient.load_component(config.model_client),
tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None, tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None,
workbench=Workbench.load_component(config.workbench) if config.workbench else None,
handoffs=config.handoffs, handoffs=config.handoffs,
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None, memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None,

View File

@ -33,7 +33,7 @@ from autogen_core.models import (
UserMessage, UserMessage,
) )
from autogen_core.models._model_client import ModelFamily from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import BaseTool, FunctionTool from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench
from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -401,6 +401,124 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
assert state == state2 assert state == state2
@pytest.mark.asyncio
async def test_run_with_workbench() -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="Hello",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="TERMINATE",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
workbench = StaticWorkbench(
[
FunctionTool(_pass_function, description="Pass"),
FunctionTool(_fail_function, description="Fail"),
FunctionTool(_echo_function, description="Echo"),
]
)
# Test raise error when both workbench and tools are provided.
with pytest.raises(ValueError):
AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
workbench=workbench,
)
agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
workbench=workbench,
reflect_on_tool_use=True,
)
result = await agent.run(task="task")
# Make sure the create call was made with the correct parameters.
assert len(model_client.create_calls) == 2
llm_messages = model_client.create_calls[0]["messages"]
assert len(llm_messages) == 2
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
llm_messages = model_client.create_calls[1]["messages"]
assert len(llm_messages) == 4
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
assert isinstance(llm_messages[2], AssistantMessage)
assert isinstance(llm_messages[3], FunctionExecutionResultMessage)
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].content == "Hello"
assert result.messages[3].models_usage is not None
assert result.messages[3].models_usage.completion_tokens == 5
assert result.messages[3].models_usage.prompt_tokens == 10
# Test streaming.
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_output_format() -> None: async def test_output_format() -> None:
class AgentResponse(BaseModel): class AgentResponse(BaseModel):

View File

@ -43,7 +43,11 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
) -> ToolResult: ) -> ToolResult:
tool = next((tool for tool in self._tools if tool.name == name), None) tool = next((tool for tool in self._tools if tool.name == name), None)
if tool is None: if tool is None:
raise ValueError(f"Tool {name} not found in workbench.") return ToolResult(
name=name,
result=[TextResultContent(content=f"Tool {name} not found.")],
is_error=True,
)
if not cancellation_token: if not cancellation_token:
cancellation_token = CancellationToken() cancellation_token = CancellationToken()
if not arguments: if not arguments: