diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 92d2ccaa4..e7ff6cc29 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -1022,10 +1022,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): # Collect normal tool calls (not handoff) into the handoff context tool_calls: List[FunctionCall] = [] tool_call_results: List[FunctionExecutionResult] = [] + # Collect the results returned by handoff_tool. By default, the message attribute will returned. + selected_handoff_message = selected_handoff.message for exec_call, exec_result in executed_calls_and_results: if exec_call.name not in handoffs: tool_calls.append(exec_call) tool_call_results.append(exec_result) + elif exec_call.name == selected_handoff.name: + selected_handoff_message = exec_result.content handoff_context: List[LLMMessage] = [] if len(tool_calls) > 0: @@ -1042,7 +1046,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): # Return response for the first handoff return Response( chat_message=HandoffMessage( - content=selected_handoff.message, + content=selected_handoff_message, target=selected_handoff.target, source=agent_name, context=handoff_context, diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py index 7afef094b..6820990a8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py @@ -24,6 +24,7 @@ class Handoff(BaseModel): message: str = Field(default="") """The message to the target agent. + By default, it will be the result for the handoff tool. If not provided, it is generated from the target agent's name.""" @model_validator(mode="before") @@ -54,3 +55,8 @@ class Handoff(BaseModel): return self.message return FunctionTool(_handoff_tool, name=self.name, description=self.description, strict=True) + + """ + The tool that can be used to handoff to the target agent. + Typically, the results of the tool's execution are provided to the target agent. + """ diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index f4f50f52b..db4fe42b7 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -1,6 +1,6 @@ import json import logging -from typing import List +from typing import Dict, List import pytest from autogen_agentchat import EVENT_LOGGER_NAME @@ -32,9 +32,10 @@ from autogen_core.models import ( UserMessage, ) from autogen_core.models._model_client import ModelFamily -from autogen_core.tools import FunctionTool +from autogen_core.tools import BaseTool, FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.models.replay import ReplayChatCompletionClient +from pydantic import BaseModel from utils import FileLogHandler logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -458,6 +459,158 @@ async def test_handoffs() -> None: index += 1 +@pytest.mark.asyncio +async def test_custom_handoffs() -> None: + name = "transfer_to_agent2" + description = "Handoff to agent2." + next_action = "next_action" + + class TextCommandHandOff(Handoff): + @property + def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]: + """Create a handoff tool from this handoff configuration.""" + + def _next_action(action: str) -> str: + """Returns the action you want the user to perform""" + return action + + return FunctionTool(_next_action, name=self.name, description=self.description, strict=True) + + handoff = TextCommandHandOff(name=name, description=description, target="agent2") + model_client = ReplayChatCompletionClient( + [ + CreateResult( + finish_reason="function_calls", + content=[ + FunctionCall(id="1", arguments=json.dumps({"action": next_action}), name=handoff.name), + ], + usage=RequestUsage(prompt_tokens=42, completion_tokens=43), + cached=False, + ) + ], + model_info={ + "function_calling": True, + "vision": True, + "json_output": True, + "family": ModelFamily.GPT_4O, + "structured_output": True, + }, + ) + tool_use_agent = AssistantAgent( + "tool_use_agent", + model_client=model_client, + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], + handoffs=[handoff], + ) + assert HandoffMessage in tool_use_agent.produced_message_types + result = await tool_use_agent.run(task="task") + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].models_usage is None + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert result.messages[1].models_usage is not None + assert result.messages[1].models_usage.completion_tokens == 43 + assert result.messages[1].models_usage.prompt_tokens == 42 + assert isinstance(result.messages[2], ToolCallExecutionEvent) + assert result.messages[2].models_usage is None + assert isinstance(result.messages[3], HandoffMessage) + assert result.messages[3].content == next_action + assert result.messages[3].target == handoff.target + + assert result.messages[3].models_usage is None + + # Test streaming. + model_client.reset() + index = 0 + async for message in tool_use_agent.run_stream(task="task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + + +@pytest.mark.asyncio +async def test_custom_object_handoffs() -> None: + """test handoff tool return a object""" + name = "transfer_to_agent2" + description = "Handoff to agent2." + next_action = {"action": "next_action"} # using a map, not a str + + class DictCommandHandOff(Handoff): + @property + def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]: + """Create a handoff tool from this handoff configuration.""" + + def _next_action(action: str) -> Dict[str, str]: + """Returns the action you want the user to perform""" + return {"action": action} + + return FunctionTool(_next_action, name=self.name, description=self.description, strict=True) + + handoff = DictCommandHandOff(name=name, description=description, target="agent2") + model_client = ReplayChatCompletionClient( + [ + CreateResult( + finish_reason="function_calls", + content=[ + FunctionCall(id="1", arguments=json.dumps({"action": "next_action"}), name=handoff.name), + ], + usage=RequestUsage(prompt_tokens=42, completion_tokens=43), + cached=False, + ) + ], + model_info={ + "function_calling": True, + "vision": True, + "json_output": True, + "family": ModelFamily.GPT_4O, + "structured_output": True, + }, + ) + tool_use_agent = AssistantAgent( + "tool_use_agent", + model_client=model_client, + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], + handoffs=[handoff], + ) + assert HandoffMessage in tool_use_agent.produced_message_types + result = await tool_use_agent.run(task="task") + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].models_usage is None + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert result.messages[1].models_usage is not None + assert result.messages[1].models_usage.completion_tokens == 43 + assert result.messages[1].models_usage.prompt_tokens == 42 + assert isinstance(result.messages[2], ToolCallExecutionEvent) + assert result.messages[2].models_usage is None + assert isinstance(result.messages[3], HandoffMessage) + # the content will return as a string, because the function call will convert to string + assert result.messages[3].content == str(next_action) + assert result.messages[3].target == handoff.target + + assert result.messages[3].models_usage is None + + # Test streaming. + model_client.reset() + index = 0 + async for message in tool_use_agent.run_stream(task="task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + + @pytest.mark.asyncio async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: model_client = ReplayChatCompletionClient(["Hello"])