mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-04 23:50:39 +00:00

## Why are these changes needed? This PR addresses critical issues in the AssistantAgent that affect tool handling: **Lack of tool call loop functionality**: The agent could not perform multiple consecutive tool calls in a single turn, limiting its ability to complete complex multi-step tasks that require chaining tool operations. These changes enhance the agent's robustness and capability while maintaining full backward compatibility through feature flags. ## Related issue number Closes #6268 ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
3408 lines
127 KiB
Python
3408 lines
127 KiB
Python
"""Comprehensive tests for AssistantAgent functionality."""
|
|
|
|
# Standard library imports
|
|
import asyncio
|
|
import json
|
|
from typing import Any, List, Optional, Union, cast
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
# Third-party imports
|
|
import pytest
|
|
|
|
# First-party imports
|
|
from autogen_agentchat.agents import AssistantAgent
|
|
from autogen_agentchat.agents._assistant_agent import AssistantAgentConfig
|
|
from autogen_agentchat.base import Handoff, Response, TaskResult
|
|
from autogen_agentchat.messages import (
|
|
BaseChatMessage,
|
|
HandoffMessage,
|
|
MemoryQueryEvent,
|
|
ModelClientStreamingChunkEvent,
|
|
StructuredMessage,
|
|
TextMessage,
|
|
ThoughtEvent,
|
|
ToolCallExecutionEvent,
|
|
ToolCallRequestEvent,
|
|
ToolCallSummaryMessage,
|
|
)
|
|
from autogen_core import CancellationToken, ComponentModel, FunctionCall
|
|
from autogen_core.memory import Memory, MemoryContent, UpdateContextResult
|
|
from autogen_core.memory import MemoryQueryResult as MemoryQueryResultSet
|
|
from autogen_core.model_context import BufferedChatCompletionContext
|
|
from autogen_core.models import (
|
|
AssistantMessage,
|
|
CreateResult,
|
|
FunctionExecutionResult,
|
|
ModelFamily,
|
|
RequestUsage,
|
|
SystemMessage,
|
|
UserMessage,
|
|
)
|
|
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
|
from autogen_ext.tools.mcp import McpWorkbench, SseServerParams
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
|
|
def mock_tool_function(param: str) -> str:
|
|
"""Mock tool function for testing.
|
|
|
|
Args:
|
|
param: Input parameter to process
|
|
|
|
Returns:
|
|
Formatted string with the input parameter
|
|
"""
|
|
return f"Tool executed with: {param}"
|
|
|
|
|
|
async def async_mock_tool_function(param: str) -> str:
|
|
"""Async mock tool function for testing.
|
|
|
|
Args:
|
|
param: Input parameter to process
|
|
|
|
Returns:
|
|
Formatted string with the input parameter
|
|
"""
|
|
return f"Async tool executed with: {param}"
|
|
|
|
|
|
def _pass_function(input: str) -> str:
|
|
"""Pass through function for testing.
|
|
|
|
Args:
|
|
input: Input to pass through
|
|
|
|
Returns:
|
|
The string "pass"
|
|
"""
|
|
return "pass"
|
|
|
|
|
|
def _echo_function(input: str) -> str:
|
|
"""Echo function for testing.
|
|
|
|
Args:
|
|
input: Input to echo
|
|
|
|
Returns:
|
|
The input string
|
|
"""
|
|
return input
|
|
|
|
|
|
class MockMemory(Memory):
|
|
"""Mock memory implementation for testing.
|
|
|
|
A simple memory implementation that stores strings and provides basic memory operations
|
|
for testing purposes.
|
|
|
|
Args:
|
|
contents: Optional list of initial memory contents
|
|
"""
|
|
|
|
def __init__(self, contents: Optional[List[str]] = None) -> None:
|
|
"""Initialize mock memory.
|
|
|
|
Args:
|
|
contents: Optional list of initial memory contents
|
|
"""
|
|
self._contents: List[str] = contents or []
|
|
|
|
async def add(self, content: MemoryContent, cancellation_token: Optional[CancellationToken] = None) -> None:
|
|
"""Add content to memory.
|
|
|
|
Args:
|
|
content: Content to add to memory
|
|
cancellation_token: Optional token for cancelling operation
|
|
"""
|
|
self._contents.append(str(content))
|
|
|
|
async def query(
|
|
self, query: Union[str, MemoryContent], cancellation_token: Optional[CancellationToken] = None, **kwargs: Any
|
|
) -> MemoryQueryResultSet:
|
|
"""Query memory contents.
|
|
|
|
Args:
|
|
query: Search query
|
|
cancellation_token: Optional token for cancelling operation
|
|
kwargs: Additional query parameters
|
|
|
|
Returns:
|
|
Query results containing all memory contents
|
|
"""
|
|
results = [MemoryContent(content=content, mime_type="text/plain") for content in self._contents]
|
|
return MemoryQueryResultSet(results=results)
|
|
|
|
async def clear(self, cancellation_token: Optional[CancellationToken] = None) -> None:
|
|
"""Clear all memory contents.
|
|
|
|
Args:
|
|
cancellation_token: Optional token for cancelling operation
|
|
"""
|
|
self._contents.clear()
|
|
|
|
async def close(self) -> None:
|
|
"""Close memory resources."""
|
|
pass
|
|
|
|
async def update_context(self, model_context: Any) -> UpdateContextResult:
|
|
"""Update model context with memory contents.
|
|
|
|
Args:
|
|
model_context: Context to update
|
|
|
|
Returns:
|
|
Update result containing memory contents
|
|
"""
|
|
if self._contents:
|
|
results = [MemoryContent(content=content, mime_type="text/plain") for content in self._contents]
|
|
return UpdateContextResult(memories=MemoryQueryResultSet(results=results))
|
|
return UpdateContextResult(memories=MemoryQueryResultSet(results=[]))
|
|
|
|
def dump_component(self) -> ComponentModel:
|
|
"""Dump memory state as component model.
|
|
|
|
Returns:
|
|
Component model representing memory state
|
|
"""
|
|
return ComponentModel(provider="test", config={"type": "mock_memory"})
|
|
|
|
|
|
class StructuredOutput(BaseModel):
|
|
"""Test structured output model.
|
|
|
|
Attributes:
|
|
content: Main content string
|
|
confidence: Confidence score between 0 and 1
|
|
"""
|
|
|
|
content: str
|
|
confidence: float
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_model_client_stream() -> None:
|
|
mock_client = ReplayChatCompletionClient(
|
|
[
|
|
"Response to message 3",
|
|
]
|
|
)
|
|
agent = AssistantAgent(
|
|
"test_agent",
|
|
model_client=mock_client,
|
|
model_client_stream=True,
|
|
)
|
|
chunks: List[str] = []
|
|
async for message in agent.run_stream(task="task"):
|
|
if isinstance(message, TaskResult):
|
|
assert isinstance(message.messages[-1], TextMessage)
|
|
assert message.messages[-1].content == "Response to message 3"
|
|
elif isinstance(message, ModelClientStreamingChunkEvent):
|
|
chunks.append(message.content)
|
|
assert "".join(chunks) == "Response to message 3"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_model_client_stream_with_tool_calls() -> None:
|
|
mock_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
content=[
|
|
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
|
|
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
|
|
],
|
|
finish_reason="function_calls",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
"Example response 2 to task",
|
|
]
|
|
)
|
|
mock_client._model_info["function_calling"] = True # pyright: ignore
|
|
agent = AssistantAgent(
|
|
"test_agent",
|
|
model_client=mock_client,
|
|
model_client_stream=True,
|
|
reflect_on_tool_use=True,
|
|
tools=[_pass_function, _echo_function],
|
|
)
|
|
chunks: List[str] = []
|
|
async for message in agent.run_stream(task="task"):
|
|
if isinstance(message, TaskResult):
|
|
assert isinstance(message.messages[-1], TextMessage)
|
|
assert isinstance(message.messages[1], ToolCallRequestEvent)
|
|
assert message.messages[-1].content == "Example response 2 to task"
|
|
assert message.messages[1].content == [
|
|
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
|
|
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
|
|
]
|
|
assert isinstance(message.messages[2], ToolCallExecutionEvent)
|
|
assert message.messages[2].content == [
|
|
FunctionExecutionResult(call_id="1", content="pass", is_error=False, name="_pass_function"),
|
|
FunctionExecutionResult(call_id="3", content="task", is_error=False, name="_echo_function"),
|
|
]
|
|
elif isinstance(message, ModelClientStreamingChunkEvent):
|
|
chunks.append(message.content)
|
|
assert "".join(chunks) == "Example response 2 to task"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_structured_output_format() -> None:
|
|
class AgentResponse(BaseModel):
|
|
response: str
|
|
status: str
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content='{"response": "Hello"}',
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
]
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="assistant",
|
|
model_client=model_client,
|
|
output_content_type=AgentResponse,
|
|
)
|
|
|
|
with pytest.raises(ValidationError):
|
|
await agent.run()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_message_factory_serialization() -> None:
|
|
class AgentResponse(BaseModel):
|
|
result: str
|
|
status: str
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content=AgentResponse(result="All good", status="ok").model_dump_json(),
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
]
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="structured_agent",
|
|
model_client=model_client,
|
|
output_content_type=AgentResponse,
|
|
output_content_type_format="{result} - {status}",
|
|
)
|
|
|
|
dumped = agent.dump_component()
|
|
restored_agent = AssistantAgent.load_component(dumped)
|
|
result = await restored_agent.run()
|
|
|
|
assert isinstance(result.messages[0], StructuredMessage)
|
|
assert result.messages[0].content.result == "All good" # type: ignore
|
|
assert result.messages[0].content.status == "ok" # type: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_message_format_string() -> None:
|
|
class AgentResponse(BaseModel):
|
|
field1: str
|
|
field2: str
|
|
|
|
expected = AgentResponse(field1="foo", field2="bar")
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content=expected.model_dump_json(),
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
]
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="formatted_agent",
|
|
model_client=model_client,
|
|
output_content_type=AgentResponse,
|
|
output_content_type_format="{field1} - {field2}",
|
|
)
|
|
|
|
result = await agent.run()
|
|
|
|
assert len(result.messages) == 1
|
|
message = result.messages[0]
|
|
|
|
# Check that it's a StructuredMessage with the correct content model
|
|
assert isinstance(message, StructuredMessage)
|
|
assert isinstance(message.content, AgentResponse) # type: ignore[reportUnknownMemberType]
|
|
assert message.content == expected
|
|
|
|
# Check that the format_string was applied correctly
|
|
assert message.to_model_text() == "foo - bar"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tools_serialize_and_deserialize() -> None:
|
|
def test() -> str:
|
|
return "hello world"
|
|
|
|
client = OpenAIChatCompletionClient(
|
|
model="gpt-4o",
|
|
api_key="API_KEY",
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test",
|
|
model_client=client,
|
|
tools=[test],
|
|
)
|
|
|
|
serialize = agent.dump_component()
|
|
deserialize = AssistantAgent.load_component(serialize)
|
|
|
|
assert deserialize.name == agent.name
|
|
for original, restored in zip(agent._workbench, deserialize._workbench, strict=True): # type: ignore
|
|
assert await original.list_tools() == await restored.list_tools() # type: ignore
|
|
assert agent.component_version == deserialize.component_version
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_workbench_serialize_and_deserialize() -> None:
|
|
workbench = McpWorkbench(server_params=SseServerParams(url="http://test-url"))
|
|
|
|
client = OpenAIChatCompletionClient(
|
|
model="gpt-4o",
|
|
api_key="API_KEY",
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test",
|
|
model_client=client,
|
|
workbench=workbench,
|
|
)
|
|
|
|
serialize = agent.dump_component()
|
|
deserialize = AssistantAgent.load_component(serialize)
|
|
|
|
assert deserialize.name == agent.name
|
|
for original, restored in zip(agent._workbench, deserialize._workbench, strict=True): # type: ignore
|
|
assert isinstance(original, McpWorkbench)
|
|
assert isinstance(restored, McpWorkbench)
|
|
assert original._to_config() == restored._to_config() # type: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_workbenches_serialize_and_deserialize() -> None:
|
|
workbenches: List[McpWorkbench] = [
|
|
McpWorkbench(server_params=SseServerParams(url="http://test-url-1")),
|
|
McpWorkbench(server_params=SseServerParams(url="http://test-url-2")),
|
|
]
|
|
|
|
client = OpenAIChatCompletionClient(
|
|
model="gpt-4o",
|
|
api_key="API_KEY",
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_multi",
|
|
model_client=client,
|
|
workbench=workbenches,
|
|
)
|
|
|
|
serialize = agent.dump_component()
|
|
deserialized_agent: AssistantAgent = AssistantAgent.load_component(serialize)
|
|
|
|
assert deserialized_agent.name == agent.name
|
|
assert isinstance(deserialized_agent._workbench, list) # type: ignore
|
|
assert len(deserialized_agent._workbench) == len(workbenches) # type: ignore
|
|
|
|
for original, restored in zip(agent._workbench, deserialized_agent._workbench, strict=True): # type: ignore
|
|
assert isinstance(original, McpWorkbench)
|
|
assert isinstance(restored, McpWorkbench)
|
|
assert original._to_config() == restored._to_config() # type: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tools_deserialize_aware() -> None:
|
|
dump = """
|
|
{
|
|
"provider": "autogen_agentchat.agents.AssistantAgent",
|
|
"component_type": "agent",
|
|
"version": 1,
|
|
"component_version": 2,
|
|
"description": "An agent that provides assistance with tool use.",
|
|
"label": "AssistantAgent",
|
|
"config": {
|
|
"name": "TestAgent",
|
|
"model_client":{
|
|
"provider": "autogen_ext.models.replay.ReplayChatCompletionClient",
|
|
"component_type": "replay_chat_completion_client",
|
|
"version": 1,
|
|
"component_version": 1,
|
|
"description": "A mock chat completion client that replays predefined responses using an index-based approach.",
|
|
"label": "ReplayChatCompletionClient",
|
|
"config": {
|
|
"chat_completions": [
|
|
{
|
|
"finish_reason": "function_calls",
|
|
"content": [
|
|
{
|
|
"id": "hello",
|
|
"arguments": "{}",
|
|
"name": "hello"
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0
|
|
},
|
|
"cached": false
|
|
}
|
|
],
|
|
"model_info": {
|
|
"vision": false,
|
|
"function_calling": true,
|
|
"json_output": false,
|
|
"family": "unknown",
|
|
"structured_output": false
|
|
}
|
|
}
|
|
},
|
|
"tools": [
|
|
{
|
|
"provider": "autogen_core.tools.FunctionTool",
|
|
"component_type": "tool",
|
|
"version": 1,
|
|
"component_version": 1,
|
|
"description": "Create custom tools by wrapping standard Python functions.",
|
|
"label": "FunctionTool",
|
|
"config": {
|
|
"source_code": "def hello():\\n return 'Hello, World!'\\n",
|
|
"name": "hello",
|
|
"description": "",
|
|
"global_imports": [],
|
|
"has_cancellation_support": false
|
|
}
|
|
}
|
|
],
|
|
"model_context": {
|
|
"provider": "autogen_core.model_context.UnboundedChatCompletionContext",
|
|
"component_type": "chat_completion_context",
|
|
"version": 1,
|
|
"component_version": 1,
|
|
"description": "An unbounded chat completion context that keeps a view of the all the messages.",
|
|
"label": "UnboundedChatCompletionContext",
|
|
"config": {}
|
|
},
|
|
"description": "An agent that provides assistance with ability to use tools.",
|
|
"system_message": "You are a helpful assistant.",
|
|
"model_client_stream": false,
|
|
"reflect_on_tool_use": false,
|
|
"tool_call_summary_format": "{result}",
|
|
"metadata": {}
|
|
}
|
|
}
|
|
|
|
"""
|
|
|
|
# Test that agent can be deserialized from configuration
|
|
config = json.loads(dump)
|
|
agent = AssistantAgent.load_component(config)
|
|
|
|
# Verify the agent was loaded correctly
|
|
assert agent.name == "TestAgent"
|
|
assert agent.description == "An agent that provides assistance with ability to use tools."
|
|
|
|
|
|
class TestAssistantAgentToolCallLoop:
|
|
"""Test suite for tool call loop functionality.
|
|
|
|
Tests the behavior of AssistantAgent's tool call loop feature, which allows
|
|
multiple sequential tool calls before producing a final response.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_loop_enabled(self) -> None:
|
|
"""Test that tool call loop works when enabled.
|
|
|
|
Verifies that:
|
|
1. Multiple tool calls are executed in sequence
|
|
2. Loop continues until non-tool response
|
|
3. Final response is correct type
|
|
"""
|
|
# Create mock client with multiple tool calls followed by text response
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
# First tool call
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "first"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
# Second tool call (loop continues)
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="2", arguments=json.dumps({"param": "second"}), name="mock_tool_function")
|
|
],
|
|
usage=RequestUsage(prompt_tokens=12, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
# Final text response (loop ends)
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Task completed successfully!",
|
|
usage=RequestUsage(prompt_tokens=15, completion_tokens=10),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
max_tool_iterations=3,
|
|
)
|
|
|
|
result = await agent.run(task="Execute multiple tool calls")
|
|
|
|
# Verify multiple model calls were made
|
|
assert len(model_client.create_calls) == 3, f"Expected 3 calls, got {len(model_client.create_calls)}"
|
|
|
|
# Verify final response is text
|
|
final_message = result.messages[-1]
|
|
assert isinstance(final_message, TextMessage)
|
|
assert final_message.content == "Task completed successfully!"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_loop_disabled_default(self) -> None:
|
|
"""Test that tool call loop is disabled by default.
|
|
|
|
Verifies that:
|
|
1. Only one tool call is made when loop is disabled
|
|
2. Agent returns after first tool call
|
|
"""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
max_tool_iterations=1,
|
|
)
|
|
|
|
result = await agent.run(task="Execute single tool call")
|
|
|
|
# Should only make one model call
|
|
assert len(model_client.create_calls) == 1, f"Expected 1 call, got {len(model_client.create_calls)}"
|
|
assert result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_loop_max_iterations(self) -> None:
|
|
"""Test that tool call loop respects max_iterations limit."""
|
|
# Create responses that would continue forever without max_iterations
|
|
responses: List[CreateResult] = []
|
|
for i in range(15): # More than default max_iterations (10)
|
|
responses.append(
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id=str(i), arguments=json.dumps({"param": f"call_{i}"}), name="mock_tool_function")
|
|
],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
)
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
responses,
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
max_tool_iterations=5, # Set max iterations to 5
|
|
)
|
|
|
|
result = await agent.run(task="Test max iterations")
|
|
|
|
# Should stop at max_iterations
|
|
assert len(model_client.create_calls) == 5, f"Expected 5 calls, got {len(model_client.create_calls)}"
|
|
# Verify result is not None
|
|
assert result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_loop_with_handoff(self) -> None:
|
|
"""Test that tool call loop stops on handoff."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
# Tool call followed by handoff
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="mock_tool_function"),
|
|
FunctionCall(
|
|
id="2", arguments=json.dumps({"target": "other_agent"}), name="transfer_to_other_agent"
|
|
),
|
|
],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
handoffs=["other_agent"],
|
|
max_tool_iterations=1,
|
|
)
|
|
|
|
result = await agent.run(task="Test handoff in loop")
|
|
|
|
# Should stop at handoff
|
|
assert len(model_client.create_calls) == 1, f"Expected 1 call, got {len(model_client.create_calls)}"
|
|
|
|
# Should return HandoffMessage
|
|
assert isinstance(result.messages[-1], HandoffMessage)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_config_validation(self) -> None:
|
|
"""Test that ToolCallConfig validation works correctly."""
|
|
# Test that max_iterations must be >= 1
|
|
with pytest.raises(
|
|
ValueError, match="Maximum number of tool iterations must be greater than or equal to 1, got 0"
|
|
):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=MagicMock(),
|
|
max_tool_iterations=0, # Should raise error
|
|
)
|
|
|
|
|
|
class TestAssistantAgentInitialization:
|
|
"""Test suite for AssistantAgent initialization.
|
|
|
|
Tests various initialization scenarios and configurations of the AssistantAgent class.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_initialization(self) -> None:
|
|
"""Test basic agent initialization with minimal parameters.
|
|
|
|
Verifies that:
|
|
1. Agent initializes with required parameters
|
|
2. Default values are set correctly
|
|
3. Basic functionality works
|
|
"""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Hello!",
|
|
usage=RequestUsage(prompt_tokens=5, completion_tokens=2),
|
|
cached=False,
|
|
)
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(name="test_agent", model_client=model_client)
|
|
result = await agent.run(task="Say hello")
|
|
|
|
assert isinstance(result.messages[-1], TextMessage)
|
|
assert result.messages[-1].content == "Hello!"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_tools(self) -> None:
|
|
"""Test agent initialization with tools.
|
|
|
|
Verifies that:
|
|
1. Agent accepts tool configurations
|
|
2. Tools are properly registered
|
|
3. Tool calls work correctly
|
|
"""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
)
|
|
|
|
result = await agent.run(task="Use the tool")
|
|
assert isinstance(result.messages[-1], ToolCallSummaryMessage)
|
|
assert "Tool executed with: test" in result.messages[-1].content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_memory(self) -> None:
|
|
"""Test agent initialization with memory.
|
|
|
|
Verifies that:
|
|
1. Memory is properly integrated
|
|
2. Memory contents affect responses
|
|
3. Memory updates work correctly
|
|
"""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Using memory content",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
memory = MockMemory(contents=["Test memory content"])
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
memory=[memory],
|
|
)
|
|
|
|
result = await agent.run(task="Use memory")
|
|
assert isinstance(result.messages[-1], TextMessage)
|
|
assert result.messages[-1].content == "Using memory content"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_handoffs(self) -> None:
|
|
"""Test agent initialization with handoffs."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
handoffs=["agent1", Handoff(target="agent2")],
|
|
)
|
|
|
|
assert len(agent._handoffs) == 2 # type: ignore[reportPrivateUsage]
|
|
assert "transfer_to_agent1" in agent._handoffs # type: ignore[reportPrivateUsage]
|
|
assert "transfer_to_agent2" in agent._handoffs # type: ignore[reportPrivateUsage]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_custom_model_context(self) -> None:
|
|
"""Test agent initialization with custom model context."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
model_context = BufferedChatCompletionContext(buffer_size=5)
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
model_context=model_context,
|
|
)
|
|
|
|
assert agent._model_context == model_context # type: ignore[reportPrivateUsage]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_structured_output(self) -> None:
|
|
"""Test agent initialization with structured output."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
output_content_type=StructuredOutput,
|
|
)
|
|
|
|
assert agent._output_content_type == StructuredOutput # type: ignore[reportPrivateUsage]
|
|
assert agent._reflect_on_tool_use is True # type: ignore[reportPrivateUsage] # Should be True by default with structured output
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_metadata(self) -> None:
|
|
"""Test agent initialization with metadata."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
metadata = {"key1": "value1", "key2": "value2"}
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
metadata=metadata,
|
|
)
|
|
|
|
assert agent._metadata == metadata # type: ignore[reportPrivateUsage]
|
|
|
|
|
|
class TestAssistantAgentValidation:
|
|
"""Test suite for AssistantAgent validation.
|
|
|
|
Tests various validation scenarios to ensure proper error handling and input validation.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_names_must_be_unique(self) -> None:
|
|
"""Test validation of unique tool names.
|
|
|
|
Verifies that:
|
|
1. Duplicate tool names are detected
|
|
2. Appropriate error is raised
|
|
"""
|
|
|
|
def duplicate_tool(param: str) -> str:
|
|
"""Test tool with duplicate name.
|
|
|
|
Args:
|
|
param: Input parameter
|
|
|
|
Returns:
|
|
Formatted string with parameter
|
|
"""
|
|
return f"Duplicate tool: {param}"
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Tool names must be unique"):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function, duplicate_tool, mock_tool_function],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handoff_names_must_be_unique(self) -> None:
|
|
"""Test validation of unique handoff names.
|
|
|
|
Verifies that:
|
|
1. Duplicate handoff names are detected
|
|
2. Appropriate error is raised
|
|
"""
|
|
model_client = ReplayChatCompletionClient(
|
|
[],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Handoff names must be unique"):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
handoffs=["agent1", "agent2", "agent1"],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handoff_names_must_be_unique_from_tool_names(self) -> None:
|
|
"""Test validation of handoff names against tool names.
|
|
|
|
Verifies that:
|
|
1. Handoff names cannot conflict with tool names
|
|
2. Appropriate error is raised
|
|
"""
|
|
|
|
def test_tool() -> str:
|
|
"""Test tool with name that conflicts with handoff.
|
|
|
|
Returns:
|
|
Static test string
|
|
"""
|
|
return "test"
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Handoff names must be unique from tool names"):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[test_tool],
|
|
handoffs=["test_tool"],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_calling_required_for_tools(self) -> None:
|
|
"""Test that function calling is required for tools."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
with pytest.raises(ValueError, match="The model does not support function calling"):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_calling_required_for_handoffs(self) -> None:
|
|
"""Test that function calling is required for handoffs."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
with pytest.raises(
|
|
ValueError, match="The model does not support function calling, which is needed for handoffs"
|
|
):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
handoffs=["agent1"],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_type_validation(self) -> None:
|
|
"""Test memory type validation."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
with pytest.raises(TypeError, match="Expected Memory, List\\[Memory\\], or None"):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
memory="invalid_memory", # type: ignore
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tools_and_workbench_mutually_exclusive(self) -> None:
|
|
"""Test that tools and workbench are mutually exclusive."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
workbench = MagicMock()
|
|
|
|
with pytest.raises(ValueError, match="Tools cannot be used with a workbench"):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
workbench=workbench,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsupported_tool_type(self) -> None:
|
|
"""Test error handling for unsupported tool types."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
with pytest.raises(ValueError, match="Unsupported tool type"):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=["invalid_tool"], # type: ignore
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsupported_handoff_type(self) -> None:
|
|
"""Test error handling for unsupported handoff types."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
with pytest.raises(ValueError, match="Unsupported handoff type"):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
handoffs=[123], # type: ignore
|
|
)
|
|
|
|
|
|
class TestAssistantAgentStateManagement:
|
|
"""Test suite for AssistantAgent state management."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_save_and_load_state(self) -> None:
|
|
"""Test saving and loading agent state."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
# Mock model context state
|
|
mock_context = MagicMock()
|
|
mock_context.save_state = AsyncMock(return_value={"context": "state"})
|
|
mock_context.load_state = AsyncMock()
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
model_context=mock_context,
|
|
)
|
|
|
|
# Test save state
|
|
state = await agent.save_state()
|
|
assert "llm_context" in state
|
|
|
|
# Test load state
|
|
await agent.load_state(state)
|
|
mock_context.load_state.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_reset(self) -> None:
|
|
"""Test agent reset functionality."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.clear = AsyncMock()
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
model_context=mock_context,
|
|
)
|
|
|
|
cancellation_token = CancellationToken()
|
|
await agent.on_reset(cancellation_token)
|
|
|
|
mock_context.clear.assert_called_once()
|
|
|
|
|
|
class TestAssistantAgentProperties:
|
|
"""Test suite for AssistantAgent properties."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_produced_message_types_text_only(self) -> None:
|
|
"""Test produced message types for text-only agent."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
)
|
|
|
|
message_types = agent.produced_message_types
|
|
assert TextMessage in message_types
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_produced_message_types_with_tools(self) -> None:
|
|
"""Test produced message types for agent with tools."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
)
|
|
|
|
message_types = agent.produced_message_types
|
|
assert ToolCallSummaryMessage in message_types
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_produced_message_types_with_handoffs(self) -> None:
|
|
"""Test produced message types for agent with handoffs."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
handoffs=["agent1"],
|
|
)
|
|
|
|
message_types = agent.produced_message_types
|
|
assert HandoffMessage in message_types
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_model_context_property(self) -> None:
|
|
"""Test model_context property access."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
custom_context = BufferedChatCompletionContext(buffer_size=3)
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
model_context=custom_context,
|
|
)
|
|
|
|
assert agent.model_context == custom_context
|
|
|
|
|
|
class TestAssistantAgentErrorHandling:
|
|
"""Test suite for error handling scenarios."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_json_in_tool_arguments(self) -> None:
|
|
"""Test handling of invalid JSON in tool arguments."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments="invalid json", name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
)
|
|
|
|
result = await agent.run(task="Execute tool with invalid JSON")
|
|
|
|
# Should handle JSON parsing error
|
|
assert isinstance(result.messages[-1], ToolCallSummaryMessage)
|
|
|
|
|
|
class TestAssistantAgentMemoryIntegration:
|
|
"""Test suite for AssistantAgent memory integration.
|
|
|
|
Tests the integration between AssistantAgent and memory components, including:
|
|
- Memory initialization
|
|
- Context updates
|
|
- Query operations
|
|
- Memory persistence
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_updates_context(self) -> None:
|
|
"""Test that memory properly updates model context.
|
|
|
|
Verifies that:
|
|
1. Memory contents are added to context
|
|
2. Context updates trigger appropriate events
|
|
3. Memory query results are properly handled
|
|
"""
|
|
# Setup test memory with initial content
|
|
memory = MockMemory(contents=["Previous conversation about topic A"])
|
|
|
|
# Configure model client with expected response
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Response incorporating memory content",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Create agent with memory
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
memory=[memory],
|
|
description="Agent with memory integration",
|
|
)
|
|
|
|
# Track memory events during execution
|
|
memory_events: List[MemoryQueryEvent] = []
|
|
|
|
async def event_handler(event: MemoryQueryEvent) -> None:
|
|
"""Handle memory query events.
|
|
|
|
Args:
|
|
event: Memory query event to process
|
|
"""
|
|
memory_events.append(event)
|
|
|
|
# Create a handler function to capture memory events
|
|
async def handle_memory_events(result: Any) -> None:
|
|
messages: List[BaseChatMessage] = result.messages if hasattr(result, "messages") else []
|
|
for msg in messages:
|
|
if isinstance(msg, MemoryQueryEvent):
|
|
await event_handler(msg)
|
|
|
|
# Run agent
|
|
result = await agent.run(task="Respond using memory context")
|
|
|
|
# Process the events
|
|
await handle_memory_events(result)
|
|
|
|
# Verify memory integration
|
|
assert len(memory_events) > 0, "No memory events were generated"
|
|
assert isinstance(result.messages[-1], TextMessage)
|
|
assert "Response incorporating memory content" in result.messages[-1].content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_persistence(self) -> None:
|
|
"""Test memory persistence across multiple sessions.
|
|
|
|
Verifies:
|
|
1. Memory content persists between sessions
|
|
2. Memory updates are preserved
|
|
3. Context is properly restored
|
|
4. Memory query events are generated correctly
|
|
"""
|
|
# Create memory with initial content
|
|
memory = MockMemory(contents=["Initial memory"])
|
|
|
|
# Create model client
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Response using memory",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Response with updated memory",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": False,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Create agent with memory
|
|
agent = AssistantAgent(name="memory_test_agent", model_client=model_client, memory=[memory])
|
|
|
|
# First session
|
|
result1 = await agent.run(task="First task")
|
|
state = await agent.save_state()
|
|
|
|
# Add new memory content
|
|
await memory.add(MemoryContent(content="New memory", mime_type="text/plain"))
|
|
|
|
# Create new agent and restore state
|
|
new_agent = AssistantAgent(name="memory_test_agent", model_client=model_client, memory=[memory])
|
|
await new_agent.load_state(state)
|
|
|
|
# Second session
|
|
result2 = await new_agent.run(task="Second task")
|
|
|
|
# Verify memory persistence
|
|
assert isinstance(result1.messages[-1], TextMessage)
|
|
assert isinstance(result2.messages[-1], TextMessage)
|
|
assert result1.messages[-1].content == "Response using memory"
|
|
assert result2.messages[-1].content == "Response with updated memory"
|
|
|
|
# Verify memory events
|
|
memory_events = [msg for msg in result2.messages if isinstance(msg, MemoryQueryEvent)]
|
|
assert len(memory_events) > 0
|
|
assert any("New memory" in str(event.content) for event in memory_events)
|
|
|
|
|
|
class TestAssistantAgentSystemMessage:
|
|
"""Test suite for system message functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_message_none(self) -> None:
|
|
"""Test agent with system_message=None."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
system_message=None,
|
|
)
|
|
|
|
assert agent._system_messages == [] # type: ignore[reportPrivateUsage]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_system_message(self) -> None:
|
|
"""Test agent with custom system message."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
custom_message = "You are a specialized assistant."
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
system_message=custom_message,
|
|
)
|
|
|
|
assert len(agent._system_messages) == 1 # type: ignore[reportPrivateUsage]
|
|
assert agent._system_messages[0].content == custom_message # type: ignore[reportPrivateUsage]
|
|
|
|
|
|
class TestAssistantAgentModelCompatibility:
|
|
"""Test suite for model compatibility functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_claude_model_warning(self) -> None:
|
|
"""Test warning for Claude models with reflection."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.CLAUDE_3_5_SONNET}
|
|
|
|
with pytest.warns(UserWarning, match="Claude models may not work with reflection"):
|
|
AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
reflect_on_tool_use=True,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vision_compatibility(self) -> None:
|
|
"""Test vision model compatibility."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": True, "family": ModelFamily.GPT_4O}
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
)
|
|
|
|
# Test _get_compatible_context with vision model
|
|
from autogen_core.models import LLMMessage
|
|
|
|
messages: List[LLMMessage] = [SystemMessage(content="Test")]
|
|
compatible_messages = agent._get_compatible_context(model_client, messages) # type: ignore[reportPrivateUsage]
|
|
|
|
# Should return original messages for vision models
|
|
assert compatible_messages == messages
|
|
|
|
|
|
class TestAssistantAgentComponentSerialization:
|
|
"""Test suite for component serialization functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_to_config_basic_agent(self) -> None:
|
|
"""Test _to_config method with basic agent configuration."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
model_client.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_client"})
|
|
)
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_context"})
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
description="Test description",
|
|
system_message="Test system message",
|
|
model_context=mock_context,
|
|
metadata={"key": "value"},
|
|
)
|
|
|
|
config = agent._to_config() # type: ignore[reportPrivateUsage]
|
|
|
|
assert config.name == "test_agent"
|
|
assert config.description == "Test description"
|
|
assert config.system_message == "Test system message"
|
|
assert config.model_client_stream is False
|
|
assert config.reflect_on_tool_use is False
|
|
assert config.max_tool_iterations == 1
|
|
assert config.metadata == {"key": "value"}
|
|
model_client.dump_component.assert_called_once()
|
|
mock_context.dump_component.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_to_config_agent_with_handoffs(self) -> None:
|
|
"""Test _to_config method with agent having handoffs."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
model_client.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_client"})
|
|
)
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_context"})
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
handoffs=["agent1", Handoff(target="agent2")],
|
|
model_context=mock_context,
|
|
)
|
|
|
|
config = agent._to_config() # type: ignore[reportPrivateUsage]
|
|
|
|
assert config.handoffs is not None
|
|
assert len(config.handoffs) == 2
|
|
handoff_targets: List[str] = [h.target if hasattr(h, "target") else str(h) for h in config.handoffs] # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
|
|
assert "agent1" in handoff_targets
|
|
assert "agent2" in handoff_targets
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_to_config_agent_with_memory(self) -> None:
|
|
"""Test _to_config method with agent having memory modules."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
model_client.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_client"})
|
|
)
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_context"})
|
|
)
|
|
|
|
mock_memory = MockMemory()
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
memory=[mock_memory],
|
|
model_context=mock_context,
|
|
)
|
|
|
|
config = agent._to_config() # type: ignore[reportPrivateUsage]
|
|
|
|
assert config.memory is not None
|
|
assert len(config.memory) == 1
|
|
assert config.memory[0].provider == "test"
|
|
assert config.memory[0].config == {"type": "mock_memory"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_to_config_agent_with_workbench(self) -> None:
|
|
"""Test _to_config method with agent having workbench."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
model_client.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_client"})
|
|
)
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_context"})
|
|
)
|
|
|
|
mock_workbench = MagicMock()
|
|
mock_workbench.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_workbench"})
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
model_context=mock_context,
|
|
)
|
|
|
|
# Replace the workbench with our mock
|
|
agent._workbench = [mock_workbench] # type: ignore[reportPrivateUsage]
|
|
|
|
config = agent._to_config() # type: ignore[reportPrivateUsage]
|
|
|
|
assert config.workbench is not None
|
|
assert len(config.workbench) == 1
|
|
mock_workbench.dump_component.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_to_config_agent_with_structured_output(self) -> None:
|
|
"""Test _to_config method with agent having structured output."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
model_client.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_client"})
|
|
)
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_context"})
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
output_content_type=StructuredOutput,
|
|
model_context=mock_context,
|
|
)
|
|
|
|
config = agent._to_config() # type: ignore[reportPrivateUsage]
|
|
|
|
assert config.structured_message_factory is not None
|
|
assert config.reflect_on_tool_use is True # Should be True with structured output
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_to_config_system_message_none(self) -> None:
|
|
"""Test _to_config method with system_message=None."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
model_client.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_client"})
|
|
)
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_context"})
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
system_message=None,
|
|
model_context=mock_context,
|
|
)
|
|
|
|
config = agent._to_config() # type: ignore[reportPrivateUsage]
|
|
|
|
assert config.system_message is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_from_config_basic_agent(self) -> None:
|
|
"""Test _from_config method with basic agent configuration."""
|
|
mock_model_client = MagicMock()
|
|
mock_model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
with patch("autogen_core.models.ChatCompletionClient.load_component", return_value=mock_model_client):
|
|
config = AssistantAgentConfig(
|
|
name="test_agent",
|
|
model_client=ComponentModel(provider="test", config={"type": "mock_client"}),
|
|
description="Test description",
|
|
system_message="Test system",
|
|
model_client_stream=True,
|
|
reflect_on_tool_use=False,
|
|
tool_call_summary_format="{tool_name}: {result}",
|
|
metadata={"test": "value"},
|
|
)
|
|
|
|
agent = AssistantAgent._from_config(config) # type: ignore[reportPrivateUsage]
|
|
|
|
assert agent.name == "test_agent"
|
|
assert agent.description == "Test description"
|
|
assert agent._model_client_stream is True # type: ignore[reportPrivateUsage]
|
|
assert agent._reflect_on_tool_use is False # type: ignore[reportPrivateUsage]
|
|
assert agent._tool_call_summary_format == "{tool_name}: {result}" # type: ignore[reportPrivateUsage]
|
|
assert agent._metadata == {"test": "value"} # type: ignore[reportPrivateUsage]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_from_config_with_structured_output(self) -> None:
|
|
"""Test _from_config method with structured output configuration."""
|
|
mock_model_client = MagicMock()
|
|
mock_model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
mock_structured_factory = MagicMock()
|
|
mock_structured_factory.format_string = "Test format"
|
|
mock_structured_factory.ContentModel = StructuredOutput
|
|
|
|
with (
|
|
patch("autogen_core.models.ChatCompletionClient.load_component", return_value=mock_model_client),
|
|
patch(
|
|
"autogen_agentchat.messages.StructuredMessageFactory.load_component",
|
|
return_value=mock_structured_factory,
|
|
),
|
|
):
|
|
config = AssistantAgentConfig(
|
|
name="test_agent",
|
|
model_client=ComponentModel(provider="test", config={"type": "mock_client"}),
|
|
description="Test description",
|
|
reflect_on_tool_use=True,
|
|
tool_call_summary_format="{result}",
|
|
structured_message_factory=ComponentModel(provider="test", config={"type": "mock_factory"}),
|
|
)
|
|
|
|
agent = AssistantAgent._from_config(config) # type: ignore[reportPrivateUsage]
|
|
|
|
assert agent._reflect_on_tool_use is True # type: ignore[reportPrivateUsage]
|
|
assert agent._output_content_type == StructuredOutput # type: ignore[reportPrivateUsage]
|
|
assert agent._output_content_type_format == "Test format" # type: ignore[reportPrivateUsage]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_from_config_with_workbench_and_memory(self) -> None:
|
|
"""Test _from_config method with workbench and memory."""
|
|
mock_model_client = MagicMock()
|
|
mock_model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
mock_workbench = MagicMock()
|
|
mock_memory = MockMemory()
|
|
mock_context = MagicMock()
|
|
|
|
with (
|
|
patch("autogen_core.models.ChatCompletionClient.load_component", return_value=mock_model_client),
|
|
patch("autogen_core.tools.Workbench.load_component", return_value=mock_workbench),
|
|
patch("autogen_core.memory.Memory.load_component", return_value=mock_memory),
|
|
patch("autogen_core.model_context.ChatCompletionContext.load_component", return_value=mock_context),
|
|
):
|
|
config = AssistantAgentConfig(
|
|
name="test_agent",
|
|
model_client=ComponentModel(provider="test", config={"type": "mock_client"}),
|
|
description="Test description",
|
|
workbench=[ComponentModel(provider="test", config={"type": "mock_workbench"})],
|
|
memory=[ComponentModel(provider="test", config={"type": "mock_memory"})],
|
|
model_context=ComponentModel(provider="test", config={"type": "mock_context"}),
|
|
reflect_on_tool_use=True,
|
|
tool_call_summary_format="{result}",
|
|
)
|
|
|
|
agent = AssistantAgent._from_config(config) # type: ignore[reportPrivateUsage]
|
|
|
|
assert len(agent._workbench) == 1 # type: ignore[reportPrivateUsage]
|
|
assert agent._memory is not None # type: ignore[reportPrivateUsage]
|
|
assert len(agent._memory) == 1 # type: ignore[reportPrivateUsage]
|
|
assert agent._model_context == mock_context # type: ignore[reportPrivateUsage]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_config_roundtrip_consistency(self) -> None:
|
|
"""Test that converting to config and back preserves agent properties."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
model_client.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_client"})
|
|
)
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.dump_component = MagicMock(
|
|
return_value=ComponentModel(provider="test", config={"type": "mock_context"})
|
|
)
|
|
|
|
original_agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
description="Test description",
|
|
system_message="Test system message",
|
|
model_client_stream=True,
|
|
reflect_on_tool_use=True,
|
|
max_tool_iterations=5,
|
|
tool_call_summary_format="{tool_name}: {result}",
|
|
handoffs=["agent1"],
|
|
model_context=mock_context,
|
|
metadata={"test": "value"},
|
|
)
|
|
|
|
# Convert to config
|
|
config = original_agent._to_config() # type: ignore[reportPrivateUsage]
|
|
|
|
# Verify config properties
|
|
assert config.name == "test_agent"
|
|
assert config.description == "Test description"
|
|
assert config.system_message == "Test system message"
|
|
assert config.model_client_stream is True
|
|
assert config.reflect_on_tool_use is True
|
|
assert config.max_tool_iterations == 5
|
|
assert config.tool_call_summary_format == "{tool_name}: {result}"
|
|
assert config.metadata == {"test": "value"}
|
|
|
|
|
|
class TestAssistantAgentThoughtHandling:
|
|
"""Test suite for thought handling functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_thought_event_yielded_from_model_result(self) -> None:
|
|
"""Test that thought events are yielded when model result contains thoughts."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Final response",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
thought="This is my internal thought process",
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": False,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
)
|
|
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream(
|
|
[TextMessage(content="Test", source="user")], CancellationToken()
|
|
):
|
|
messages.append(message)
|
|
|
|
# Should have ThoughtEvent in the stream
|
|
thought_events = [msg for msg in messages if isinstance(msg, ThoughtEvent)]
|
|
assert len(thought_events) == 1
|
|
assert thought_events[0].content == "This is my internal thought process"
|
|
assert thought_events[0].source == "test_agent"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_thought_event_with_tool_calls(self) -> None:
|
|
"""Test that thought events are yielded when tool calls have thoughts."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
thought="I need to use this tool to help the user",
|
|
),
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Tool execution completed",
|
|
usage=RequestUsage(prompt_tokens=15, completion_tokens=10),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
max_tool_iterations=1,
|
|
)
|
|
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream(
|
|
[TextMessage(content="Test", source="user")], CancellationToken()
|
|
):
|
|
messages.append(message)
|
|
|
|
# Should have ThoughtEvent in the stream
|
|
thought_events = [msg for msg in messages if isinstance(msg, ThoughtEvent)]
|
|
assert len(thought_events) == 1
|
|
assert thought_events[0].content == "I need to use this tool to help the user"
|
|
assert thought_events[0].source == "test_agent"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_thought_event_with_reflection(self) -> None:
|
|
"""Test that thought events are yielded during reflection."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
# Initial tool call with thought
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
thought="Initial thought before tool call",
|
|
),
|
|
# Reflection with thought
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Based on the tool result, here's my response",
|
|
usage=RequestUsage(prompt_tokens=15, completion_tokens=10),
|
|
cached=False,
|
|
thought="Reflection thought after tool execution",
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
reflect_on_tool_use=True,
|
|
model_client_stream=True, # Enable streaming
|
|
)
|
|
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream(
|
|
[TextMessage(content="Test", source="user")], CancellationToken()
|
|
):
|
|
messages.append(message)
|
|
|
|
# Should have two ThoughtEvents - one for initial call, one for reflection
|
|
thought_events = [msg for msg in messages if isinstance(msg, ThoughtEvent)]
|
|
assert len(thought_events) == 2
|
|
|
|
thought_contents = [event.content for event in thought_events]
|
|
assert "Initial thought before tool call" in thought_contents
|
|
assert "Reflection thought after tool execution" in thought_contents
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_thought_event_with_tool_call_loop(self) -> None:
|
|
"""Test that thought events are yielded in tool call loops."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
# First tool call with thought
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "first"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
thought="First iteration thought",
|
|
),
|
|
# Second tool call with thought
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="2", arguments=json.dumps({"param": "second"}), name="mock_tool_function")
|
|
],
|
|
usage=RequestUsage(prompt_tokens=12, completion_tokens=5),
|
|
cached=False,
|
|
thought="Second iteration thought",
|
|
),
|
|
# Final response with thought
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Loop completed",
|
|
usage=RequestUsage(prompt_tokens=15, completion_tokens=10),
|
|
cached=False,
|
|
thought="Final completion thought",
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
max_tool_iterations=3,
|
|
)
|
|
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream(
|
|
[TextMessage(content="Test", source="user")], CancellationToken()
|
|
):
|
|
messages.append(message)
|
|
|
|
# Should have three ThoughtEvents - one for each iteration
|
|
thought_events = [msg for msg in messages if isinstance(msg, ThoughtEvent)]
|
|
assert len(thought_events) == 3
|
|
|
|
thought_contents = [event.content for event in thought_events]
|
|
assert "First iteration thought" in thought_contents
|
|
assert "Second iteration thought" in thought_contents
|
|
assert "Final completion thought" in thought_contents
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_thought_event_with_handoff(self) -> None:
|
|
"""Test that thought events are included in handoff context."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(
|
|
id="1", arguments=json.dumps({"target": "other_agent"}), name="transfer_to_other_agent"
|
|
)
|
|
],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
thought="I need to hand this off to another agent",
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
handoffs=["other_agent"],
|
|
max_tool_iterations=1,
|
|
)
|
|
|
|
result = await agent.run(task="Test handoff with thought")
|
|
|
|
# Should have ThoughtEvent in inner messages
|
|
thought_events = [msg for msg in result.messages if isinstance(msg, ThoughtEvent)]
|
|
assert len(thought_events) == 1
|
|
assert thought_events[0].content == "I need to hand this off to another agent"
|
|
|
|
# Should have handoff message with thought in context
|
|
handoff_message = result.messages[-1]
|
|
assert isinstance(handoff_message, HandoffMessage)
|
|
assert len(handoff_message.context) == 1
|
|
assert isinstance(handoff_message.context[0], AssistantMessage)
|
|
assert handoff_message.context[0].content == "I need to hand this off to another agent"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_thought_event_when_no_thought(self) -> None:
|
|
"""Test that no thought events are yielded when model result has no thoughts."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Simple response without thought",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
# No thought field
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": False,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
)
|
|
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream(
|
|
[TextMessage(content="Test", source="user")], CancellationToken()
|
|
):
|
|
messages.append(message)
|
|
|
|
# Should have no ThoughtEvents
|
|
thought_events = [msg for msg in messages if isinstance(msg, ThoughtEvent)]
|
|
assert len(thought_events) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_thought_event_context_preservation(self) -> None:
|
|
"""Test that thoughts are properly preserved in model context."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Response with thought",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
thought="Internal reasoning",
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": False,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
)
|
|
|
|
await agent.run(task="Test thought preservation")
|
|
|
|
# Check that the model context contains the thought
|
|
messages = await agent.model_context.get_messages()
|
|
assistant_messages = [msg for msg in messages if isinstance(msg, AssistantMessage)]
|
|
assert len(assistant_messages) > 0
|
|
|
|
# The last assistant message should have the thought
|
|
last_assistant_msg = assistant_messages[-1]
|
|
# Fix line 2727 - properly check for thought attribute with type checking
|
|
if hasattr(last_assistant_msg, "thought"):
|
|
thought_content = cast(str, last_assistant_msg.thought)
|
|
assert thought_content == "Internal reasoning"
|
|
|
|
|
|
class TestAssistantAgentAdvancedScenarios:
|
|
"""Test suite for advanced usage scenarios."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handoff_without_tool_calls(self) -> None:
|
|
"""Test handoff without any tool calls."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", arguments=json.dumps({"target": "agent2"}), name="transfer_to_agent2")
|
|
],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
handoffs=["agent2"],
|
|
)
|
|
|
|
result = await agent.run(task="Handoff to agent2")
|
|
|
|
# Should return HandoffMessage
|
|
assert isinstance(result.messages[-1], HandoffMessage)
|
|
assert result.messages[-1].target == "agent2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_handoff_warning(self) -> None:
|
|
"""Test warning for multiple handoffs."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", arguments=json.dumps({"target": "agent2"}), name="transfer_to_agent2"),
|
|
FunctionCall(id="2", arguments=json.dumps({"target": "agent3"}), name="transfer_to_agent3"),
|
|
],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
handoffs=["agent2", "agent3"],
|
|
)
|
|
|
|
with pytest.warns(UserWarning, match="Multiple handoffs detected"):
|
|
result = await agent.run(task="Multiple handoffs")
|
|
|
|
# Should only execute first handoff
|
|
assert isinstance(result.messages[-1], HandoffMessage)
|
|
assert result.messages[-1].target == "agent2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structured_output_with_reflection(self) -> None:
|
|
"""Test structured output with reflection enabled."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content='{"content": "Structured response", "confidence": 0.95}',
|
|
usage=RequestUsage(prompt_tokens=15, completion_tokens=10),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
output_content_type=StructuredOutput,
|
|
reflect_on_tool_use=True,
|
|
)
|
|
|
|
result = await agent.run(task="Test structured output with reflection")
|
|
|
|
# Should return StructuredMessage
|
|
from autogen_agentchat.messages import StructuredMessage
|
|
|
|
final_message = result.messages[-1]
|
|
assert isinstance(final_message, StructuredMessage)
|
|
# Fix line 1710 - properly access structured content with explicit type annotation
|
|
structured_message: StructuredMessage[StructuredOutput] = cast(
|
|
StructuredMessage[StructuredOutput], final_message
|
|
)
|
|
assert structured_message.content.content == "Structured response"
|
|
assert structured_message.content.confidence == 0.95
|
|
|
|
|
|
class TestAssistantAgentAdvancedToolFeatures:
|
|
"""Test suite for advanced tool features including custom formatters."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_tool_call_summary_formatter(self) -> None:
|
|
"""Test custom tool call summary formatter functionality."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", arguments=json.dumps({"param": "success"}), name="mock_tool_function"),
|
|
FunctionCall(id="2", arguments=json.dumps({"param": "error"}), name="mock_tool_function"),
|
|
],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
def custom_formatter(call: FunctionCall, result: FunctionExecutionResult) -> str:
|
|
if result.is_error:
|
|
return f"ERROR in {call.name}: {result.content} (args: {call.arguments})"
|
|
else:
|
|
return f"SUCCESS: {call.name} completed"
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
tool_call_summary_formatter=custom_formatter,
|
|
reflect_on_tool_use=False,
|
|
)
|
|
|
|
result = await agent.run(task="Test custom formatter")
|
|
|
|
# Should return ToolCallSummaryMessage with custom formatting
|
|
final_message = result.messages[-1]
|
|
assert isinstance(final_message, ToolCallSummaryMessage)
|
|
# Fix line 1875 - properly access content with type checking
|
|
assert hasattr(final_message, "content"), "ToolCallSummaryMessage should have content attribute"
|
|
content = final_message.content
|
|
assert "SUCCESS: mock_tool_function completed" in content
|
|
assert "SUCCESS: mock_tool_function completed" in content # Both calls should be successful
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_tool_call_summary_format_string(self) -> None:
|
|
"""Test custom tool call summary format string."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
tool_call_summary_format="Tool {tool_name} called with {arguments} -> {result}",
|
|
reflect_on_tool_use=False,
|
|
)
|
|
|
|
result = await agent.run(task="Test custom format string")
|
|
|
|
# Should return ToolCallSummaryMessage with custom format
|
|
final_message = result.messages[-1]
|
|
assert isinstance(final_message, ToolCallSummaryMessage)
|
|
content = final_message.content
|
|
assert "Tool mock_tool_function called with" in content
|
|
assert "Tool executed with: test" in content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_summary_formatter_overrides_format_string(self) -> None:
|
|
"""Test that tool_call_summary_formatter overrides format string."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
def custom_formatter(call: FunctionCall, result: FunctionExecutionResult) -> str:
|
|
return f"CUSTOM: {call.name} -> {result.content}"
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
tool_call_summary_format="This should be ignored: {result}",
|
|
tool_call_summary_formatter=custom_formatter,
|
|
reflect_on_tool_use=False,
|
|
)
|
|
|
|
result = await agent.run(task="Test formatter override")
|
|
|
|
# Should use custom formatter, not format string
|
|
final_message = result.messages[-1]
|
|
assert isinstance(final_message, ToolCallSummaryMessage)
|
|
content = final_message.content
|
|
assert "CUSTOM: mock_tool_function" in content
|
|
assert "This should be ignored" not in content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_output_content_type_format_string(self) -> None:
|
|
"""Test structured output with custom format string."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content='{"content": "Test response", "confidence": 0.8}',
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": False,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
output_content_type=StructuredOutput,
|
|
output_content_type_format="Response: {content} (Confidence: {confidence})",
|
|
)
|
|
|
|
result = await agent.run(task="Test structured output format")
|
|
|
|
# Should return StructuredMessage with custom format
|
|
final_message = result.messages[-1]
|
|
assert isinstance(final_message, StructuredMessage)
|
|
# Fix line 1880 - properly access structured content with explicit type annotation
|
|
structured_message: StructuredMessage[StructuredOutput] = cast(
|
|
StructuredMessage[StructuredOutput], final_message
|
|
)
|
|
assert structured_message.content.content == "Test response"
|
|
assert structured_message.content.confidence == 0.8
|
|
# The format string should be stored in the agent
|
|
assert hasattr(agent, "_output_content_type_format")
|
|
output_format = getattr(agent, "_output_content_type_format", None)
|
|
assert output_format == "Response: {content} (Confidence: {confidence})"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_error_handling_with_custom_formatter(self) -> None:
|
|
"""Test error handling in tool calls with custom formatter."""
|
|
|
|
def error_tool(param: str) -> str:
|
|
raise ValueError(f"Tool error with param: {param}")
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="error_tool")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
def error_formatter(call: FunctionCall, result: FunctionExecutionResult) -> str:
|
|
if result.is_error:
|
|
return f"ERROR in {call.name}: {result.content}"
|
|
else:
|
|
return f"SUCCESS: {result.content}"
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[error_tool],
|
|
tool_call_summary_formatter=error_formatter,
|
|
reflect_on_tool_use=False,
|
|
)
|
|
|
|
result = await agent.run(task="Test error handling")
|
|
|
|
# Should return ToolCallSummaryMessage with error formatting
|
|
assert isinstance(result.messages[-1], ToolCallSummaryMessage)
|
|
content = result.messages[-1].content
|
|
assert "ERROR in error_tool" in content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_tools_with_different_formats(self) -> None:
|
|
"""Test multiple tool calls with different return formats."""
|
|
|
|
def json_tool(data: str) -> str:
|
|
return json.dumps({"result": data, "status": "success"})
|
|
|
|
def simple_tool(text: str) -> str:
|
|
return f"Processed: {text}"
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", arguments=json.dumps({"data": "json_data"}), name="json_tool"),
|
|
FunctionCall(id="2", arguments=json.dumps({"text": "simple_text"}), name="simple_tool"),
|
|
],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
def smart_formatter(call: FunctionCall, result: FunctionExecutionResult) -> str:
|
|
try:
|
|
# Try to parse as JSON
|
|
parsed = json.loads(result.content)
|
|
return f"{call.name}: {parsed}"
|
|
except json.JSONDecodeError:
|
|
# Plain text
|
|
return f"{call.name}: {result.content}"
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[json_tool, simple_tool],
|
|
tool_call_summary_formatter=smart_formatter,
|
|
reflect_on_tool_use=False,
|
|
)
|
|
|
|
result = await agent.run(task="Test multiple tool formats")
|
|
|
|
# Should handle both JSON and plain text tools
|
|
assert isinstance(result.messages[-1], ToolCallSummaryMessage)
|
|
content = result.messages[-1].content
|
|
assert "json_tool:" in content
|
|
assert "simple_tool:" in content
|
|
assert "Processed: simple_text" in content
|
|
|
|
|
|
class TestAssistantAgentCancellationToken:
|
|
"""Test suite for cancellation token handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancellation_during_model_inference(self) -> None:
|
|
"""Test cancellation token during model inference."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
# Mock create method to check cancellation token
|
|
model_client.create = AsyncMock()
|
|
model_client.create.return_value = CreateResult(
|
|
finish_reason="stop",
|
|
content="Response",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
)
|
|
|
|
cancellation_token = CancellationToken()
|
|
result = await agent.on_messages([TextMessage(content="Test", source="user")], cancellation_token)
|
|
|
|
# Verify cancellation token was passed to model client
|
|
model_client.create.assert_called_once()
|
|
call_args = model_client.create.call_args
|
|
assert call_args.kwargs["cancellation_token"] == cancellation_token
|
|
# Verify result is not None
|
|
assert result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancellation_during_streaming_inference(self) -> None:
|
|
"""Test cancellation token during streaming model inference."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
# Mock create_stream method
|
|
async def mock_create_stream(*args: Any, **kwargs: Any) -> Any:
|
|
yield "chunk1" # First chunk
|
|
yield "chunk2" # Second chunk
|
|
yield CreateResult(
|
|
finish_reason="stop",
|
|
content="chunk1chunk2",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
|
|
model_client.create_stream = mock_create_stream
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
model_client_stream=True,
|
|
)
|
|
|
|
cancellation_token = CancellationToken()
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream([TextMessage(content="Test", source="user")], cancellation_token):
|
|
messages.append(message)
|
|
|
|
# Should have received streaming chunks and final response
|
|
chunk_events = [msg for msg in messages if isinstance(msg, ModelClientStreamingChunkEvent)]
|
|
assert len(chunk_events) == 2
|
|
assert chunk_events[0].content == "chunk1"
|
|
assert chunk_events[1].content == "chunk2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancellation_during_tool_execution(self) -> None:
|
|
"""Test cancellation token during tool execution."""
|
|
|
|
async def slow_tool(param: str) -> str:
|
|
await asyncio.sleep(0.1) # Simulate slow operation
|
|
return f"Slow result: {param}"
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="slow_tool")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[slow_tool],
|
|
)
|
|
|
|
cancellation_token = CancellationToken()
|
|
result = await agent.on_messages([TextMessage(content="Test", source="user")], cancellation_token)
|
|
|
|
# Tool should execute successfully with cancellation token
|
|
assert isinstance(result.chat_message, ToolCallSummaryMessage)
|
|
assert "Slow result: test" in result.chat_message.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancellation_during_workbench_tool_execution(self) -> None:
|
|
"""Test cancellation token during workbench tool execution."""
|
|
mock_workbench = MagicMock()
|
|
mock_workbench.list_tools = AsyncMock(return_value=[{"name": "test_tool", "description": "Test tool"}])
|
|
|
|
# Mock tool execution result
|
|
mock_result = MagicMock()
|
|
mock_result.to_text.return_value = "Workbench tool result"
|
|
mock_result.is_error = False
|
|
mock_workbench.call_tool = AsyncMock(return_value=mock_result)
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="test_tool")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
workbench=[mock_workbench],
|
|
)
|
|
|
|
cancellation_token = CancellationToken()
|
|
result = await agent.on_messages([TextMessage(content="Test", source="user")], cancellation_token)
|
|
|
|
# Verify cancellation token was passed to workbench
|
|
mock_workbench.call_tool.assert_called_once()
|
|
call_args = mock_workbench.call_tool.call_args
|
|
assert call_args.kwargs["cancellation_token"] == cancellation_token
|
|
# Verify result is not None
|
|
assert result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancellation_during_memory_operations(self) -> None:
|
|
"""Test cancellation token during memory operations."""
|
|
mock_memory = MagicMock()
|
|
mock_memory.update_context = AsyncMock(return_value=None)
|
|
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
model_client.create = AsyncMock(
|
|
return_value=CreateResult(
|
|
finish_reason="stop",
|
|
content="Response",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
memory=[mock_memory],
|
|
)
|
|
|
|
cancellation_token = CancellationToken()
|
|
await agent.on_messages([TextMessage(content="Test", source="user")], cancellation_token)
|
|
|
|
# Memory update_context should be called
|
|
mock_memory.update_context.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_with_cancellation_token(self) -> None:
|
|
"""Test agent reset with cancellation token."""
|
|
mock_context = MagicMock()
|
|
mock_context.clear = AsyncMock()
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=MagicMock(),
|
|
model_context=mock_context,
|
|
)
|
|
|
|
cancellation_token = CancellationToken()
|
|
await agent.on_reset(cancellation_token)
|
|
|
|
# Context clear should be called
|
|
mock_context.clear.assert_called_once()
|
|
|
|
|
|
class TestAssistantAgentStreamingEdgeCases:
|
|
"""Test suite for streaming edge cases and error scenarios."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_with_empty_chunks(self) -> None:
|
|
"""Test streaming with empty chunks."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
async def mock_create_stream(*args: Any, **kwargs: Any) -> Any:
|
|
yield "" # Empty chunk
|
|
yield "content"
|
|
yield "" # Another empty chunk
|
|
yield CreateResult(
|
|
finish_reason="stop",
|
|
content="content",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
|
|
model_client.create_stream = mock_create_stream
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
model_client_stream=True,
|
|
)
|
|
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream(
|
|
[TextMessage(content="Test", source="user")], CancellationToken()
|
|
):
|
|
messages.append(message)
|
|
|
|
# Should handle empty chunks gracefully
|
|
chunk_events = [msg for msg in messages if isinstance(msg, ModelClientStreamingChunkEvent)]
|
|
assert len(chunk_events) == 3 # Including empty chunks
|
|
assert chunk_events[0].content == ""
|
|
assert chunk_events[1].content == "content"
|
|
assert chunk_events[2].content == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_with_invalid_chunk_type(self) -> None:
|
|
"""Test streaming with invalid chunk type raises error."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
async def mock_create_stream(*args: Any, **kwargs: Any) -> Any:
|
|
yield "valid_chunk"
|
|
yield 123 # Invalid chunk type
|
|
yield CreateResult(
|
|
finish_reason="stop",
|
|
content="content",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
|
|
model_client.create_stream = mock_create_stream
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
model_client_stream=True,
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="Invalid chunk type"):
|
|
async for _ in agent.on_messages_stream([TextMessage(content="Test", source="user")], CancellationToken()):
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_without_final_result(self) -> None:
|
|
"""Test streaming without final CreateResult raises error."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
async def mock_create_stream(*args: Any, **kwargs: Any) -> Any:
|
|
yield "chunk1"
|
|
yield "chunk2"
|
|
# No final CreateResult
|
|
|
|
model_client.create_stream = mock_create_stream
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
model_client_stream=True,
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="No final model result in streaming mode"):
|
|
async for _ in agent.on_messages_stream([TextMessage(content="Test", source="user")], CancellationToken()):
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_with_tool_calls_and_reflection(self) -> None:
|
|
"""Test streaming with tool calls followed by reflection."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": True, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
call_count = 0
|
|
|
|
async def mock_create_stream(*args: Any, **kwargs: Any) -> Any:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
|
|
if call_count == 1:
|
|
# First call: tool call
|
|
yield CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="mock_tool_function")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
else:
|
|
# Second call: reflection streaming
|
|
yield "Reflection "
|
|
yield "response "
|
|
yield "complete"
|
|
yield CreateResult(
|
|
finish_reason="stop",
|
|
content="Reflection response complete",
|
|
usage=RequestUsage(prompt_tokens=15, completion_tokens=10),
|
|
cached=False,
|
|
)
|
|
|
|
model_client.create_stream = mock_create_stream
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
reflect_on_tool_use=True,
|
|
model_client_stream=True,
|
|
)
|
|
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream(
|
|
[TextMessage(content="Test", source="user")], CancellationToken()
|
|
):
|
|
messages.append(message)
|
|
|
|
# Should have tool call events, execution events, and streaming chunks for reflection
|
|
tool_call_events = [msg for msg in messages if isinstance(msg, ToolCallRequestEvent)]
|
|
tool_exec_events = [msg for msg in messages if isinstance(msg, ToolCallExecutionEvent)]
|
|
chunk_events = [msg for msg in messages if isinstance(msg, ModelClientStreamingChunkEvent)]
|
|
|
|
assert len(tool_call_events) == 1
|
|
assert len(tool_exec_events) == 1
|
|
assert len(chunk_events) == 3 # Three reflection chunks
|
|
assert chunk_events[0].content == "Reflection "
|
|
assert chunk_events[1].content == "response "
|
|
assert chunk_events[2].content == "complete"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_with_large_chunks(self) -> None:
|
|
"""Test streaming with large chunks."""
|
|
model_client = MagicMock()
|
|
model_client.model_info = {"function_calling": False, "vision": False, "family": ModelFamily.GPT_4O}
|
|
|
|
large_chunk = "x" * 10000 # 10KB chunk
|
|
|
|
async def mock_create_stream(*args: Any, **kwargs: Any) -> Any:
|
|
yield large_chunk
|
|
yield CreateResult(
|
|
finish_reason="stop",
|
|
content=large_chunk,
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
)
|
|
|
|
model_client.create_stream = mock_create_stream
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
model_client_stream=True,
|
|
)
|
|
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream(
|
|
[TextMessage(content="Test", source="user")], CancellationToken()
|
|
):
|
|
messages.append(message)
|
|
|
|
# Should handle large chunks
|
|
chunk_events = [msg for msg in messages if isinstance(msg, ModelClientStreamingChunkEvent)]
|
|
assert len(chunk_events) == 1
|
|
assert len(chunk_events[0].content) == 10000
|
|
|
|
|
|
class TestAssistantAgentWorkbenchIntegration:
|
|
"""Test suite for comprehensive workbench testing."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_workbenches(self) -> None:
|
|
"""Test agent with multiple workbenches."""
|
|
mock_workbench1 = MagicMock()
|
|
mock_workbench1.list_tools = AsyncMock(return_value=[{"name": "tool1", "description": "Tool from workbench 1"}])
|
|
mock_result1 = MagicMock()
|
|
mock_result1.to_text.return_value = "Result from workbench 1"
|
|
mock_result1.is_error = False
|
|
mock_workbench1.call_tool = AsyncMock(return_value=mock_result1)
|
|
|
|
mock_workbench2 = MagicMock()
|
|
mock_workbench2.list_tools = AsyncMock(return_value=[{"name": "tool2", "description": "Tool from workbench 2"}])
|
|
mock_result2 = MagicMock()
|
|
mock_result2.to_text.return_value = "Result from workbench 2"
|
|
mock_result2.is_error = False
|
|
mock_workbench2.call_tool = AsyncMock(return_value=mock_result2)
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", arguments=json.dumps({"param": "test1"}), name="tool1"),
|
|
FunctionCall(id="2", arguments=json.dumps({"param": "test2"}), name="tool2"),
|
|
],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
workbench=[mock_workbench1, mock_workbench2],
|
|
)
|
|
|
|
result = await agent.run(task="Test multiple workbenches")
|
|
|
|
# Both workbenches should be called
|
|
mock_workbench1.call_tool.assert_called_once()
|
|
mock_workbench2.call_tool.assert_called_once()
|
|
|
|
# Should return summary with both results
|
|
assert isinstance(result.messages[-1], ToolCallSummaryMessage)
|
|
content = result.messages[-1].content
|
|
assert "Result from workbench 1" in content
|
|
assert "Result from workbench 2" in content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_workbench_tool_not_found(self) -> None:
|
|
"""Test handling when tool is not found in any workbench."""
|
|
mock_workbench = MagicMock()
|
|
mock_workbench.list_tools = AsyncMock(
|
|
return_value=[{"name": "available_tool", "description": "Available tool"}]
|
|
)
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", arguments=json.dumps({"param": "test"}), name="missing_tool")],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
workbench=[mock_workbench],
|
|
)
|
|
|
|
result = await agent.run(task="Test missing tool")
|
|
|
|
# Should return error message for missing tool
|
|
assert isinstance(result.messages[-1], ToolCallSummaryMessage)
|
|
content = result.messages[-1].content
|
|
assert "tool 'missing_tool' not found" in content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_workbench_concurrent_tool_execution(self) -> None:
|
|
"""Test concurrent execution of multiple workbench tools."""
|
|
mock_workbench = MagicMock()
|
|
mock_workbench.list_tools = AsyncMock(
|
|
return_value=[
|
|
{"name": "concurrent_tool1", "description": "Concurrent tool 1"},
|
|
{"name": "concurrent_tool2", "description": "Concurrent tool 2"},
|
|
]
|
|
)
|
|
|
|
call_order: List[str] = []
|
|
|
|
async def mock_call_tool(name: str, **kwargs: Any) -> Any:
|
|
call_order.append(f"start_{name}")
|
|
await asyncio.sleep(0.01) # Simulate work
|
|
call_order.append(f"end_{name}")
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.to_text.return_value = f"Result from {name}"
|
|
mock_result.is_error = False
|
|
return mock_result
|
|
|
|
mock_workbench.call_tool = mock_call_tool
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", arguments=json.dumps({"param": "test1"}), name="concurrent_tool1"),
|
|
FunctionCall(id="2", arguments=json.dumps({"param": "test2"}), name="concurrent_tool2"),
|
|
],
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="test_agent",
|
|
model_client=model_client,
|
|
workbench=[mock_workbench],
|
|
)
|
|
|
|
result = await agent.run(task="Test concurrent execution")
|
|
|
|
# Should execute both tools concurrently (both start before either ends)
|
|
assert "start_concurrent_tool1" in call_order
|
|
assert "start_concurrent_tool2" in call_order
|
|
|
|
# Both results should be present
|
|
assert isinstance(result.messages[-1], ToolCallSummaryMessage)
|
|
content = result.messages[-1].content
|
|
assert "Result from concurrent_tool1" in content
|
|
assert "Result from concurrent_tool2" in content
|
|
|
|
|
|
class TestAssistantAgentComplexIntegration:
|
|
"""Test suite for complex integration scenarios."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_workflow_with_all_features(self) -> None:
|
|
"""Test agent with tools, handoffs, memory, streaming, and reflection."""
|
|
# Setup memory
|
|
memory = MockMemory(["User prefers detailed explanations"])
|
|
|
|
# Setup model client with complex workflow
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
# Initial tool call
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", arguments=json.dumps({"param": "analysis"}), name="mock_tool_function")
|
|
],
|
|
usage=RequestUsage(prompt_tokens=20, completion_tokens=10),
|
|
cached=False,
|
|
thought="I need to analyze this first",
|
|
),
|
|
# Reflection result
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Based on the analysis, I can provide a detailed response. The user prefers comprehensive explanations.",
|
|
usage=RequestUsage(prompt_tokens=30, completion_tokens=15),
|
|
cached=False,
|
|
thought="I should be thorough based on user preference",
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="comprehensive_agent",
|
|
model_client=model_client,
|
|
tools=[mock_tool_function],
|
|
handoffs=["specialist_agent"],
|
|
memory=[memory],
|
|
reflect_on_tool_use=True,
|
|
model_client_stream=True,
|
|
tool_call_summary_format="Analysis: {result}",
|
|
metadata={"test": "comprehensive"},
|
|
)
|
|
|
|
messages: List[Any] = []
|
|
async for message in agent.on_messages_stream(
|
|
[TextMessage(content="Analyze this complex scenario", source="user")], CancellationToken()
|
|
):
|
|
messages.append(message)
|
|
|
|
# Should have all types of events
|
|
memory_events = [msg for msg in messages if isinstance(msg, MemoryQueryEvent)]
|
|
thought_events = [msg for msg in messages if isinstance(msg, ThoughtEvent)]
|
|
tool_events = [msg for msg in messages if isinstance(msg, ToolCallRequestEvent)]
|
|
execution_events = [msg for msg in messages if isinstance(msg, ToolCallExecutionEvent)]
|
|
chunk_events = [msg for msg in messages if isinstance(msg, ModelClientStreamingChunkEvent)]
|
|
|
|
assert len(memory_events) > 0
|
|
assert len(thought_events) == 2 # Initial and reflection thoughts
|
|
assert len(tool_events) == 1
|
|
assert len(execution_events) == 1
|
|
assert len(chunk_events) == 0 # No streaming chunks since we removed the string responses
|
|
|
|
# Final response should be TextMessage from reflection
|
|
final_response = None
|
|
for msg in reversed(messages):
|
|
if isinstance(msg, Response):
|
|
final_response = msg
|
|
break
|
|
|
|
assert final_response is not None
|
|
assert isinstance(final_response.chat_message, TextMessage)
|
|
assert "comprehensive explanations" in final_response.chat_message.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_recovery_in_complex_workflow(self) -> None:
|
|
"""Test error recovery in complex workflow with multiple failures."""
|
|
|
|
def failing_tool(param: str) -> str:
|
|
if param == "fail":
|
|
raise ValueError("Tool failure")
|
|
return f"Success: {param}"
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
# Multiple tool calls, some failing
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", arguments=json.dumps({"param": "success"}), name="failing_tool"),
|
|
FunctionCall(id="2", arguments=json.dumps({"param": "fail"}), name="failing_tool"),
|
|
FunctionCall(id="3", arguments=json.dumps({"param": "success2"}), name="failing_tool"),
|
|
],
|
|
usage=RequestUsage(prompt_tokens=20, completion_tokens=10),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": True,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
def error_aware_formatter(call: FunctionCall, result: FunctionExecutionResult) -> str:
|
|
if result.is_error:
|
|
return f"⚠️ {call.name} failed: {result.content}"
|
|
else:
|
|
return f"✅ {call.name}: {result.content}"
|
|
|
|
agent = AssistantAgent(
|
|
name="error_recovery_agent",
|
|
model_client=model_client,
|
|
tools=[failing_tool],
|
|
tool_call_summary_formatter=error_aware_formatter,
|
|
reflect_on_tool_use=False,
|
|
)
|
|
|
|
result = await agent.run(task="Test error recovery")
|
|
|
|
# Should handle mixed success/failure gracefully
|
|
assert isinstance(result.messages[-1], ToolCallSummaryMessage)
|
|
content = result.messages[-1].content
|
|
assert "✅ failing_tool: Success: success" in content
|
|
assert "⚠️ failing_tool failed:" in content
|
|
assert "✅ failing_tool: Success: success2" in content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_state_persistence_across_interactions(self) -> None:
|
|
"""Test that agent state persists correctly across multiple interactions."""
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
# First interaction
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="First response",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
# Second interaction
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Second response, remembering context",
|
|
usage=RequestUsage(prompt_tokens=15, completion_tokens=8),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": False,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
agent = AssistantAgent(
|
|
name="stateful_agent",
|
|
model_client=model_client,
|
|
system_message="Remember previous conversations",
|
|
)
|
|
|
|
# First interaction
|
|
result1 = await agent.run(task="First task")
|
|
final_message_1 = result1.messages[-1]
|
|
assert isinstance(final_message_1, TextMessage)
|
|
assert final_message_1.content == "First response"
|
|
|
|
# Save state
|
|
state = await agent.save_state()
|
|
assert "llm_context" in state
|
|
|
|
# Second interaction
|
|
result2 = await agent.run(task="Second task, referring to first")
|
|
# Fix line 2730 - properly access content on TextMessage
|
|
final_message_2 = result2.messages[-1]
|
|
assert isinstance(final_message_2, TextMessage)
|
|
assert final_message_2.content == "Second response, remembering context"
|
|
|
|
# Verify context contains both interactions
|
|
context_messages = await agent.model_context.get_messages()
|
|
user_messages = [
|
|
msg for msg in context_messages if hasattr(msg, "source") and getattr(msg, "source", None) == "user"
|
|
]
|
|
assert len(user_messages) == 2
|
|
|
|
|
|
class TestAssistantAgentMessageContext:
|
|
"""Test suite for message context handling in AssistantAgent.
|
|
|
|
Tests various scenarios of message handling, context updates, and state management.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_messages_to_context(self) -> None:
|
|
"""Test adding different message types to context.
|
|
|
|
Verifies:
|
|
1. Regular messages are added correctly
|
|
2. Handoff messages with context are handled properly
|
|
3. Message order is preserved
|
|
4. Model messages are converted correctly
|
|
"""
|
|
# Setup test context
|
|
model_context = BufferedChatCompletionContext(buffer_size=10)
|
|
|
|
# Create test messages
|
|
regular_msg = TextMessage(content="Regular message", source="user")
|
|
handoff_msg = HandoffMessage(content="Handoff message", source="agent1", target="agent2")
|
|
|
|
# Add messages to context
|
|
await AssistantAgent._add_messages_to_context(model_context=model_context, messages=[regular_msg, handoff_msg]) # type: ignore[reportPrivateUsage]
|
|
|
|
# Verify context contents
|
|
context_messages = await model_context.get_messages()
|
|
|
|
# Should have: regular + handoff = 2 messages (now that handoff doesn't have context)
|
|
assert len(context_messages) == 2
|
|
|
|
# Verify message order and content - only the added messages should be present
|
|
assert isinstance(context_messages[0], UserMessage)
|
|
assert context_messages[0].content == "Regular message"
|
|
|
|
assert isinstance(context_messages[1], UserMessage)
|
|
assert context_messages[1].content == "Handoff message"
|
|
|
|
# No more assertions needed for context_messages since we already verified both
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complex_model_context(self) -> None:
|
|
"""Test complex model context management scenarios.
|
|
|
|
Verifies:
|
|
1. Large context handling
|
|
2. Mixed message type handling
|
|
3. Context size limits
|
|
4. Message filtering
|
|
"""
|
|
# Setup test context with limited size
|
|
model_context = BufferedChatCompletionContext(buffer_size=5)
|
|
|
|
# Create a mix of message types
|
|
messages: List[BaseChatMessage] = [
|
|
TextMessage(content="First message", source="user"),
|
|
StructuredMessage[StructuredOutput](
|
|
content=StructuredOutput(content="Structured data", confidence=0.9), source="agent"
|
|
),
|
|
ToolCallSummaryMessage(content="Tool result", source="agent", tool_calls=[], results=[]),
|
|
HandoffMessage(content="Handoff", source="agent1", target="agent2"),
|
|
]
|
|
|
|
# Add messages to context
|
|
await AssistantAgent._add_messages_to_context(model_context=model_context, messages=messages) # type: ignore[reportPrivateUsage]
|
|
|
|
# Verify context management
|
|
context_messages = await model_context.get_messages()
|
|
|
|
# Should respect buffer size limit
|
|
assert len(context_messages) <= 5
|
|
|
|
# Verify message conversion
|
|
for msg in context_messages:
|
|
assert isinstance(msg, (SystemMessage, UserMessage, AssistantMessage))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_persistence(self) -> None:
|
|
"""Test memory persistence across multiple sessions.
|
|
|
|
Verifies:
|
|
1. Memory content persists between sessions
|
|
2. Memory updates are preserved
|
|
3. Context is properly restored
|
|
4. Memory query events are generated correctly
|
|
"""
|
|
# Create memory with initial content
|
|
memory = MockMemory(contents=["Initial memory"])
|
|
|
|
# Create model client
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Response using memory",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
CreateResult(
|
|
finish_reason="stop",
|
|
content="Response with updated memory",
|
|
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
|
cached=False,
|
|
),
|
|
],
|
|
model_info={
|
|
"function_calling": False,
|
|
"vision": False,
|
|
"json_output": False,
|
|
"family": ModelFamily.GPT_4O,
|
|
"structured_output": False,
|
|
},
|
|
)
|
|
|
|
# Create agent with memory
|
|
agent = AssistantAgent(name="memory_test_agent", model_client=model_client, memory=[memory])
|
|
|
|
# First session
|
|
result1 = await agent.run(task="First task")
|
|
state = await agent.save_state()
|
|
|
|
# Add new memory content
|
|
await memory.add(MemoryContent(content="New memory", mime_type="text/plain"))
|
|
|
|
# Create new agent and restore state
|
|
new_agent = AssistantAgent(name="memory_test_agent", model_client=model_client, memory=[memory])
|
|
await new_agent.load_state(state)
|
|
|
|
# Second session
|
|
result2 = await new_agent.run(task="Second task")
|
|
|
|
# Verify memory persistence
|
|
assert isinstance(result1.messages[-1], TextMessage)
|
|
assert isinstance(result2.messages[-1], TextMessage)
|
|
assert result1.messages[-1].content == "Response using memory"
|
|
assert result2.messages[-1].content == "Response with updated memory"
|
|
|
|
# Verify memory events
|
|
memory_events = [msg for msg in result2.messages if isinstance(msg, MemoryQueryEvent)]
|
|
assert len(memory_events) > 0
|
|
assert any("New memory" in str(event.content) for event in memory_events)
|