mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 14:09:00 +00:00
Add reflection for claude model in AssistantAgent (#6763)
This commit is contained in:
parent
4a3634d6da
commit
d2619049f3
@ -33,7 +33,6 @@ from autogen_core.models import (
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
SystemMessage,
|
||||
)
|
||||
from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, ToolResult, Workbench
|
||||
@ -847,16 +846,6 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
self._reflect_on_tool_use = False
|
||||
else:
|
||||
self._reflect_on_tool_use = reflect_on_tool_use
|
||||
if self._reflect_on_tool_use and ModelFamily.is_claude(model_client.model_info["family"]):
|
||||
warnings.warn(
|
||||
"Claude models may not work with reflection on tool use because Claude requires that any requests including a previous tool use or tool result must include the original tools definition."
|
||||
"Consider setting reflect_on_tool_use to False. "
|
||||
"As an alternative, consider calling the agent in a loop until it stops producing tool calls. "
|
||||
"See [Single-Agent Team](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/teams.html#single-agent-team) "
|
||||
"for more details.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Tool call loop
|
||||
self._max_tool_iterations = max_tool_iterations
|
||||
@ -1443,6 +1432,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
llm_messages,
|
||||
json_output=output_content_type,
|
||||
cancellation_token=cancellation_token,
|
||||
tool_choice="none", # Do not use tools in reflection flow.
|
||||
):
|
||||
if isinstance(chunk, CreateResult):
|
||||
reflection_result = chunk
|
||||
@ -1454,7 +1444,10 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||
else:
|
||||
reflection_result = await model_client.create(
|
||||
llm_messages, json_output=output_content_type, cancellation_token=cancellation_token
|
||||
llm_messages,
|
||||
json_output=output_content_type,
|
||||
cancellation_token=cancellation_token,
|
||||
tool_choice="none", # Do not use tools in reflection flow.
|
||||
)
|
||||
|
||||
if not reflection_result or not isinstance(reflection_result.content, str):
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
# Standard library imports
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, List, Optional, Union, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@ -39,6 +40,7 @@ from autogen_core.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_ext.models.anthropic import AnthropicChatCompletionClient
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from autogen_ext.tools.mcp import McpWorkbench, SseServerParams
|
||||
@ -1538,19 +1540,6 @@ class TestAssistantAgentSystemMessage:
|
||||
class TestAssistantAgentModelCompatibility:
|
||||
"""Test suite for model compatibility functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude_model_warning(self) -> None:
|
||||
"""Test warning for Claude models with reflection."""
|
||||
model_client = MagicMock()
|
||||
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.CLAUDE_3_5_SONNET}
|
||||
|
||||
with pytest.warns(UserWarning, match="Claude models may not work with reflection"):
|
||||
AssistantAgent(
|
||||
name="test_agent",
|
||||
model_client=model_client,
|
||||
reflect_on_tool_use=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vision_compatibility(self) -> None:
|
||||
"""Test vision model compatibility."""
|
||||
@ -2215,7 +2204,7 @@ class TestAssistantAgentThoughtHandling:
|
||||
|
||||
# The last assistant message should have the thought
|
||||
last_assistant_msg = assistant_messages[-1]
|
||||
# Fix line 2727 - properly check for thought attribute with type checking
|
||||
# Fix line 2730 - properly check for thought attribute with type checking
|
||||
if hasattr(last_assistant_msg, "thought"):
|
||||
thought_content = cast(str, last_assistant_msg.thought)
|
||||
assert thought_content == "Internal reasoning"
|
||||
@ -3453,3 +3442,121 @@ class TestAssistantAgentMessageContext:
|
||||
# Verify message conversion
|
||||
for msg in context_messages:
|
||||
assert isinstance(msg, (SystemMessage, UserMessage, AssistantMessage))
|
||||
|
||||
|
||||
class TestAnthropicIntegration:
|
||||
"""Test suite for Anthropic model API integration."""
|
||||
|
||||
def _get_anthropic_client(self) -> AnthropicChatCompletionClient:
|
||||
"""Create an Anthropic client for testing."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
||||
return AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307", # Use haiku for faster/cheaper testing
|
||||
api_key=api_key,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_call_loop_max_iterations_10(self) -> None:
|
||||
"""Test Anthropic integration with tool call loop and max_tool_iterations=10."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
||||
client = self._get_anthropic_client()
|
||||
|
||||
agent = AssistantAgent(
|
||||
name="anthropic_test_agent",
|
||||
model_client=client,
|
||||
tools=[mock_tool_function],
|
||||
max_tool_iterations=10,
|
||||
)
|
||||
|
||||
# Test with a task that might require tool calls
|
||||
result = await agent.run(
|
||||
task="Use the mock_tool_function to process the text 'hello world'. Then provide a summary."
|
||||
)
|
||||
|
||||
# Verify that we got a result
|
||||
assert result is not None
|
||||
assert isinstance(result, TaskResult)
|
||||
assert len(result.messages) > 0
|
||||
# Check that the last message is a non-tool call.
|
||||
assert isinstance(result.messages[-1], TextMessage)
|
||||
# Check that a tool call was made
|
||||
tool_calls = [msg for msg in result.messages if isinstance(msg, ToolCallRequestEvent)]
|
||||
assert len(tool_calls) > 0
|
||||
|
||||
# Check that usage was tracked
|
||||
usage = client.total_usage()
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_call_loop_max_iterations_1_with_reflection(self) -> None:
|
||||
"""Test Anthropic integration with max_tool_iterations=1 and reflect_on_tool_use=True."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
||||
client = self._get_anthropic_client()
|
||||
|
||||
agent = AssistantAgent(
|
||||
name="anthropic_reflection_agent",
|
||||
model_client=client,
|
||||
tools=[mock_tool_function],
|
||||
max_tool_iterations=1,
|
||||
reflect_on_tool_use=True,
|
||||
)
|
||||
|
||||
# Test with a task that might require tool calls but should be limited to 1 iteration
|
||||
result = await agent.run(
|
||||
task="Use the mock_tool_function to process the text 'test input' and then explain what happened."
|
||||
)
|
||||
|
||||
# Verify that we got a result
|
||||
assert result is not None
|
||||
assert isinstance(result, TaskResult)
|
||||
assert len(result.messages) > 0
|
||||
# Check that the last message is a reflection
|
||||
assert isinstance(result.messages[-1], TextMessage)
|
||||
# Check that a tool call was made
|
||||
tool_calls = [msg for msg in result.messages if isinstance(msg, ToolCallRequestEvent)]
|
||||
assert len(tool_calls) > 0
|
||||
|
||||
# Check that usage was tracked
|
||||
usage = client.total_usage()
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_basic_text_response(self) -> None:
|
||||
"""Test basic Anthropic integration without tools."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
||||
client = self._get_anthropic_client()
|
||||
|
||||
agent = AssistantAgent(
|
||||
name="anthropic_basic_agent",
|
||||
model_client=client,
|
||||
)
|
||||
|
||||
# Test with a simple task that doesn't require tools
|
||||
result = await agent.run(task="What is 2 + 2? Just answer with the number.")
|
||||
|
||||
# Verify that we got a result
|
||||
assert result is not None
|
||||
assert isinstance(result, TaskResult)
|
||||
# Check that we got a text message with content
|
||||
assert isinstance(result.messages[-1], TextMessage)
|
||||
assert "4" in result.messages[-1].content
|
||||
|
||||
# Check that usage was tracked
|
||||
usage = client.total_usage()
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user