Fix Redis Deserialization Error (#6952)

This commit is contained in:
Ben Constable 2025-08-19 16:24:19 +01:00 committed by GitHub
parent db10c9dd2c
commit 29a84e293c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 322 additions and 5 deletions

View File

@ -1,7 +1,7 @@
import hashlib
import json
import warnings
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union, cast
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union
from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
from autogen_core.models import (
@ -13,7 +13,7 @@ from autogen_core.models import (
RequestUsage,
)
from autogen_core.tools import Tool, ToolSchema
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from typing_extensions import Self
CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]
@ -126,8 +126,29 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
serialized_data = json.dumps(data, sort_keys=True)
cache_key = hashlib.sha256(serialized_data.encode()).hexdigest()
cached_result = cast(Optional[CreateResult], self.store.get(cache_key))
cached_result = self.store.get(cache_key)
if cached_result is not None:
# Handle case where cache store returns dict instead of CreateResult (e.g., Redis)
if isinstance(cached_result, dict):
try:
cached_result = CreateResult.model_validate(cached_result)
except ValidationError:
# If reconstruction fails, treat as cache miss
return None, cache_key
elif isinstance(cached_result, list):
# Handle streaming results - reconstruct CreateResult instances from dicts
try:
reconstructed_list: List[Union[str, CreateResult]] = []
for item in cached_result:
if isinstance(item, dict):
reconstructed_list.append(CreateResult.model_validate(item))
else:
reconstructed_list.append(item)
cached_result = reconstructed_list
except ValidationError:
# If reconstruction fails, treat as cache miss
return None, cache_key
# If it's already the right type (CreateResult or list), return as-is
return cached_result, cache_key
return None, cache_key

View File

@ -1,15 +1,17 @@
import copy
from typing import List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import pytest
from autogen_core import CacheStore
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from autogen_ext.models.cache import ChatCompletionCache
from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
@ -184,3 +186,297 @@ async def test_cache_create_stream() -> None:
# cached_client_config = cached_client.dump_component()
# loaded_client = ChatCompletionCache.load_component(cached_client_config)
# assert loaded_client.client == cached_client.client
class MockCacheStore(CacheStore[CHAT_CACHE_VALUE_TYPE]):
"""Mock cache store for testing deserialization scenarios."""
def __init__(self, return_value: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> None:
self._return_value = return_value
self._storage: Dict[str, CHAT_CACHE_VALUE_TYPE] = {}
def get(self, key: str, default: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> Optional[CHAT_CACHE_VALUE_TYPE]:
return self._return_value # type: ignore
def set(self, key: str, value: CHAT_CACHE_VALUE_TYPE) -> None:
self._storage[key] = value
def _to_config(self) -> BaseModel:
"""Dummy implementation for testing."""
return BaseModel()
@classmethod
def _from_config(cls, _config: BaseModel) -> "MockCacheStore":
"""Dummy implementation for testing."""
return cls()
def test_check_cache_redis_dict_deserialization_success() -> None:
"""Test _check_cache when Redis cache returns a dict that can be deserialized to CreateResult.
This tests the core Redis deserialization fix where Redis returns serialized Pydantic
models as dictionaries instead of the original objects.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a CreateResult instance (simulating deserialized Redis data)
create_result = CreateResult(
content="test response from redis",
usage=RequestUsage(prompt_tokens=15, completion_tokens=8),
cached=False,
finish_reason="stop",
)
# Mock cache store that returns a CreateResult (simulating Redis behavior)
mock_store = MockCacheStore(return_value=create_result)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly using proper test data
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
assert cached_result is not None
assert isinstance(cached_result, CreateResult)
assert cached_result.content == "test response from redis"
assert cache_key is not None
def test_check_cache_redis_dict_deserialization_failure() -> None:
"""Test _check_cache gracefully handles corrupted Redis data.
This ensures the system degrades gracefully when Redis returns corrupted
or invalid data that cannot be deserialized back to CreateResult.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Mock cache store that returns None (simulating deserialization failure)
mock_store = MockCacheStore(return_value=None)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly using proper test data
messages = [system_prompt, UserMessage(content=prompts[1], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when deserialization fails
assert cached_result is None
assert cache_key is not None
def test_check_cache_redis_streaming_dict_deserialization() -> None:
"""Test _check_cache with Redis streaming data containing dicts that need deserialization.
This tests the streaming scenario where Redis returns a list containing
serialized CreateResult objects as dictionaries mixed with string chunks.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a list with CreateResult objects mixed with strings (streaming scenario)
create_result = CreateResult(
content="final streaming response from redis",
usage=RequestUsage(prompt_tokens=12, completion_tokens=6),
cached=False,
finish_reason="stop",
)
cached_list: List[Union[str, CreateResult]] = [
"streaming chunk 1",
create_result, # Proper CreateResult object
"streaming chunk 2",
]
# Mock cache store that returns the list with CreateResults (simulating Redis streaming)
mock_store = MockCacheStore(return_value=cached_list)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly using proper test data
messages = [system_prompt, UserMessage(content=prompts[2], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
assert cached_result is not None
assert isinstance(cached_result, list)
assert len(cached_result) == 3
assert cached_result[0] == "streaming chunk 1"
assert isinstance(cached_result[1], CreateResult)
assert cached_result[1].content == "final streaming response from redis"
assert cached_result[2] == "streaming chunk 2"
assert cache_key is not None
def test_check_cache_redis_streaming_deserialization_failure() -> None:
"""Test _check_cache gracefully handles corrupted Redis streaming data.
This ensures the system degrades gracefully when Redis returns streaming
data with corrupted CreateResult dictionaries that cannot be deserialized.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data(num_messages=4)
# Mock cache store that returns None (simulating deserialization failure)
mock_store = MockCacheStore(return_value=None)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method directly using proper test data
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when deserialization fails
assert cached_result is None
assert cache_key is not None
def test_check_cache_dict_reconstruction_success() -> None:
"""Test _check_cache successfully reconstructs CreateResult from a dict.
This tests the line: cached_result = CreateResult.model_validate(cached_result)
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a dict that can be successfully validated as CreateResult
valid_dict = {
"content": "reconstructed response",
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
"cached": False,
"finish_reason": "stop",
}
# Create a MockCacheStore that returns the dict directly (simulating Redis)
mock_store = MockCacheStore(return_value=cast(Any, valid_dict))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should successfully reconstruct the CreateResult from dict
assert cached_result is not None
assert isinstance(cached_result, CreateResult)
assert cached_result.content == "reconstructed response"
assert cache_key is not None
def test_check_cache_dict_reconstruction_failure() -> None:
"""Test _check_cache handles ValidationError when dict cannot be reconstructed.
This tests the except ValidationError block for single dicts.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create an invalid dict that will fail CreateResult validation
invalid_dict = {
"invalid_field": "value",
"missing_required_fields": True,
}
# Create a MockCacheStore that returns the invalid dict
mock_store = MockCacheStore(return_value=cast(Any, invalid_dict))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when reconstruction fails
assert cached_result is None
assert cache_key is not None
def test_check_cache_list_reconstruction_success() -> None:
"""Test _check_cache successfully reconstructs CreateResult objects from dicts in a list.
This tests the line: reconstructed_list.append(CreateResult.model_validate(item))
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a list with valid dicts that can be reconstructed
valid_dict1 = {
"content": "first result",
"usage": {"prompt_tokens": 8, "completion_tokens": 3},
"cached": False,
"finish_reason": "stop",
}
valid_dict2 = {
"content": "second result",
"usage": {"prompt_tokens": 12, "completion_tokens": 7},
"cached": False,
"finish_reason": "stop",
}
cached_list = [
"streaming chunk 1",
valid_dict1,
"streaming chunk 2",
valid_dict2,
]
# Create a MockCacheStore that returns the list with dicts
mock_store = MockCacheStore(return_value=cast(Any, cached_list))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should successfully reconstruct the list with CreateResult objects
assert cached_result is not None
assert isinstance(cached_result, list)
assert len(cached_result) == 4
assert cached_result[0] == "streaming chunk 1"
assert isinstance(cached_result[1], CreateResult)
assert cached_result[1].content == "first result"
assert cached_result[2] == "streaming chunk 2"
assert isinstance(cached_result[3], CreateResult)
assert cached_result[3].content == "second result"
assert cache_key is not None
def test_check_cache_list_reconstruction_failure() -> None:
"""Test _check_cache handles ValidationError when list contains invalid dicts.
This tests the except ValidationError block for lists.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a list with an invalid dict that will fail validation
invalid_dict = {
"invalid_field": "value",
"missing_required": True,
}
cached_list = [
"streaming chunk 1",
invalid_dict, # This will cause ValidationError
"streaming chunk 2",
]
# Create a MockCacheStore that returns the list with invalid dict
mock_store = MockCacheStore(return_value=cast(Any, cached_list))
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return None (cache miss) when list reconstruction fails
assert cached_result is None
assert cache_key is not None
def test_check_cache_already_correct_type() -> None:
"""Test _check_cache returns data as-is when it's already the correct type.
This tests the final return path when no reconstruction is needed.
"""
_, prompts, system_prompt, replay_client, _ = get_test_data()
# Create a proper CreateResult object (already correct type)
create_result = CreateResult(
content="already correct type",
usage=RequestUsage(prompt_tokens=15, completion_tokens=8),
cached=False,
finish_reason="stop",
)
# Create a MockCacheStore that returns the proper type
mock_store = MockCacheStore(return_value=create_result)
cached_client = ChatCompletionCache(replay_client, mock_store)
# Test _check_cache method
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
# Should return the same object without reconstruction
assert cached_result is not None
assert cached_result is create_result # Same object reference
assert isinstance(cached_result, CreateResult)
assert cached_result.content == "already correct type"
assert cache_key is not None