mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-16 18:14:30 +00:00
Fix Redis Deserialization Error (#6952)
This commit is contained in:
parent
db10c9dd2c
commit
29a84e293c
@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union, cast
|
||||
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union
|
||||
|
||||
from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
|
||||
from autogen_core.models import (
|
||||
@ -13,7 +13,7 @@ from autogen_core.models import (
|
||||
RequestUsage,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Self
|
||||
|
||||
CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]
|
||||
@ -126,8 +126,29 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
serialized_data = json.dumps(data, sort_keys=True)
|
||||
cache_key = hashlib.sha256(serialized_data.encode()).hexdigest()
|
||||
|
||||
cached_result = cast(Optional[CreateResult], self.store.get(cache_key))
|
||||
cached_result = self.store.get(cache_key)
|
||||
if cached_result is not None:
|
||||
# Handle case where cache store returns dict instead of CreateResult (e.g., Redis)
|
||||
if isinstance(cached_result, dict):
|
||||
try:
|
||||
cached_result = CreateResult.model_validate(cached_result)
|
||||
except ValidationError:
|
||||
# If reconstruction fails, treat as cache miss
|
||||
return None, cache_key
|
||||
elif isinstance(cached_result, list):
|
||||
# Handle streaming results - reconstruct CreateResult instances from dicts
|
||||
try:
|
||||
reconstructed_list: List[Union[str, CreateResult]] = []
|
||||
for item in cached_result:
|
||||
if isinstance(item, dict):
|
||||
reconstructed_list.append(CreateResult.model_validate(item))
|
||||
else:
|
||||
reconstructed_list.append(item)
|
||||
cached_result = reconstructed_list
|
||||
except ValidationError:
|
||||
# If reconstruction fails, treat as cache miss
|
||||
return None, cache_key
|
||||
# If it's already the right type (CreateResult or list), return as-is
|
||||
return cached_result, cache_key
|
||||
|
||||
return None, cache_key
|
||||
|
||||
@ -1,15 +1,17 @@
|
||||
import copy
|
||||
from typing import List, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import pytest
|
||||
from autogen_core import CacheStore
|
||||
from autogen_core.models import (
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_ext.models.cache import ChatCompletionCache
|
||||
from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -184,3 +186,297 @@ async def test_cache_create_stream() -> None:
|
||||
# cached_client_config = cached_client.dump_component()
|
||||
# loaded_client = ChatCompletionCache.load_component(cached_client_config)
|
||||
# assert loaded_client.client == cached_client.client
|
||||
|
||||
|
||||
class MockCacheStore(CacheStore[CHAT_CACHE_VALUE_TYPE]):
|
||||
"""Mock cache store for testing deserialization scenarios."""
|
||||
|
||||
def __init__(self, return_value: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> None:
|
||||
self._return_value = return_value
|
||||
self._storage: Dict[str, CHAT_CACHE_VALUE_TYPE] = {}
|
||||
|
||||
def get(self, key: str, default: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> Optional[CHAT_CACHE_VALUE_TYPE]:
|
||||
return self._return_value # type: ignore
|
||||
|
||||
def set(self, key: str, value: CHAT_CACHE_VALUE_TYPE) -> None:
|
||||
self._storage[key] = value
|
||||
|
||||
def _to_config(self) -> BaseModel:
|
||||
"""Dummy implementation for testing."""
|
||||
return BaseModel()
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, _config: BaseModel) -> "MockCacheStore":
|
||||
"""Dummy implementation for testing."""
|
||||
return cls()
|
||||
|
||||
|
||||
def test_check_cache_redis_dict_deserialization_success() -> None:
|
||||
"""Test _check_cache when Redis cache returns a dict that can be deserialized to CreateResult.
|
||||
This tests the core Redis deserialization fix where Redis returns serialized Pydantic
|
||||
models as dictionaries instead of the original objects.
|
||||
"""
|
||||
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||
|
||||
# Create a CreateResult instance (simulating deserialized Redis data)
|
||||
create_result = CreateResult(
|
||||
content="test response from redis",
|
||||
usage=RequestUsage(prompt_tokens=15, completion_tokens=8),
|
||||
cached=False,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
# Mock cache store that returns a CreateResult (simulating Redis behavior)
|
||||
mock_store = MockCacheStore(return_value=create_result)
|
||||
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||
|
||||
# Test _check_cache method directly using proper test data
|
||||
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||
|
||||
assert cached_result is not None
|
||||
assert isinstance(cached_result, CreateResult)
|
||||
assert cached_result.content == "test response from redis"
|
||||
assert cache_key is not None
|
||||
|
||||
|
||||
def test_check_cache_redis_dict_deserialization_failure() -> None:
|
||||
"""Test _check_cache gracefully handles corrupted Redis data.
|
||||
This ensures the system degrades gracefully when Redis returns corrupted
|
||||
or invalid data that cannot be deserialized back to CreateResult.
|
||||
"""
|
||||
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||
|
||||
# Mock cache store that returns None (simulating deserialization failure)
|
||||
mock_store = MockCacheStore(return_value=None)
|
||||
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||
|
||||
# Test _check_cache method directly using proper test data
|
||||
messages = [system_prompt, UserMessage(content=prompts[1], source="user")]
|
||||
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||
|
||||
# Should return None (cache miss) when deserialization fails
|
||||
assert cached_result is None
|
||||
assert cache_key is not None
|
||||
|
||||
|
||||
def test_check_cache_redis_streaming_dict_deserialization() -> None:
|
||||
"""Test _check_cache with Redis streaming data containing dicts that need deserialization.
|
||||
This tests the streaming scenario where Redis returns a list containing
|
||||
serialized CreateResult objects as dictionaries mixed with string chunks.
|
||||
"""
|
||||
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||
|
||||
# Create a list with CreateResult objects mixed with strings (streaming scenario)
|
||||
create_result = CreateResult(
|
||||
content="final streaming response from redis",
|
||||
usage=RequestUsage(prompt_tokens=12, completion_tokens=6),
|
||||
cached=False,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
cached_list: List[Union[str, CreateResult]] = [
|
||||
"streaming chunk 1",
|
||||
create_result, # Proper CreateResult object
|
||||
"streaming chunk 2",
|
||||
]
|
||||
|
||||
# Mock cache store that returns the list with CreateResults (simulating Redis streaming)
|
||||
mock_store = MockCacheStore(return_value=cached_list)
|
||||
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||
|
||||
# Test _check_cache method directly using proper test data
|
||||
messages = [system_prompt, UserMessage(content=prompts[2], source="user")]
|
||||
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||
|
||||
assert cached_result is not None
|
||||
assert isinstance(cached_result, list)
|
||||
assert len(cached_result) == 3
|
||||
assert cached_result[0] == "streaming chunk 1"
|
||||
assert isinstance(cached_result[1], CreateResult)
|
||||
assert cached_result[1].content == "final streaming response from redis"
|
||||
assert cached_result[2] == "streaming chunk 2"
|
||||
assert cache_key is not None
|
||||
|
||||
|
||||
def test_check_cache_redis_streaming_deserialization_failure() -> None:
|
||||
"""Test _check_cache gracefully handles corrupted Redis streaming data.
|
||||
This ensures the system degrades gracefully when Redis returns streaming
|
||||
data with corrupted CreateResult dictionaries that cannot be deserialized.
|
||||
"""
|
||||
_, prompts, system_prompt, replay_client, _ = get_test_data(num_messages=4)
|
||||
|
||||
# Mock cache store that returns None (simulating deserialization failure)
|
||||
mock_store = MockCacheStore(return_value=None)
|
||||
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||
|
||||
# Test _check_cache method directly using proper test data
|
||||
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||
|
||||
# Should return None (cache miss) when deserialization fails
|
||||
assert cached_result is None
|
||||
assert cache_key is not None
|
||||
|
||||
|
||||
def test_check_cache_dict_reconstruction_success() -> None:
|
||||
"""Test _check_cache successfully reconstructs CreateResult from a dict.
|
||||
This tests the line: cached_result = CreateResult.model_validate(cached_result)
|
||||
"""
|
||||
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||
|
||||
# Create a dict that can be successfully validated as CreateResult
|
||||
valid_dict = {
|
||||
"content": "reconstructed response",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
|
||||
"cached": False,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
|
||||
# Create a MockCacheStore that returns the dict directly (simulating Redis)
|
||||
mock_store = MockCacheStore(return_value=cast(Any, valid_dict))
|
||||
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||
|
||||
# Test _check_cache method
|
||||
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||
|
||||
# Should successfully reconstruct the CreateResult from dict
|
||||
assert cached_result is not None
|
||||
assert isinstance(cached_result, CreateResult)
|
||||
assert cached_result.content == "reconstructed response"
|
||||
assert cache_key is not None
|
||||
|
||||
|
||||
def test_check_cache_dict_reconstruction_failure() -> None:
|
||||
"""Test _check_cache handles ValidationError when dict cannot be reconstructed.
|
||||
This tests the except ValidationError block for single dicts.
|
||||
"""
|
||||
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||
|
||||
# Create an invalid dict that will fail CreateResult validation
|
||||
invalid_dict = {
|
||||
"invalid_field": "value",
|
||||
"missing_required_fields": True,
|
||||
}
|
||||
|
||||
# Create a MockCacheStore that returns the invalid dict
|
||||
mock_store = MockCacheStore(return_value=cast(Any, invalid_dict))
|
||||
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||
|
||||
# Test _check_cache method
|
||||
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||
|
||||
# Should return None (cache miss) when reconstruction fails
|
||||
assert cached_result is None
|
||||
assert cache_key is not None
|
||||
|
||||
|
||||
def test_check_cache_list_reconstruction_success() -> None:
|
||||
"""Test _check_cache successfully reconstructs CreateResult objects from dicts in a list.
|
||||
This tests the line: reconstructed_list.append(CreateResult.model_validate(item))
|
||||
"""
|
||||
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||
|
||||
# Create a list with valid dicts that can be reconstructed
|
||||
valid_dict1 = {
|
||||
"content": "first result",
|
||||
"usage": {"prompt_tokens": 8, "completion_tokens": 3},
|
||||
"cached": False,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
valid_dict2 = {
|
||||
"content": "second result",
|
||||
"usage": {"prompt_tokens": 12, "completion_tokens": 7},
|
||||
"cached": False,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
|
||||
cached_list = [
|
||||
"streaming chunk 1",
|
||||
valid_dict1,
|
||||
"streaming chunk 2",
|
||||
valid_dict2,
|
||||
]
|
||||
|
||||
# Create a MockCacheStore that returns the list with dicts
|
||||
mock_store = MockCacheStore(return_value=cast(Any, cached_list))
|
||||
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||
|
||||
# Test _check_cache method
|
||||
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||
|
||||
# Should successfully reconstruct the list with CreateResult objects
|
||||
assert cached_result is not None
|
||||
assert isinstance(cached_result, list)
|
||||
assert len(cached_result) == 4
|
||||
assert cached_result[0] == "streaming chunk 1"
|
||||
assert isinstance(cached_result[1], CreateResult)
|
||||
assert cached_result[1].content == "first result"
|
||||
assert cached_result[2] == "streaming chunk 2"
|
||||
assert isinstance(cached_result[3], CreateResult)
|
||||
assert cached_result[3].content == "second result"
|
||||
assert cache_key is not None
|
||||
|
||||
|
||||
def test_check_cache_list_reconstruction_failure() -> None:
|
||||
"""Test _check_cache handles ValidationError when list contains invalid dicts.
|
||||
This tests the except ValidationError block for lists.
|
||||
"""
|
||||
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||
|
||||
# Create a list with an invalid dict that will fail validation
|
||||
invalid_dict = {
|
||||
"invalid_field": "value",
|
||||
"missing_required": True,
|
||||
}
|
||||
|
||||
cached_list = [
|
||||
"streaming chunk 1",
|
||||
invalid_dict, # This will cause ValidationError
|
||||
"streaming chunk 2",
|
||||
]
|
||||
|
||||
# Create a MockCacheStore that returns the list with invalid dict
|
||||
mock_store = MockCacheStore(return_value=cast(Any, cached_list))
|
||||
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||
|
||||
# Test _check_cache method
|
||||
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||
|
||||
# Should return None (cache miss) when list reconstruction fails
|
||||
assert cached_result is None
|
||||
assert cache_key is not None
|
||||
|
||||
|
||||
def test_check_cache_already_correct_type() -> None:
|
||||
"""Test _check_cache returns data as-is when it's already the correct type.
|
||||
This tests the final return path when no reconstruction is needed.
|
||||
"""
|
||||
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||
|
||||
# Create a proper CreateResult object (already correct type)
|
||||
create_result = CreateResult(
|
||||
content="already correct type",
|
||||
usage=RequestUsage(prompt_tokens=15, completion_tokens=8),
|
||||
cached=False,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
# Create a MockCacheStore that returns the proper type
|
||||
mock_store = MockCacheStore(return_value=create_result)
|
||||
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||
|
||||
# Test _check_cache method
|
||||
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||
|
||||
# Should return the same object without reconstruction
|
||||
assert cached_result is not None
|
||||
assert cached_result is create_result # Same object reference
|
||||
assert isinstance(cached_result, CreateResult)
|
||||
assert cached_result.content == "already correct type"
|
||||
assert cache_key is not None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user