diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 3e4996fe8..f5ba404d5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -1111,6 +1111,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): ) ) handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results)) + elif model_result.thought: + # If no tool calls, but a thought exists, include it in the context + handoff_context.append( + AssistantMessage( + content=model_result.thought, + source=agent_name, + ) + ) # Return response for the first handoff return Response( diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 192ad2fe9..6aab0374f 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -554,6 +554,7 @@ async def test_handoffs() -> None: ], usage=RequestUsage(prompt_tokens=42, completion_tokens=43), cached=False, + thought="Calling handoff function", ) ], model_info={ @@ -576,19 +577,95 @@ async def test_handoffs() -> None: ) assert HandoffMessage in tool_use_agent.produced_message_types result = await tool_use_agent.run(task="task") - assert len(result.messages) == 4 + assert len(result.messages) == 5 assert isinstance(result.messages[0], TextMessage) assert result.messages[0].models_usage is None - assert isinstance(result.messages[1], ToolCallRequestEvent) - assert result.messages[1].models_usage is not None - assert result.messages[1].models_usage.completion_tokens == 43 - assert result.messages[1].models_usage.prompt_tokens == 42 - assert isinstance(result.messages[2], ToolCallExecutionEvent) - assert result.messages[2].models_usage is None - assert isinstance(result.messages[3], HandoffMessage) - assert result.messages[3].content == handoff.message - assert result.messages[3].target == handoff.target + assert isinstance(result.messages[1], ThoughtEvent) + assert result.messages[1].content == "Calling handoff function" + assert isinstance(result.messages[2], ToolCallRequestEvent) + assert result.messages[2].models_usage is not None + assert result.messages[2].models_usage.completion_tokens == 43 + assert result.messages[2].models_usage.prompt_tokens == 42 + assert isinstance(result.messages[3], ToolCallExecutionEvent) assert result.messages[3].models_usage is None + assert isinstance(result.messages[4], HandoffMessage) + assert result.messages[4].content == handoff.message + assert result.messages[4].target == handoff.target + assert result.messages[4].models_usage is None + assert result.messages[4].context == [AssistantMessage(content="Calling handoff function", source="tool_use_agent")] + + # Test streaming. + model_client.reset() + index = 0 + async for message in tool_use_agent.run_stream(task="task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + + +@pytest.mark.asyncio +async def test_handoff_with_tool_call_context() -> None: + handoff = Handoff(target="agent2") + model_client = ReplayChatCompletionClient( + [ + CreateResult( + finish_reason="function_calls", + content=[ + FunctionCall(id="1", arguments=json.dumps({}), name=handoff.name), + FunctionCall(id="2", arguments=json.dumps({"input": "task"}), name="_pass_function"), + ], + usage=RequestUsage(prompt_tokens=42, completion_tokens=43), + cached=False, + thought="Calling handoff function", + ) + ], + model_info={ + "function_calling": True, + "vision": True, + "json_output": True, + "family": ModelFamily.GPT_4O, + "structured_output": True, + }, + ) + tool_use_agent = AssistantAgent( + "tool_use_agent", + model_client=model_client, + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], + handoffs=[handoff], + ) + assert HandoffMessage in tool_use_agent.produced_message_types + result = await tool_use_agent.run(task="task") + assert len(result.messages) == 5 + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].models_usage is None + assert isinstance(result.messages[1], ThoughtEvent) + assert result.messages[1].content == "Calling handoff function" + assert isinstance(result.messages[2], ToolCallRequestEvent) + assert result.messages[2].models_usage is not None + assert result.messages[2].models_usage.completion_tokens == 43 + assert result.messages[2].models_usage.prompt_tokens == 42 + assert isinstance(result.messages[3], ToolCallExecutionEvent) + assert result.messages[3].models_usage is None + assert isinstance(result.messages[4], HandoffMessage) + assert result.messages[4].content == handoff.message + assert result.messages[4].target == handoff.target + assert result.messages[4].models_usage is None + assert result.messages[4].context == [ + AssistantMessage( + content=[FunctionCall(id="2", arguments=r'{"input": "task"}', name="_pass_function")], + source="tool_use_agent", + thought="Calling handoff function", + ), + FunctionExecutionResultMessage( + content=[FunctionExecutionResult(call_id="2", content="pass", is_error=False, name="_pass_function")] + ), + ] # Test streaming. model_client.reset()