mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-12 07:21:18 +00:00
AssistantAgent to support Workbench (#6393)
Finishing up the work on workbench.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.tools.mcp import StdioServerParams, McpWorkbench
async def main() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)
# You can also use `start()` and `stop()` to manage the session.
async with McpWorkbench(server_params=params) as workbench:
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
assistant = AssistantAgent(
name="Assistant",
model_client=model_client,
workbench=workbench,
reflect_on_tool_use=True,
)
await Console(assistant.run_stream(task="Go to https://github.com/microsoft/autogen and tell me what you see."))
asyncio.run(main())
```
This commit is contained in:
parent
0c9fd64d6e
commit
bab0dfd1e7
@ -32,7 +32,7 @@ from autogen_core.models import (
|
||||
ModelFamily,
|
||||
SystemMessage,
|
||||
)
|
||||
from autogen_core.tools import BaseTool, FunctionTool
|
||||
from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench, Workbench
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
@ -66,6 +66,7 @@ class AssistantAgentConfig(BaseModel):
|
||||
name: str
|
||||
model_client: ComponentModel
|
||||
tools: List[ComponentModel] | None
|
||||
workbench: ComponentModel | None = None
|
||||
handoffs: List[HandoffBase | str] | None = None
|
||||
model_context: ComponentModel | None = None
|
||||
memory: List[ComponentModel] | None = None
|
||||
@ -168,6 +169,8 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client to use for inference.
|
||||
tools (List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
|
||||
workbench (Workbench | None, optional): The workbench to use for the agent.
|
||||
Tools cannot be used when workbench is set and vice versa.
|
||||
handoffs (List[HandoffBase | str] | None, optional): The handoff configurations for the agent,
|
||||
allowing it to transfer to other agents by responding with a :class:`HandoffMessage`.
|
||||
The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`.
|
||||
@ -334,7 +337,45 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Example 4: agent with structured output and tool**
|
||||
**Example 4: agent with Model-Context Protocol (MCP) workbench**
|
||||
|
||||
The following example demonstrates how to create an assistant agent with
|
||||
a model client and an :class:`~autogen_ext.tools.mcp.McpWorkbench` for
|
||||
interacting with a Model-Context Protocol (MCP) server.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.tools.mcp import StdioServerParams, McpWorkbench
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
params = StdioServerParams(
|
||||
command="uvx",
|
||||
args=["mcp-server-fetch"],
|
||||
read_timeout_seconds=60,
|
||||
)
|
||||
|
||||
# You can also use `start()` and `stop()` to manage the session.
|
||||
async with McpWorkbench(server_params=params) as workbench:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
assistant = AssistantAgent(
|
||||
name="Assistant",
|
||||
model_client=model_client,
|
||||
workbench=workbench,
|
||||
reflect_on_tool_use=True,
|
||||
)
|
||||
await Console(
|
||||
assistant.run_stream(task="Go to https://github.com/microsoft/autogen and tell me what you see.")
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Example 5: agent with structured output and tool**
|
||||
|
||||
The following example demonstrates how to create an assistant agent with
|
||||
a model client configured to use structured output and a tool.
|
||||
@ -404,7 +445,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
---------- assistant ----------
|
||||
{"thoughts":"The user expresses a clear positive emotion by stating they are happy today, suggesting an upbeat mood.","response":"happy"}
|
||||
|
||||
**Example 5: agent with bounded model context**
|
||||
**Example 6: agent with bounded model context**
|
||||
|
||||
The following example shows how to use a
|
||||
:class:`~autogen_core.model_context.BufferedChatCompletionContext`
|
||||
@ -465,7 +506,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
That's great! Blue is often associated with calmness and serenity. Do you have a specific shade of blue that you like, or any particular reason why it's your favorite?
|
||||
No, you didn't ask a question. I apologize for any misunderstanding. If you have something specific you'd like to discuss or ask, feel free to let me know!
|
||||
|
||||
**Example 6: agent with memory**
|
||||
**Example 7: agent with memory**
|
||||
|
||||
The following example shows how to use a list-based memory with the assistant agent.
|
||||
The memory is preloaded with some initial content.
|
||||
@ -525,7 +566,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
|
||||
Serve it with a side salad or some garlic bread to complete the meal! Enjoy your dinner!
|
||||
|
||||
**Example 7: agent with `o1-mini`**
|
||||
**Example 8: agent with `o1-mini`**
|
||||
|
||||
The following example shows how to use `o1-mini` model with the assistant agent.
|
||||
|
||||
@ -561,7 +602,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
See `o1 beta limitations <https://platform.openai.com/docs/guides/reasoning#beta-limitations>`_ for more details.
|
||||
|
||||
|
||||
**Example 8: agent using reasoning model with custom model context.**
|
||||
**Example 9: agent using reasoning model with custom model context.**
|
||||
|
||||
The following example shows how to use a reasoning model (DeepSeek R1) with the assistant agent.
|
||||
The model context is used to filter out the thought field from the assistant message.
|
||||
@ -628,6 +669,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
tools: List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
||||
workbench: Workbench | None = None,
|
||||
handoffs: List[HandoffBase | str] | None = None,
|
||||
model_context: ChatCompletionContext | None = None,
|
||||
description: str = "An agent that provides assistance with ability to use tools.",
|
||||
@ -711,6 +753,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
f"Handoff names: {handoff_tool_names}; tool names: {tool_names}"
|
||||
)
|
||||
|
||||
if workbench is not None:
|
||||
if self._tools:
|
||||
raise ValueError("Tools cannot be used with a workbench.")
|
||||
self._workbench = workbench
|
||||
else:
|
||||
self._workbench = StaticWorkbench(self._tools)
|
||||
|
||||
if model_context is not None:
|
||||
self._model_context = model_context
|
||||
else:
|
||||
@ -774,7 +823,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
model_context = self._model_context
|
||||
memory = self._memory
|
||||
system_messages = self._system_messages
|
||||
tools = self._tools
|
||||
workbench = self._workbench
|
||||
handoff_tools = self._handoff_tools
|
||||
handoffs = self._handoffs
|
||||
model_client = self._model_client
|
||||
@ -807,7 +856,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
model_client_stream=model_client_stream,
|
||||
system_messages=system_messages,
|
||||
model_context=model_context,
|
||||
tools=tools,
|
||||
workbench=workbench,
|
||||
handoff_tools=handoff_tools,
|
||||
agent_name=agent_name,
|
||||
cancellation_token=cancellation_token,
|
||||
@ -844,7 +893,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
agent_name=agent_name,
|
||||
system_messages=system_messages,
|
||||
model_context=model_context,
|
||||
tools=tools,
|
||||
workbench=workbench,
|
||||
handoff_tools=handoff_tools,
|
||||
handoffs=handoffs,
|
||||
model_client=model_client,
|
||||
@ -898,7 +947,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
model_client_stream: bool,
|
||||
system_messages: List[SystemMessage],
|
||||
model_context: ChatCompletionContext,
|
||||
tools: List[BaseTool[Any, Any]],
|
||||
workbench: Workbench,
|
||||
handoff_tools: List[BaseTool[Any, Any]],
|
||||
agent_name: str,
|
||||
cancellation_token: CancellationToken,
|
||||
@ -910,13 +959,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
all_messages = await model_context.get_messages()
|
||||
llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages)
|
||||
|
||||
all_tools = tools + handoff_tools
|
||||
tools = (await workbench.list_tools()) + handoff_tools
|
||||
|
||||
if model_client_stream:
|
||||
model_result: Optional[CreateResult] = None
|
||||
async for chunk in model_client.create_stream(
|
||||
llm_messages,
|
||||
tools=all_tools,
|
||||
tools=tools,
|
||||
json_output=output_content_type,
|
||||
cancellation_token=cancellation_token,
|
||||
):
|
||||
@ -932,7 +981,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
else:
|
||||
model_result = await model_client.create(
|
||||
llm_messages,
|
||||
tools=all_tools,
|
||||
tools=tools,
|
||||
cancellation_token=cancellation_token,
|
||||
json_output=output_content_type,
|
||||
)
|
||||
@ -947,7 +996,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
agent_name: str,
|
||||
system_messages: List[SystemMessage],
|
||||
model_context: ChatCompletionContext,
|
||||
tools: List[BaseTool[Any, Any]],
|
||||
workbench: Workbench,
|
||||
handoff_tools: List[BaseTool[Any, Any]],
|
||||
handoffs: Dict[str, HandoffBase],
|
||||
model_client: ChatCompletionClient,
|
||||
@ -1006,7 +1055,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
*[
|
||||
cls._execute_tool_call(
|
||||
tool_call=call,
|
||||
tools=tools,
|
||||
workbench=workbench,
|
||||
handoff_tools=handoff_tools,
|
||||
agent_name=agent_name,
|
||||
cancellation_token=cancellation_token,
|
||||
@ -1238,32 +1287,16 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
@staticmethod
|
||||
async def _execute_tool_call(
|
||||
tool_call: FunctionCall,
|
||||
tools: List[BaseTool[Any, Any]],
|
||||
workbench: Workbench,
|
||||
handoff_tools: List[BaseTool[Any, Any]],
|
||||
agent_name: str,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Tuple[FunctionCall, FunctionExecutionResult]:
|
||||
"""Execute a single tool call and return the result."""
|
||||
# Load the arguments from the tool call.
|
||||
try:
|
||||
all_tools = tools + handoff_tools
|
||||
if not all_tools:
|
||||
raise ValueError("No tools are available.")
|
||||
tool = next((t for t in all_tools if t.name == tool_call.name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||
arguments: Dict[str, Any] = json.loads(tool_call.arguments) if tool_call.arguments else {}
|
||||
result = await tool.run_json(arguments, cancellation_token)
|
||||
result_as_str = tool.return_value_as_string(result)
|
||||
return (
|
||||
tool_call,
|
||||
FunctionExecutionResult(
|
||||
content=result_as_str,
|
||||
call_id=tool_call.id,
|
||||
is_error=False,
|
||||
name=tool_call.name,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
arguments = json.loads(tool_call.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
return (
|
||||
tool_call,
|
||||
FunctionExecutionResult(
|
||||
@ -1274,6 +1307,39 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the tool call is a handoff.
|
||||
# TODO: consider creating a combined workbench to handle both handoff and normal tools.
|
||||
for handoff_tool in handoff_tools:
|
||||
if tool_call.name == handoff_tool.name:
|
||||
# Run handoff tool call.
|
||||
result = await handoff_tool.run_json(arguments, cancellation_token)
|
||||
result_as_str = handoff_tool.return_value_as_string(result)
|
||||
return (
|
||||
tool_call,
|
||||
FunctionExecutionResult(
|
||||
content=result_as_str,
|
||||
call_id=tool_call.id,
|
||||
is_error=False,
|
||||
name=tool_call.name,
|
||||
),
|
||||
)
|
||||
|
||||
# Handle normal tool call using workbench.
|
||||
result = await workbench.call_tool(
|
||||
name=tool_call.name,
|
||||
arguments=arguments,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
return (
|
||||
tool_call,
|
||||
FunctionExecutionResult(
|
||||
content=result.to_text(),
|
||||
call_id=tool_call.id,
|
||||
is_error=result.is_error,
|
||||
name=tool_call.name,
|
||||
),
|
||||
)
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Reset the assistant agent to its initialization state."""
|
||||
await self._model_context.clear()
|
||||
@ -1304,6 +1370,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
name=self.name,
|
||||
model_client=self._model_client.dump_component(),
|
||||
tools=[tool.dump_component() for tool in self._tools],
|
||||
workbench=self._workbench.dump_component() if self._workbench else None,
|
||||
handoffs=list(self._handoffs.values()) if self._handoffs else None,
|
||||
model_context=self._model_context.dump_component(),
|
||||
memory=[memory.dump_component() for memory in self._memory] if self._memory else None,
|
||||
@ -1336,6 +1403,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
name=config.name,
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None,
|
||||
workbench=Workbench.load_component(config.workbench) if config.workbench else None,
|
||||
handoffs=config.handoffs,
|
||||
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
|
||||
memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None,
|
||||
|
||||
@ -33,7 +33,7 @@ from autogen_core.models import (
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.models._model_client import ModelFamily
|
||||
from autogen_core.tools import BaseTool, FunctionTool
|
||||
from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel, ValidationError
|
||||
@ -401,6 +401,124 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
|
||||
assert state == state2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_workbench() -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="function_calls",
|
||||
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
CreateResult(
|
||||
finish_reason="stop",
|
||||
content="Hello",
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
CreateResult(
|
||||
finish_reason="stop",
|
||||
content="TERMINATE",
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
],
|
||||
model_info={
|
||||
"function_calling": True,
|
||||
"vision": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4O,
|
||||
"structured_output": True,
|
||||
},
|
||||
)
|
||||
workbench = StaticWorkbench(
|
||||
[
|
||||
FunctionTool(_pass_function, description="Pass"),
|
||||
FunctionTool(_fail_function, description="Fail"),
|
||||
FunctionTool(_echo_function, description="Echo"),
|
||||
]
|
||||
)
|
||||
|
||||
# Test raise error when both workbench and tools are provided.
|
||||
with pytest.raises(ValueError):
|
||||
AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=model_client,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
FunctionTool(_echo_function, description="Echo"),
|
||||
],
|
||||
workbench=workbench,
|
||||
)
|
||||
|
||||
agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=model_client,
|
||||
workbench=workbench,
|
||||
reflect_on_tool_use=True,
|
||||
)
|
||||
result = await agent.run(task="task")
|
||||
|
||||
# Make sure the create call was made with the correct parameters.
|
||||
assert len(model_client.create_calls) == 2
|
||||
llm_messages = model_client.create_calls[0]["messages"]
|
||||
assert len(llm_messages) == 2
|
||||
assert isinstance(llm_messages[0], SystemMessage)
|
||||
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
|
||||
assert isinstance(llm_messages[1], UserMessage)
|
||||
assert llm_messages[1].content == "task"
|
||||
llm_messages = model_client.create_calls[1]["messages"]
|
||||
assert len(llm_messages) == 4
|
||||
assert isinstance(llm_messages[0], SystemMessage)
|
||||
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
|
||||
assert isinstance(llm_messages[1], UserMessage)
|
||||
assert llm_messages[1].content == "task"
|
||||
assert isinstance(llm_messages[2], AssistantMessage)
|
||||
assert isinstance(llm_messages[3], FunctionExecutionResultMessage)
|
||||
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].models_usage is None
|
||||
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
||||
assert result.messages[1].models_usage is not None
|
||||
assert result.messages[1].models_usage.completion_tokens == 5
|
||||
assert result.messages[1].models_usage.prompt_tokens == 10
|
||||
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
||||
assert result.messages[2].models_usage is None
|
||||
assert isinstance(result.messages[3], TextMessage)
|
||||
assert result.messages[3].content == "Hello"
|
||||
assert result.messages[3].models_usage is not None
|
||||
assert result.messages[3].models_usage.completion_tokens == 5
|
||||
assert result.messages[3].models_usage.prompt_tokens == 10
|
||||
|
||||
# Test streaming.
|
||||
model_client.reset()
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
# Test state saving and loading.
|
||||
state = await agent.save_state()
|
||||
agent2 = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=model_client,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
FunctionTool(_echo_function, description="Echo"),
|
||||
],
|
||||
)
|
||||
await agent2.load_state(state)
|
||||
state2 = await agent2.save_state()
|
||||
assert state == state2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_format() -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
|
||||
@ -43,7 +43,11 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
|
||||
) -> ToolResult:
|
||||
tool = next((tool for tool in self._tools if tool.name == name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool {name} not found in workbench.")
|
||||
return ToolResult(
|
||||
name=name,
|
||||
result=[TextResultContent(content=f"Tool {name} not found.")],
|
||||
is_error=True,
|
||||
)
|
||||
if not cancellation_token:
|
||||
cancellation_token = CancellationToken()
|
||||
if not arguments:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user