mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-19 03:24:46 +00:00
Fix Redis Deserialization Error (#6952)
This commit is contained in:
parent
db10c9dd2c
commit
29a84e293c
@ -1,7 +1,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union, cast
|
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union
|
||||||
|
|
||||||
from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
|
from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
|
||||||
from autogen_core.models import (
|
from autogen_core.models import (
|
||||||
@ -13,7 +13,7 @@ from autogen_core.models import (
|
|||||||
RequestUsage,
|
RequestUsage,
|
||||||
)
|
)
|
||||||
from autogen_core.tools import Tool, ToolSchema
|
from autogen_core.tools import Tool, ToolSchema
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ValidationError
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]
|
CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]
|
||||||
@ -126,8 +126,29 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
|||||||
serialized_data = json.dumps(data, sort_keys=True)
|
serialized_data = json.dumps(data, sort_keys=True)
|
||||||
cache_key = hashlib.sha256(serialized_data.encode()).hexdigest()
|
cache_key = hashlib.sha256(serialized_data.encode()).hexdigest()
|
||||||
|
|
||||||
cached_result = cast(Optional[CreateResult], self.store.get(cache_key))
|
cached_result = self.store.get(cache_key)
|
||||||
if cached_result is not None:
|
if cached_result is not None:
|
||||||
|
# Handle case where cache store returns dict instead of CreateResult (e.g., Redis)
|
||||||
|
if isinstance(cached_result, dict):
|
||||||
|
try:
|
||||||
|
cached_result = CreateResult.model_validate(cached_result)
|
||||||
|
except ValidationError:
|
||||||
|
# If reconstruction fails, treat as cache miss
|
||||||
|
return None, cache_key
|
||||||
|
elif isinstance(cached_result, list):
|
||||||
|
# Handle streaming results - reconstruct CreateResult instances from dicts
|
||||||
|
try:
|
||||||
|
reconstructed_list: List[Union[str, CreateResult]] = []
|
||||||
|
for item in cached_result:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
reconstructed_list.append(CreateResult.model_validate(item))
|
||||||
|
else:
|
||||||
|
reconstructed_list.append(item)
|
||||||
|
cached_result = reconstructed_list
|
||||||
|
except ValidationError:
|
||||||
|
# If reconstruction fails, treat as cache miss
|
||||||
|
return None, cache_key
|
||||||
|
# If it's already the right type (CreateResult or list), return as-is
|
||||||
return cached_result, cache_key
|
return cached_result, cache_key
|
||||||
|
|
||||||
return None, cache_key
|
return None, cache_key
|
||||||
|
|||||||
@ -1,15 +1,17 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import List, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from autogen_core import CacheStore
|
||||||
from autogen_core.models import (
|
from autogen_core.models import (
|
||||||
ChatCompletionClient,
|
ChatCompletionClient,
|
||||||
CreateResult,
|
CreateResult,
|
||||||
LLMMessage,
|
LLMMessage,
|
||||||
|
RequestUsage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from autogen_ext.models.cache import ChatCompletionCache
|
from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache
|
||||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -184,3 +186,297 @@ async def test_cache_create_stream() -> None:
|
|||||||
# cached_client_config = cached_client.dump_component()
|
# cached_client_config = cached_client.dump_component()
|
||||||
# loaded_client = ChatCompletionCache.load_component(cached_client_config)
|
# loaded_client = ChatCompletionCache.load_component(cached_client_config)
|
||||||
# assert loaded_client.client == cached_client.client
|
# assert loaded_client.client == cached_client.client
|
||||||
|
|
||||||
|
|
||||||
|
class MockCacheStore(CacheStore[CHAT_CACHE_VALUE_TYPE]):
|
||||||
|
"""Mock cache store for testing deserialization scenarios."""
|
||||||
|
|
||||||
|
def __init__(self, return_value: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> None:
|
||||||
|
self._return_value = return_value
|
||||||
|
self._storage: Dict[str, CHAT_CACHE_VALUE_TYPE] = {}
|
||||||
|
|
||||||
|
def get(self, key: str, default: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> Optional[CHAT_CACHE_VALUE_TYPE]:
|
||||||
|
return self._return_value # type: ignore
|
||||||
|
|
||||||
|
def set(self, key: str, value: CHAT_CACHE_VALUE_TYPE) -> None:
|
||||||
|
self._storage[key] = value
|
||||||
|
|
||||||
|
def _to_config(self) -> BaseModel:
|
||||||
|
"""Dummy implementation for testing."""
|
||||||
|
return BaseModel()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, _config: BaseModel) -> "MockCacheStore":
|
||||||
|
"""Dummy implementation for testing."""
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cache_redis_dict_deserialization_success() -> None:
|
||||||
|
"""Test _check_cache when Redis cache returns a dict that can be deserialized to CreateResult.
|
||||||
|
This tests the core Redis deserialization fix where Redis returns serialized Pydantic
|
||||||
|
models as dictionaries instead of the original objects.
|
||||||
|
"""
|
||||||
|
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||||
|
|
||||||
|
# Create a CreateResult instance (simulating deserialized Redis data)
|
||||||
|
create_result = CreateResult(
|
||||||
|
content="test response from redis",
|
||||||
|
usage=RequestUsage(prompt_tokens=15, completion_tokens=8),
|
||||||
|
cached=False,
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock cache store that returns a CreateResult (simulating Redis behavior)
|
||||||
|
mock_store = MockCacheStore(return_value=create_result)
|
||||||
|
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||||
|
|
||||||
|
# Test _check_cache method directly using proper test data
|
||||||
|
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||||
|
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||||
|
|
||||||
|
assert cached_result is not None
|
||||||
|
assert isinstance(cached_result, CreateResult)
|
||||||
|
assert cached_result.content == "test response from redis"
|
||||||
|
assert cache_key is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cache_redis_dict_deserialization_failure() -> None:
|
||||||
|
"""Test _check_cache gracefully handles corrupted Redis data.
|
||||||
|
This ensures the system degrades gracefully when Redis returns corrupted
|
||||||
|
or invalid data that cannot be deserialized back to CreateResult.
|
||||||
|
"""
|
||||||
|
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||||
|
|
||||||
|
# Mock cache store that returns None (simulating deserialization failure)
|
||||||
|
mock_store = MockCacheStore(return_value=None)
|
||||||
|
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||||
|
|
||||||
|
# Test _check_cache method directly using proper test data
|
||||||
|
messages = [system_prompt, UserMessage(content=prompts[1], source="user")]
|
||||||
|
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||||
|
|
||||||
|
# Should return None (cache miss) when deserialization fails
|
||||||
|
assert cached_result is None
|
||||||
|
assert cache_key is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cache_redis_streaming_dict_deserialization() -> None:
|
||||||
|
"""Test _check_cache with Redis streaming data containing dicts that need deserialization.
|
||||||
|
This tests the streaming scenario where Redis returns a list containing
|
||||||
|
serialized CreateResult objects as dictionaries mixed with string chunks.
|
||||||
|
"""
|
||||||
|
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||||
|
|
||||||
|
# Create a list with CreateResult objects mixed with strings (streaming scenario)
|
||||||
|
create_result = CreateResult(
|
||||||
|
content="final streaming response from redis",
|
||||||
|
usage=RequestUsage(prompt_tokens=12, completion_tokens=6),
|
||||||
|
cached=False,
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
|
||||||
|
cached_list: List[Union[str, CreateResult]] = [
|
||||||
|
"streaming chunk 1",
|
||||||
|
create_result, # Proper CreateResult object
|
||||||
|
"streaming chunk 2",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock cache store that returns the list with CreateResults (simulating Redis streaming)
|
||||||
|
mock_store = MockCacheStore(return_value=cached_list)
|
||||||
|
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||||
|
|
||||||
|
# Test _check_cache method directly using proper test data
|
||||||
|
messages = [system_prompt, UserMessage(content=prompts[2], source="user")]
|
||||||
|
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||||
|
|
||||||
|
assert cached_result is not None
|
||||||
|
assert isinstance(cached_result, list)
|
||||||
|
assert len(cached_result) == 3
|
||||||
|
assert cached_result[0] == "streaming chunk 1"
|
||||||
|
assert isinstance(cached_result[1], CreateResult)
|
||||||
|
assert cached_result[1].content == "final streaming response from redis"
|
||||||
|
assert cached_result[2] == "streaming chunk 2"
|
||||||
|
assert cache_key is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cache_redis_streaming_deserialization_failure() -> None:
|
||||||
|
"""Test _check_cache gracefully handles corrupted Redis streaming data.
|
||||||
|
This ensures the system degrades gracefully when Redis returns streaming
|
||||||
|
data with corrupted CreateResult dictionaries that cannot be deserialized.
|
||||||
|
"""
|
||||||
|
_, prompts, system_prompt, replay_client, _ = get_test_data(num_messages=4)
|
||||||
|
|
||||||
|
# Mock cache store that returns None (simulating deserialization failure)
|
||||||
|
mock_store = MockCacheStore(return_value=None)
|
||||||
|
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||||
|
|
||||||
|
# Test _check_cache method directly using proper test data
|
||||||
|
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||||
|
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||||
|
|
||||||
|
# Should return None (cache miss) when deserialization fails
|
||||||
|
assert cached_result is None
|
||||||
|
assert cache_key is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cache_dict_reconstruction_success() -> None:
|
||||||
|
"""Test _check_cache successfully reconstructs CreateResult from a dict.
|
||||||
|
This tests the line: cached_result = CreateResult.model_validate(cached_result)
|
||||||
|
"""
|
||||||
|
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||||
|
|
||||||
|
# Create a dict that can be successfully validated as CreateResult
|
||||||
|
valid_dict = {
|
||||||
|
"content": "reconstructed response",
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
|
||||||
|
"cached": False,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a MockCacheStore that returns the dict directly (simulating Redis)
|
||||||
|
mock_store = MockCacheStore(return_value=cast(Any, valid_dict))
|
||||||
|
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||||
|
|
||||||
|
# Test _check_cache method
|
||||||
|
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||||
|
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||||
|
|
||||||
|
# Should successfully reconstruct the CreateResult from dict
|
||||||
|
assert cached_result is not None
|
||||||
|
assert isinstance(cached_result, CreateResult)
|
||||||
|
assert cached_result.content == "reconstructed response"
|
||||||
|
assert cache_key is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cache_dict_reconstruction_failure() -> None:
|
||||||
|
"""Test _check_cache handles ValidationError when dict cannot be reconstructed.
|
||||||
|
This tests the except ValidationError block for single dicts.
|
||||||
|
"""
|
||||||
|
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||||
|
|
||||||
|
# Create an invalid dict that will fail CreateResult validation
|
||||||
|
invalid_dict = {
|
||||||
|
"invalid_field": "value",
|
||||||
|
"missing_required_fields": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a MockCacheStore that returns the invalid dict
|
||||||
|
mock_store = MockCacheStore(return_value=cast(Any, invalid_dict))
|
||||||
|
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||||
|
|
||||||
|
# Test _check_cache method
|
||||||
|
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||||
|
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||||
|
|
||||||
|
# Should return None (cache miss) when reconstruction fails
|
||||||
|
assert cached_result is None
|
||||||
|
assert cache_key is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cache_list_reconstruction_success() -> None:
|
||||||
|
"""Test _check_cache successfully reconstructs CreateResult objects from dicts in a list.
|
||||||
|
This tests the line: reconstructed_list.append(CreateResult.model_validate(item))
|
||||||
|
"""
|
||||||
|
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||||
|
|
||||||
|
# Create a list with valid dicts that can be reconstructed
|
||||||
|
valid_dict1 = {
|
||||||
|
"content": "first result",
|
||||||
|
"usage": {"prompt_tokens": 8, "completion_tokens": 3},
|
||||||
|
"cached": False,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
valid_dict2 = {
|
||||||
|
"content": "second result",
|
||||||
|
"usage": {"prompt_tokens": 12, "completion_tokens": 7},
|
||||||
|
"cached": False,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
cached_list = [
|
||||||
|
"streaming chunk 1",
|
||||||
|
valid_dict1,
|
||||||
|
"streaming chunk 2",
|
||||||
|
valid_dict2,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a MockCacheStore that returns the list with dicts
|
||||||
|
mock_store = MockCacheStore(return_value=cast(Any, cached_list))
|
||||||
|
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||||
|
|
||||||
|
# Test _check_cache method
|
||||||
|
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||||
|
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||||
|
|
||||||
|
# Should successfully reconstruct the list with CreateResult objects
|
||||||
|
assert cached_result is not None
|
||||||
|
assert isinstance(cached_result, list)
|
||||||
|
assert len(cached_result) == 4
|
||||||
|
assert cached_result[0] == "streaming chunk 1"
|
||||||
|
assert isinstance(cached_result[1], CreateResult)
|
||||||
|
assert cached_result[1].content == "first result"
|
||||||
|
assert cached_result[2] == "streaming chunk 2"
|
||||||
|
assert isinstance(cached_result[3], CreateResult)
|
||||||
|
assert cached_result[3].content == "second result"
|
||||||
|
assert cache_key is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cache_list_reconstruction_failure() -> None:
|
||||||
|
"""Test _check_cache handles ValidationError when list contains invalid dicts.
|
||||||
|
This tests the except ValidationError block for lists.
|
||||||
|
"""
|
||||||
|
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||||
|
|
||||||
|
# Create a list with an invalid dict that will fail validation
|
||||||
|
invalid_dict = {
|
||||||
|
"invalid_field": "value",
|
||||||
|
"missing_required": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
cached_list = [
|
||||||
|
"streaming chunk 1",
|
||||||
|
invalid_dict, # This will cause ValidationError
|
||||||
|
"streaming chunk 2",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a MockCacheStore that returns the list with invalid dict
|
||||||
|
mock_store = MockCacheStore(return_value=cast(Any, cached_list))
|
||||||
|
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||||
|
|
||||||
|
# Test _check_cache method
|
||||||
|
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||||
|
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||||
|
|
||||||
|
# Should return None (cache miss) when list reconstruction fails
|
||||||
|
assert cached_result is None
|
||||||
|
assert cache_key is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cache_already_correct_type() -> None:
|
||||||
|
"""Test _check_cache returns data as-is when it's already the correct type.
|
||||||
|
This tests the final return path when no reconstruction is needed.
|
||||||
|
"""
|
||||||
|
_, prompts, system_prompt, replay_client, _ = get_test_data()
|
||||||
|
|
||||||
|
# Create a proper CreateResult object (already correct type)
|
||||||
|
create_result = CreateResult(
|
||||||
|
content="already correct type",
|
||||||
|
usage=RequestUsage(prompt_tokens=15, completion_tokens=8),
|
||||||
|
cached=False,
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a MockCacheStore that returns the proper type
|
||||||
|
mock_store = MockCacheStore(return_value=create_result)
|
||||||
|
cached_client = ChatCompletionCache(replay_client, mock_store)
|
||||||
|
|
||||||
|
# Test _check_cache method
|
||||||
|
messages = [system_prompt, UserMessage(content=prompts[0], source="user")]
|
||||||
|
cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore
|
||||||
|
|
||||||
|
# Should return the same object without reconstruction
|
||||||
|
assert cached_result is not None
|
||||||
|
assert cached_result is create_result # Same object reference
|
||||||
|
assert isinstance(cached_result, CreateResult)
|
||||||
|
assert cached_result.content == "already correct type"
|
||||||
|
assert cache_key is not None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user