autogen/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py
Copilot e0e39e475b
Fix Redis caching always returning False due to unhandled string values (#7022)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2025-09-16 02:30:08 -07:00

959 lines
41 KiB
Python

import copy
import json
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import pytest
from autogen_core import CacheStore, FunctionCall
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
def get_test_data(
num_messages: int = 3,
) -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]:
responses = [f"This is dummy message number {i}" for i in range(num_messages)]
prompts = [f"This is dummy prompt number {i}" for i in range(num_messages)]
system_prompt = SystemMessage(content="This is a system prompt")
replay_client = ReplayChatCompletionClient(responses)
replay_client.set_cached_bool_value(False)
cached_client = ChatCompletionCache(replay_client)
return responses, prompts, system_prompt, replay_client, cached_client
@pytest.mark.asyncio
async def test_cache_basic_with_args() -> None:
responses, prompts, system_prompt, _, cached_client = get_test_data()
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[0]
response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")])
assert not response1.cached
assert response1.content == responses[1]
# Cached output.
response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert response0_cached.cached
assert response0_cached.content == responses[0]
# Cache miss if args change.
response2 = await cached_client.create(
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=True
)
assert isinstance(response2, CreateResult)
assert not response2.cached
assert response2.content == responses[2]
@pytest.mark.asyncio
async def test_cache_structured_output_with_args() -> None:
responses, prompts, system_prompt, _, cached_client = get_test_data(num_messages=4)
class Answer(BaseModel):
thought: str
answer: str
class Answer2(BaseModel):
calculation: str
answer: str
response0 = await cached_client.create(
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=Answer
)
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[0]
response1 = await cached_client.create(
[system_prompt, UserMessage(content=prompts[1], source="user")], json_output=Answer
)
assert not response1.cached
assert response1.content == responses[1]
# Cached output.
response0_cached = await cached_client.create(
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=Answer
)
assert isinstance(response0, CreateResult)
assert response0_cached.cached
assert response0_cached.content == responses[0]
# Without the json_output argument, the cache should not be hit.
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[2]
# With a different output type, the cache should not be hit.
response0 = await cached_client.create(
[system_prompt, UserMessage(content=prompts[1], source="user")], json_output=Answer2
)
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[3]
@pytest.mark.asyncio
async def test_cache_model_and_count_api() -> None:
_, prompts, system_prompt, replay_client, cached_client = get_test_data()
assert replay_client.model_info == cached_client.model_info
assert replay_client.capabilities == cached_client.capabilities
messages: List[LLMMessage] = [system_prompt, UserMessage(content=prompts[0], source="user")]
assert replay_client.count_tokens(messages) == cached_client.count_tokens(messages)
assert replay_client.remaining_tokens(messages) == cached_client.remaining_tokens(messages)
@pytest.mark.asyncio
async def test_cache_token_usage() -> None:
responses, prompts, system_prompt, replay_client, cached_client = get_test_data()
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[0]
actual_usage0 = copy.copy(cached_client.actual_usage())
total_usage0 = copy.copy(cached_client.total_usage())
response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")])
assert not response1.cached
assert response1.content == responses[1]
actual_usage1 = copy.copy(cached_client.actual_usage())
total_usage1 = copy.copy(cached_client.total_usage())
assert total_usage1.prompt_tokens > total_usage0.prompt_tokens
assert total_usage1.completion_tokens > total_usage0.completion_tokens
assert actual_usage1.prompt_tokens == actual_usage0.prompt_tokens
assert actual_usage1.completion_tokens == actual_usage0.completion_tokens
# Cached output.
response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert response0_cached.cached
assert response0_cached.content == responses[0]
total_usage2 = copy.copy(cached_client.total_usage())
assert total_usage2.prompt_tokens == total_usage1.prompt_tokens
assert total_usage2.completion_tokens == total_usage1.completion_tokens
assert cached_client.actual_usage() == replay_client.actual_usage()
assert cached_client.total_usage() == replay_client.total_usage()
@pytest.mark.asyncio
async def test_cache_create_stream() -> None:
_, prompts, system_prompt, _, cached_client = get_test_data()
original_streamed_results: List[Union[str, CreateResult]] = []
async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")]
):
original_streamed_results.append(copy.copy(completion))
total_usage0 = copy.copy(cached_client.total_usage())
cached_completion_results: List[Union[str, CreateResult]] = []
async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")]
):
cached_completion_results.append(copy.copy(completion))
total_usage1 = copy.copy(cached_client.total_usage())
assert total_usage1.prompt_tokens == total_usage0.prompt_tokens
assert total_usage1.completion_tokens == total_usage0.completion_tokens
for original, cached in zip(original_streamed_results, cached_completion_results, strict=False):
if isinstance(original, str):
assert original == cached
elif isinstance(original, CreateResult) and isinstance(cached, CreateResult):
assert original.content == cached.content
assert cached.cached
assert not original.cached
else:
raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}")
# test serialization
# cached_client_config = cached_client.dump_component()
# loaded_client = ChatCompletionCache.load_component(cached_client_config)
# assert loaded_client.client == cached_client.client
class MockCacheStore(CacheStore[CHAT_CACHE_VALUE_TYPE]):
"""Mock cache store for testing deserialization scenarios."""
def __init__(self, return_value: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> None:
self._return_value = return_value
self._storage: Dict[str, CHAT_CACHE_VALUE_TYPE] = {}
def get(self, key: str, default: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> Optional[CHAT_CACHE_VALUE_TYPE]:
return self._return_value # type: ignore
def set(self, key: str, value: CHAT_CACHE_VALUE_TYPE) -> None:
self._storage[key] = value
def _to_config(self) -> BaseModel:
"""Dummy implementation for testing."""
return BaseModel()
@classmethod
def _from_config(cls, _config: BaseModel) -> "MockCacheStore":
"""Dummy implementation for testing."""
return cls()
def test_check_cache_redis_dict_deserialization_success() -> None:
"""Test _check_cache when Redis cache returns a dict that can be deserialized to CreateResult.
This tests the core Redis deserialization fix where Redis returns serialized Pydantic
models as dictionaries instead of the original objects.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a CreateResult instance (simulating deserialized Redis data)
create_result = CreateResult(
content="test response from redis",
usage=RequestUsage(prompt_tokens=15, completion_tokens=8),
cached=False,
finish_reason="stop",
)
# Mock cache store that returns a CreateResult (simulating Redis behavior)
mock_store = MockCacheStore(return_value=create_result)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly using proper test data
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
assert cached_result is not None
assert isinstance(cached_result, CreateResult)
assert cached_result.content == "test response from redis"
assert cache_key is not None
def test_check_cache_redis_dict_deserialization_failure() -> None:
"""Test _check_cache gracefully handles corrupted Redis data.
This ensures the system degrades gracefully when Redis returns corrupted
or invalid data that cannot be deserialized back to CreateResult.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Mock cache store that returns None (simulating deserialization failure)
mock_store = MockCacheStore(return_value=None)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly using proper test data
messages = [system_prompt, UserMessage(content=prompts[1], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when deserialization fails
assert cached_result is None
assert cache_key is not None
def test_check_cache_redis_streaming_dict_deserialization() -> None:
"""Test _check_cache with Redis streaming data containing dicts that need deserialization.
This tests the streaming scenario where Redis returns a list containing
serialized CreateResult objects as dictionaries mixed with string chunks.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a list with CreateResult objects mixed with strings (streaming scenario)
create_result = CreateResult(
content="final streaming response from redis",
usage=RequestUsage(prompt_tokens=12, completion_tokens=6),
cached=False,
finish_reason="stop",
)
cached_list: List[Union[str, CreateResult]] = [
"streaming chunk 1",
create_result, # Proper CreateResult object
"streaming chunk 2",
]
# Mock cache store that returns the list with CreateResults (simulating Redis streaming)
mock_store = MockCacheStore(return_value=cached_list)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly using proper test data
messages = [system_prompt, UserMessage(content=prompts[2], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
assert cached_result is not None
assert isinstance(cached_result, list)
assert len(cached_result) == 3
assert cached_result[0] == "streaming chunk 1"
assert isinstance(cached_result[1], CreateResult)
assert cached_result[1].content == "final streaming response from redis"
assert cached_result[2] == "streaming chunk 2"
assert cache_key is not None
def test_check_cache_redis_streaming_deserialization_failure() -> None:
"""Test _check_cache gracefully handles corrupted Redis streaming data.
This ensures the system degrades gracefully when Redis returns streaming
data with corrupted CreateResult dictionaries that cannot be deserialized.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data(num_messages=4)
# Mock cache store that returns None (simulating deserialization failure)
mock_store = MockCacheStore(return_value=None)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly using proper test data
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when deserialization fails
assert cached_result is None
assert cache_key is not None
def test_check_cache_dict_reconstruction_success() -> None:
"""Test _check_cache successfully reconstructs CreateResult from a dict.
This tests the line: cached_result = CreateResult.model_validate(cached_result)
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a dict that can be successfully validated as CreateResult
valid_dict = {
"content": "reconstructed response",
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
"cached": False,
"finish_reason": "stop",
}
# Create a MockCacheStore that returns the dict directly (simulating Redis)
mock_store = MockCacheStore(return_value=cast(Any, valid_dict))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should successfully reconstruct the CreateResult from dict
assert cached_result is not None
assert isinstance(cached_result, CreateResult)
assert cached_result.content == "reconstructed response"
assert cache_key is not None
def test_check_cache_dict_reconstruction_failure() -> None:
"""Test _check_cache handles ValidationError when dict cannot be reconstructed.
This tests the except ValidationError block for single dicts.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create an invalid dict that will fail CreateResult validation
invalid_dict = {
"invalid_field": "value",
"missing_required_fields": True,
}
# Create a MockCacheStore that returns the invalid dict
mock_store = MockCacheStore(return_value=cast(Any, invalid_dict))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when reconstruction fails
assert cached_result is None
assert cache_key is not None
def test_check_cache_list_reconstruction_success() -> None:
"""Test _check_cache successfully reconstructs CreateResult objects from dicts in a list.
This tests the line: reconstructed_list.append(CreateResult.model_validate(item))
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a list with valid dicts that can be reconstructed
valid_dict1 = {
"content": "first result",
"usage": {"prompt_tokens": 8, "completion_tokens": 3},
"cached": False,
"finish_reason": "stop",
}
valid_dict2 = {
"content": "second result",
"usage": {"prompt_tokens": 12, "completion_tokens": 7},
"cached": False,
"finish_reason": "stop",
}
cached_list = [
"streaming chunk 1",
valid_dict1,
"streaming chunk 2",
valid_dict2,
]
# Create a MockCacheStore that returns the list with dicts
mock_store = MockCacheStore(return_value=cast(Any, cached_list))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should successfully reconstruct the list with CreateResult objects
assert cached_result is not None
assert isinstance(cached_result, list)
assert len(cached_result) == 4
assert cached_result[0] == "streaming chunk 1"
assert isinstance(cached_result[1], CreateResult)
assert cached_result[1].content == "first result"
assert cached_result[2] == "streaming chunk 2"
assert isinstance(cached_result[3], CreateResult)
assert cached_result[3].content == "second result"
assert cache_key is not None
def test_check_cache_list_reconstruction_failure() -> None:
"""Test _check_cache handles ValidationError when list contains invalid dicts.
This tests the except ValidationError block for lists.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a list with an invalid dict that will fail validation
invalid_dict = {
"invalid_field": "value",
"missing_required": True,
}
cached_list = [
"streaming chunk 1",
invalid_dict, # This will cause ValidationError
"streaming chunk 2",
]
# Create a MockCacheStore that returns the list with invalid dict
mock_store = MockCacheStore(return_value=cast(Any, cached_list))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when list reconstruction fails
assert cached_result is None
assert cache_key is not None
def test_check_cache_already_correct_type() -> None:
"""Test _check_cache returns data as-is when it's already the correct type.
This tests the final return path when no reconstruction is needed.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a proper CreateResult object (already correct type)
create_result = CreateResult(
content="already correct type",
usage=RequestUsage(prompt_tokens=15, completion_tokens=8),
cached=False,
finish_reason="stop",
)
# Create a MockCacheStore that returns the proper type
mock_store = MockCacheStore(return_value=create_result)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return the same object without reconstruction
assert cached_result is not None
assert cached_result is create_result # Same object reference
assert isinstance(cached_result, CreateResult)
assert cached_result.content == "already correct type"
assert cache_key is not None
def test_check_cache_string_json_deserialization_success() -> None:
"""Test _check_cache when Redis cache returns a string containing valid JSON.
This tests the fix for the Redis string caching issue where Redis returns
string data instead of dict/CreateResult, causing cache misses.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a JSON string representing a valid CreateResult
create_result_json = json.dumps(
{
"content": "response from string json",
"usage": {"prompt_tokens": 12, "completion_tokens": 6},
"cached": False,
"finish_reason": "stop",
"logprobs": None,
"thought": None,
}
)
# Mock cache store that returns the JSON string (simulating Redis behavior)
mock_store = MockCacheStore(return_value=cast(Any, create_result_json))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should successfully reconstruct the CreateResult from JSON string
assert cached_result is not None
assert isinstance(cached_result, CreateResult)
assert cached_result.content == "response from string json"
assert cached_result.usage.prompt_tokens == 12
assert cached_result.usage.completion_tokens == 6
assert cache_key is not None
def test_check_cache_string_json_list_deserialization_success() -> None:
"""Test _check_cache when Redis cache returns a string containing valid JSON list.
This tests the fix for streaming results stored as JSON strings in Redis.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a JSON string representing a streaming result list
streaming_list_json = json.dumps(
[
"streaming chunk 1",
{
"content": "streaming response from json",
"usage": {"prompt_tokens": 8, "completion_tokens": 4},
"cached": False,
"finish_reason": "stop",
"logprobs": None,
"thought": None,
},
"streaming chunk 2",
]
)
# Mock cache store that returns the JSON string (simulating Redis streaming)
mock_store = MockCacheStore(return_value=cast(Any, streaming_list_json))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should successfully reconstruct the list from JSON string
assert cached_result is not None
assert isinstance(cached_result, list)
assert len(cached_result) == 3
assert cached_result[0] == "streaming chunk 1"
assert isinstance(cached_result[1], CreateResult)
assert cached_result[1].content == "streaming response from json"
assert cached_result[2] == "streaming chunk 2"
assert cache_key is not None
def test_check_cache_string_invalid_json_failure() -> None:
"""Test _check_cache gracefully handles invalid JSON strings.
This ensures the system degrades gracefully when Redis returns corrupted
string data that cannot be parsed as JSON.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create an invalid JSON string
invalid_json_string = '{"content": "test", invalid json}'
# Mock cache store that returns the invalid JSON string
mock_store = MockCacheStore(return_value=cast(Any, invalid_json_string))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when JSON parsing fails
assert cached_result is None
assert cache_key is not None
def test_check_cache_string_invalid_data_failure() -> None:
"""Test _check_cache gracefully handles JSON strings with invalid data structure.
This ensures the system handles JSON that parses but doesn't represent valid CreateResult data.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a JSON string that parses but has invalid structure
invalid_data_json = json.dumps({"invalid_structure": "not a CreateResult"})
# Mock cache store that returns the invalid data JSON string
mock_store = MockCacheStore(return_value=cast(Any, invalid_data_json))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when validation fails
assert cached_result is None
assert cache_key is not None
@pytest.mark.asyncio
async def test_redis_streaming_cache_integration() -> None:
"""Integration test for Redis streaming cache scenario.
This test covers the original streaming cache issues:
1. Cache is stored after streaming completes (not before)
2. Redis cache properly handles lists containing CreateResult objects
3. ChatCompletionCache properly reconstructs CreateResult from Redis dicts
"""
from unittest.mock import MagicMock
# Skip this test if redis is not available
pytest.importorskip("redis")
from autogen_ext.cache_store.redis import RedisStore
# Use standardized test data
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Mock Redis instance to control what gets stored/retrieved
redis_instance = MagicMock()
redis_store = RedisStore[CHAT_CACHE_VALUE_TYPE](redis_instance)
# Create the cached client with Redis store
cached_client = ChatCompletionCache(replay_client, redis_store)
# Simulate first streaming call (should cache after completion)
first_stream_results: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[0], source="user")]):
first_stream_results.append(copy.copy(chunk))
# Verify Redis set was called with the complete streaming results
redis_instance.set.assert_called_once()
call_args = redis_instance.set.call_args
serialized_data = call_args[0][1]
# Verify the serialized data represents the complete stream
assert isinstance(serialized_data, bytes)
import json
deserialized = json.loads(serialized_data.decode("utf-8"))
assert isinstance(deserialized, list)
# Type narrowing: after isinstance check, deserialized is known to be a list
deserialized_list: List[Union[str, Dict[str, Union[str, int]]]] = deserialized # Now properly typed as list
# Should contain both string chunks and final CreateResult (as dict)
assert len(deserialized_list) > 0
# Reset the mock for the second call
redis_instance.reset_mock()
# Configure Redis to return the serialized data (simulating cache hit)
redis_instance.get.return_value = serialized_data
# Second streaming call should hit the cache
second_stream_results: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[0], source="user")]):
second_stream_results.append(copy.copy(chunk))
# Verify Redis get was called but set was not (cache hit)
redis_instance.get.assert_called_once()
redis_instance.set.assert_not_called()
# Verify both streams have the same content
assert len(first_stream_results) == len(second_stream_results)
# Verify cached results are marked as cached
for first, second in zip(first_stream_results, second_stream_results, strict=True):
if isinstance(first, CreateResult) and isinstance(second, CreateResult):
assert not first.cached # First call should not be cached
assert second.cached # Second call should be cached
assert first.content == second.content
elif isinstance(first, str) and isinstance(second, str):
assert first == second
else:
pytest.fail(f"Unexpected chunk types: {type(first)}, {type(second)}")
@pytest.mark.asyncio
async def test_cache_cross_compatibility_create_to_stream() -> None:
"""Test that create() cache can be used by create_stream() call.
This tests the scenario where:
1. User calls create() - stores CreateResult
2. User calls create_stream() with same inputs - should get cache hit and yield the CreateResult
"""
responses, prompts, system_prompt, _, cached_client = get_test_data()
# First call: create() - should cache a CreateResult
create_result = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(create_result, CreateResult)
assert not create_result.cached
assert create_result.content == responses[0]
# Second call: create_stream() with same inputs - should hit the cache
stream_results: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[0], source="user")]):
stream_results.append(copy.copy(chunk))
# Should yield exactly two items: the string content + the cached CreateResult
assert len(stream_results) == 2
# First item should be the string content
assert isinstance(stream_results[0], str)
assert stream_results[0] == responses[0]
# Second item should be the cached CreateResult
assert isinstance(stream_results[1], CreateResult)
assert stream_results[1].cached # Should be marked as cached
assert stream_results[1].content == responses[0]
# Verify no additional API calls were made (cache hit)
initial_usage = cached_client.total_usage()
# Third call: create_stream() again - should still hit cache
stream_results_2: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[0], source="user")]):
stream_results_2.append(copy.copy(chunk))
# Usage should be the same (no new API calls)
assert cached_client.total_usage().prompt_tokens == initial_usage.prompt_tokens
assert cached_client.total_usage().completion_tokens == initial_usage.completion_tokens
@pytest.mark.asyncio
async def test_cache_cross_compatibility_stream_to_create() -> None:
"""Test that create_stream() cache can be used by create() call.
This tests the scenario where:
1. User calls create_stream() - stores List[Union[str, CreateResult]]
2. User calls create() with same inputs - should get cache hit and return the final CreateResult
"""
_, prompts, system_prompt, _, cached_client = get_test_data()
# First call: create_stream() - should cache a List[Union[str, CreateResult]]
first_stream_results: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[0], source="user")]):
first_stream_results.append(copy.copy(chunk))
# Verify we got streaming results
assert len(first_stream_results) > 0
final_create_result = None
for item in first_stream_results:
if isinstance(item, CreateResult):
final_create_result = item
break
assert final_create_result is not None
assert not final_create_result.cached # First call should not be cached
# Second call: create() with same inputs - should hit the streaming cache
create_result = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(create_result, CreateResult)
assert create_result.cached # Should be marked as cached
assert create_result.content == final_create_result.content
# Verify no additional API calls were made (cache hit)
initial_usage = cached_client.total_usage()
# Third call: create() again - should still hit cache
create_result_2 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
# Usage should be the same (no new API calls)
assert cached_client.total_usage().prompt_tokens == initial_usage.prompt_tokens
assert cached_client.total_usage().completion_tokens == initial_usage.completion_tokens
assert create_result_2.cached
@pytest.mark.asyncio
async def test_cache_cross_compatibility_mixed_sequence() -> None:
"""Test mixed sequence of create() and create_stream() calls with caching.
This tests a realistic scenario with multiple interleaved calls:
create() → create_stream() → create() → create_stream()
"""
responses, prompts, system_prompt, _, cached_client = get_test_data(num_messages=4)
# Call 1: create() with prompt[0] - should make API call
result1 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert not result1.cached
assert result1.content == responses[0]
usage_after_1 = copy.copy(cached_client.total_usage())
# Call 2: create_stream() with prompt[0] - should hit cache from call 1
stream1_results: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[0], source="user")]):
stream1_results.append(chunk)
assert len(stream1_results) == 2 # Should yield string content + cached CreateResult
assert isinstance(stream1_results[0], str) # First item: string content
assert stream1_results[0] == responses[0]
assert isinstance(stream1_results[1], CreateResult) # Second item: cached CreateResult
assert stream1_results[1].cached
usage_after_2 = copy.copy(cached_client.total_usage())
# No new API call should have been made
assert usage_after_2.prompt_tokens == usage_after_1.prompt_tokens
# Call 3: create_stream() with prompt[1] - should make new API call
stream2_results: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[1], source="user")]):
stream2_results.append(copy.copy(chunk))
# Should have made a new API call
usage_after_3 = copy.copy(cached_client.total_usage())
assert usage_after_3.prompt_tokens > usage_after_2.prompt_tokens
# Find the final CreateResult
final_result = None
for item in stream2_results:
if isinstance(item, CreateResult):
final_result = item
break
assert final_result is not None
assert not final_result.cached
# Call 4: create() with prompt[1] - should hit cache from call 3
result4 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")])
assert result4.cached
assert result4.content == final_result.content
usage_after_4 = copy.copy(cached_client.total_usage())
# No new API call should have been made
assert usage_after_4.prompt_tokens == usage_after_3.prompt_tokens
@pytest.mark.asyncio
async def test_cache_streaming_list_without_create_result() -> None:
"""Test edge case where streaming cache contains only strings (no CreateResult).
This could happen if streaming was interrupted or in unusual scenarios.
The create() method should handle this gracefully by falling through to make a real API call.
"""
responses, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a mock cache store that returns a list with only strings (no CreateResult)
string_only_list: List[Union[str, CreateResult]] = ["Hello", " world", "!"]
mock_store = MockCacheStore(return_value=string_only_list)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Call create() - should fall through and make API call since no CreateResult in cached list
result = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(result, CreateResult)
assert not result.cached # Should be from real API call, not cache
assert result.content == responses[0]
@pytest.mark.asyncio
async def test_create_stream_with_cached_non_streaming_result_string_content() -> None:
"""
Test that when create_stream() finds a cached non-streaming result with string content,
it yields both the content string as a streaming chunk and then the CreateResult.
"""
responses, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a CreateResult with string content (simulating a cached non-streaming result)
cached_create_result = CreateResult(
content=responses[0], # This is a string
finish_reason="stop",
usage=RequestUsage(prompt_tokens=10, completion_tokens=20),
cached=False, # Will be set to True when retrieved from cache
)
# Mock cache store that returns the non-streaming CreateResult
mock_store = MockCacheStore(return_value=cached_create_result)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Call create_stream() - should yield string content first, then CreateResult
stream_results: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[0], source="user")]):
stream_results.append(copy.copy(chunk))
# Should have exactly 2 items: the string content, then the CreateResult
assert len(stream_results) == 2
# First item should be the string content
assert isinstance(stream_results[0], str)
assert stream_results[0] == responses[0]
# Second item should be the CreateResult
assert isinstance(stream_results[1], CreateResult)
assert stream_results[1].content == responses[0]
assert stream_results[1].finish_reason == "stop"
assert stream_results[1].cached is True # Should be marked as cached
@pytest.mark.asyncio
async def test_create_stream_with_cached_non_streaming_result_empty_content() -> None:
"""
Test that when create_stream() finds a cached non-streaming result with empty string content,
it only yields the CreateResult (no separate string chunk).
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a CreateResult with empty string content
cached_create_result = CreateResult(
content="", # Empty string
finish_reason="stop",
usage=RequestUsage(prompt_tokens=10, completion_tokens=0),
cached=False,
)
# Mock cache store that returns the non-streaming CreateResult
mock_store = MockCacheStore(return_value=cached_create_result)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Call create_stream() - should yield only the CreateResult (no string chunk)
stream_results: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[0], source="user")]):
stream_results.append(copy.copy(chunk))
# Should have exactly 1 item: just the CreateResult
assert len(stream_results) == 1
# Only item should be the CreateResult
assert isinstance(stream_results[0], CreateResult)
assert stream_results[0].content == ""
assert stream_results[0].finish_reason == "stop"
assert stream_results[0].cached is True
@pytest.mark.asyncio
async def test_create_stream_with_cached_non_streaming_result_non_string_content() -> None:
"""
Test that when create_stream() finds a cached non-streaming result with non-string content,
it only yields the CreateResult (no separate string chunk).
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a CreateResult with non-string content (e.g., list of function calls)
cached_create_result = CreateResult(
content=[
FunctionCall(id="call_123", name="test_func", arguments='{"param": "value"}')
], # List of FunctionCall objects
finish_reason="function_calls", # Valid finish reason for function calls
usage=RequestUsage(prompt_tokens=10, completion_tokens=15),
cached=False,
)
# Mock cache store that returns the non-streaming CreateResult
mock_store = MockCacheStore(return_value=cached_create_result)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Call create_stream() - should yield only the CreateResult (no string chunk)
stream_results: List[Union[str, CreateResult]] = []
async for chunk in cached_client.create_stream([system_prompt, UserMessage(content=prompts[0], source="user")]):
stream_results.append(copy.copy(chunk))
# Should have exactly 1 item: just the CreateResult
assert len(stream_results) == 1
# Only item should be the CreateResult
assert isinstance(stream_results[0], CreateResult)
expected_function_call = FunctionCall(id="call_123", name="test_func", arguments='{"param": "value"}')
assert stream_results[0].content == [expected_function_call]
assert stream_results[0].finish_reason == "function_calls"
assert stream_results[0].cached is True