fix: make sure system message is present in reflection call (#5926)

Resolves #5919
This commit is contained in:
Eric Zhu 2025-03-13 14:29:46 -07:00
parent 997ad60a7d
commit 904cb0f4b8
2 changed files with 32 additions and 1 deletions

View File

@ -784,6 +784,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
inner_messages=inner_messages,
cancellation_token=cancellation_token,
agent_name=agent_name,
system_messages=system_messages,
model_context=model_context,
tools=tools,
handoff_tools=handoff_tools,
@ -878,6 +879,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
inner_messages: List[AgentEvent | ChatMessage],
cancellation_token: CancellationToken,
agent_name: str,
system_messages: List[SystemMessage],
model_context: ChatCompletionContext,
tools: List[BaseTool[Any, Any]],
handoff_tools: List[BaseTool[Any, Any]],
@ -959,6 +961,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
# STEP 4D: Reflect or summarize tool results
if reflect_on_tool_use:
async for reflection_response in AssistantAgent._reflect_on_tool_use_flow(
system_messages=system_messages,
model_client=model_client,
model_client_stream=model_client_stream,
model_context=model_context,
@ -1039,6 +1042,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
@classmethod
async def _reflect_on_tool_use_flow(
cls,
system_messages: List[SystemMessage],
model_client: ChatCompletionClient,
model_client_stream: bool,
model_context: ChatCompletionContext,
@ -1049,7 +1053,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
If reflect_on_tool_use=True, we do another inference based on tool results
and yield the final text response (or streaming chunks).
"""
all_messages = await model_context.get_messages()
all_messages = system_messages + await model_context.get_messages()
llm_messages = cls._get_compatible_context(model_client=model_client, messages=all_messages)
reflection_result: Optional[CreateResult] = None

View File

@ -25,6 +25,7 @@ from autogen_core.models import (
AssistantMessage,
CreateResult,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
SystemMessage,
@ -80,6 +81,15 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
)
result = await agent.run(task="task")
# Make sure the create call was made with the correct parameters.
assert len(model_client.create_calls) == 1
llm_messages = model_client.create_calls[0]["messages"]
assert len(llm_messages) == 2
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
@ -150,6 +160,23 @@ async def test_run_with_tools_and_reflection() -> None:
)
result = await agent.run(task="task")
# Make sure the create call was made with the correct parameters.
assert len(model_client.create_calls) == 2
llm_messages = model_client.create_calls[0]["messages"]
assert len(llm_messages) == 2
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
llm_messages = model_client.create_calls[1]["messages"]
assert len(llm_messages) == 4
assert isinstance(llm_messages[0], SystemMessage)
assert llm_messages[0].content == agent._system_messages[0].content # type: ignore
assert isinstance(llm_messages[1], UserMessage)
assert llm_messages[1].content == "task"
assert isinstance(llm_messages[2], AssistantMessage)
assert isinstance(llm_messages[3], FunctionExecutionResultMessage)
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None