Add reflection for claude model in AssistantAgent (#6763)

This commit is contained in:
Eric Zhu 2025-07-07 15:13:52 -07:00 committed by GitHub
parent 4a3634d6da
commit d2619049f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 126 additions and 26 deletions

View File

@ -33,7 +33,6 @@ from autogen_core.models import (
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
ModelFamily,
SystemMessage,
)
from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, ToolResult, Workbench
@ -847,16 +846,6 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
self._reflect_on_tool_use = False
else:
self._reflect_on_tool_use = reflect_on_tool_use
if self._reflect_on_tool_use and ModelFamily.is_claude(model_client.model_info["family"]):
warnings.warn(
"Claude models may not work with reflection on tool use because Claude requires that any requests including a previous tool use or tool result must include the original tools definition."
"Consider setting reflect_on_tool_use to False. "
"As an alternative, consider calling the agent in a loop until it stops producing tool calls. "
"See [Single-Agent Team](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/teams.html#single-agent-team) "
"for more details.",
UserWarning,
stacklevel=2,
)
# Tool call loop
self._max_tool_iterations = max_tool_iterations
@ -1443,6 +1432,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
llm_messages,
json_output=output_content_type,
cancellation_token=cancellation_token,
tool_choice="none", # Do not use tools in reflection flow.
):
if isinstance(chunk, CreateResult):
reflection_result = chunk
@ -1454,7 +1444,10 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
else:
reflection_result = await model_client.create(
llm_messages, json_output=output_content_type, cancellation_token=cancellation_token
llm_messages,
json_output=output_content_type,
cancellation_token=cancellation_token,
tool_choice="none", # Do not use tools in reflection flow.
)
if not reflection_result or not isinstance(reflection_result.content, str):

View File

@ -3,6 +3,7 @@
# Standard library imports
import asyncio
import json
import os
from typing import Any, List, Optional, Union, cast
from unittest.mock import AsyncMock, MagicMock, patch
@ -39,6 +40,7 @@ from autogen_core.models import (
SystemMessage,
UserMessage,
)
from autogen_ext.models.anthropic import AnthropicChatCompletionClient
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from autogen_ext.tools.mcp import McpWorkbench, SseServerParams
@ -1538,19 +1540,6 @@ class TestAssistantAgentSystemMessage:
class TestAssistantAgentModelCompatibility:
"""Test suite for model compatibility functionality."""
@pytest.mark.asyncio
async def test_claude_model_warning(self) -> None:
"""Test warning for Claude models with reflection."""
model_client = MagicMock()
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.CLAUDE_3_5_SONNET}
with pytest.warns(UserWarning, match="Claude models may not work with reflection"):
AssistantAgent(
name="test_agent",
model_client=model_client,
reflect_on_tool_use=True,
)
@pytest.mark.asyncio
async def test_vision_compatibility(self) -> None:
"""Test vision model compatibility."""
@ -2215,7 +2204,7 @@ class TestAssistantAgentThoughtHandling:
# The last assistant message should have the thought
last_assistant_msg = assistant_messages[-1]
# Fix line 2727 - properly check for thought attribute with type checking
# Fix line 2730 - properly check for thought attribute with type checking
if hasattr(last_assistant_msg, "thought"):
thought_content = cast(str, last_assistant_msg.thought)
assert thought_content == "Internal reasoning"
@ -3453,3 +3442,121 @@ class TestAssistantAgentMessageContext:
# Verify message conversion
for msg in context_messages:
assert isinstance(msg, (SystemMessage, UserMessage, AssistantMessage))
class TestAnthropicIntegration:
"""Test suite for Anthropic model API integration."""
def _get_anthropic_client(self) -> AnthropicChatCompletionClient:
"""Create an Anthropic client for testing."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
return AnthropicChatCompletionClient(
model="claude-3-haiku-20240307", # Use haiku for faster/cheaper testing
api_key=api_key,
temperature=0.0,
)
@pytest.mark.asyncio
async def test_anthropic_tool_call_loop_max_iterations_10(self) -> None:
"""Test Anthropic integration with tool call loop and max_tool_iterations=10."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
client = self._get_anthropic_client()
agent = AssistantAgent(
name="anthropic_test_agent",
model_client=client,
tools=[mock_tool_function],
max_tool_iterations=10,
)
# Test with a task that might require tool calls
result = await agent.run(
task="Use the mock_tool_function to process the text 'hello world'. Then provide a summary."
)
# Verify that we got a result
assert result is not None
assert isinstance(result, TaskResult)
assert len(result.messages) > 0
# Check that the last message is a non-tool call.
assert isinstance(result.messages[-1], TextMessage)
# Check that a tool call was made
tool_calls = [msg for msg in result.messages if isinstance(msg, ToolCallRequestEvent)]
assert len(tool_calls) > 0
# Check that usage was tracked
usage = client.total_usage()
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
@pytest.mark.asyncio
async def test_anthropic_tool_call_loop_max_iterations_1_with_reflection(self) -> None:
"""Test Anthropic integration with max_tool_iterations=1 and reflect_on_tool_use=True."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
client = self._get_anthropic_client()
agent = AssistantAgent(
name="anthropic_reflection_agent",
model_client=client,
tools=[mock_tool_function],
max_tool_iterations=1,
reflect_on_tool_use=True,
)
# Test with a task that might require tool calls but should be limited to 1 iteration
result = await agent.run(
task="Use the mock_tool_function to process the text 'test input' and then explain what happened."
)
# Verify that we got a result
assert result is not None
assert isinstance(result, TaskResult)
assert len(result.messages) > 0
# Check that the last message is a reflection
assert isinstance(result.messages[-1], TextMessage)
# Check that a tool call was made
tool_calls = [msg for msg in result.messages if isinstance(msg, ToolCallRequestEvent)]
assert len(tool_calls) > 0
# Check that usage was tracked
usage = client.total_usage()
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
@pytest.mark.asyncio
async def test_anthropic_basic_text_response(self) -> None:
"""Test basic Anthropic integration without tools."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
client = self._get_anthropic_client()
agent = AssistantAgent(
name="anthropic_basic_agent",
model_client=client,
)
# Test with a simple task that doesn't require tools
result = await agent.run(task="What is 2 + 2? Just answer with the number.")
# Verify that we got a result
assert result is not None
assert isinstance(result, TaskResult)
# Check that we got a text message with content
assert isinstance(result.messages[-1], TextMessage)
assert "4" in result.messages[-1].content
# Check that usage was tracked
usage = client.total_usage()
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0