Make sure thought content is included in handoff context (#6319)

Resolves #6295

Ensure the thought content gets included in handoff message conetxt,
when the only tool call was handoff tool call.
This commit is contained in:
Eric Zhu 2025-04-16 20:22:49 -07:00 committed by GitHub
parent 165c189f0e
commit fb16d5acf9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 95 additions and 10 deletions

View File

@ -1111,6 +1111,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
)
)
handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results))
elif model_result.thought:
# If no tool calls, but a thought exists, include it in the context
handoff_context.append(
AssistantMessage(
content=model_result.thought,
source=agent_name,
)
)
# Return response for the first handoff
return Response(

View File

@ -554,6 +554,7 @@ async def test_handoffs() -> None:
],
usage=RequestUsage(prompt_tokens=42, completion_tokens=43),
cached=False,
thought="Calling handoff function",
)
],
model_info={
@ -576,19 +577,95 @@ async def test_handoffs() -> None:
)
assert HandoffMessage in tool_use_agent.produced_message_types
result = await tool_use_agent.run(task="task")
assert len(result.messages) == 4
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 43
assert result.messages[1].models_usage.prompt_tokens == 42
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], HandoffMessage)
assert result.messages[3].content == handoff.message
assert result.messages[3].target == handoff.target
assert isinstance(result.messages[1], ThoughtEvent)
assert result.messages[1].content == "Calling handoff function"
assert isinstance(result.messages[2], ToolCallRequestEvent)
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 43
assert result.messages[2].models_usage.prompt_tokens == 42
assert isinstance(result.messages[3], ToolCallExecutionEvent)
assert result.messages[3].models_usage is None
assert isinstance(result.messages[4], HandoffMessage)
assert result.messages[4].content == handoff.message
assert result.messages[4].target == handoff.target
assert result.messages[4].models_usage is None
assert result.messages[4].context == [AssistantMessage(content="Calling handoff function", source="tool_use_agent")]
# Test streaming.
model_client.reset()
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_handoff_with_tool_call_context() -> None:
handoff = Handoff(target="agent2")
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[
FunctionCall(id="1", arguments=json.dumps({}), name=handoff.name),
FunctionCall(id="2", arguments=json.dumps({"input": "task"}), name="_pass_function"),
],
usage=RequestUsage(prompt_tokens=42, completion_tokens=43),
cached=False,
thought="Calling handoff function",
)
],
model_info={
"function_calling": True,
"vision": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=model_client,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
handoffs=[handoff],
)
assert HandoffMessage in tool_use_agent.produced_message_types
result = await tool_use_agent.run(task="task")
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ThoughtEvent)
assert result.messages[1].content == "Calling handoff function"
assert isinstance(result.messages[2], ToolCallRequestEvent)
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 43
assert result.messages[2].models_usage.prompt_tokens == 42
assert isinstance(result.messages[3], ToolCallExecutionEvent)
assert result.messages[3].models_usage is None
assert isinstance(result.messages[4], HandoffMessage)
assert result.messages[4].content == handoff.message
assert result.messages[4].target == handoff.target
assert result.messages[4].models_usage is None
assert result.messages[4].context == [
AssistantMessage(
content=[FunctionCall(id="2", arguments=r'{"input": "task"}', name="_pass_function")],
source="tool_use_agent",
thought="Calling handoff function",
),
FunctionExecutionResultMessage(
content=[FunctionExecutionResult(call_id="2", content="pass", is_error=False, name="_pass_function")]
),
]
# Test streaming.
model_client.reset()