diff --git a/python/packages/autogen-ext/src/autogen_ext/models/cache/_chat_completion_cache.py b/python/packages/autogen-ext/src/autogen_ext/models/cache/_chat_completion_cache.py index 07e3d89e8..279cffb81 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/cache/_chat_completion_cache.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/cache/_chat_completion_cache.py @@ -1,7 +1,7 @@ import hashlib import json import warnings -from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union, cast +from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore from autogen_core.models import ( @@ -13,7 +13,7 @@ from autogen_core.models import ( RequestUsage, ) from autogen_core.tools import Tool, ToolSchema -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import Self CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]] @@ -126,8 +126,29 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon serialized_data = json.dumps(data, sort_keys=True) cache_key = hashlib.sha256(serialized_data.encode()).hexdigest() - cached_result = cast(Optional[CreateResult], self.store.get(cache_key)) + cached_result = self.store.get(cache_key) if cached_result is not None: + # Handle case where cache store returns dict instead of CreateResult (e.g., Redis) + if isinstance(cached_result, dict): + try: + cached_result = CreateResult.model_validate(cached_result) + except ValidationError: + # If reconstruction fails, treat as cache miss + return None, cache_key + elif isinstance(cached_result, list): + # Handle streaming results - reconstruct CreateResult instances from dicts + try: + reconstructed_list: List[Union[str, CreateResult]] = [] + for item in cached_result: + if isinstance(item, dict): + reconstructed_list.append(CreateResult.model_validate(item)) + else: + reconstructed_list.append(item) + cached_result = reconstructed_list + except ValidationError: + # If reconstruction fails, treat as cache miss + return None, cache_key + # If it's already the right type (CreateResult or list), return as-is return cached_result, cache_key return None, cache_key diff --git a/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py b/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py index adf2da476..3a125307a 100644 --- a/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py +++ b/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py @@ -1,15 +1,17 @@ import copy -from typing import List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast import pytest +from autogen_core import CacheStore from autogen_core.models import ( ChatCompletionClient, CreateResult, LLMMessage, + RequestUsage, SystemMessage, UserMessage, ) -from autogen_ext.models.cache import ChatCompletionCache +from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache from autogen_ext.models.replay import ReplayChatCompletionClient from pydantic import BaseModel @@ -184,3 +186,297 @@ async def test_cache_create_stream() -> None: # cached_client_config = cached_client.dump_component() # loaded_client = ChatCompletionCache.load_component(cached_client_config) # assert loaded_client.client == cached_client.client + + +class MockCacheStore(CacheStore[CHAT_CACHE_VALUE_TYPE]): + """Mock cache store for testing deserialization scenarios.""" + + def __init__(self, return_value: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> None: + self._return_value = return_value + self._storage: Dict[str, CHAT_CACHE_VALUE_TYPE] = {} + + def get(self, key: str, default: Optional[CHAT_CACHE_VALUE_TYPE] = None) -> Optional[CHAT_CACHE_VALUE_TYPE]: + return self._return_value # type: ignore + + def set(self, key: str, value: CHAT_CACHE_VALUE_TYPE) -> None: + self._storage[key] = value + + def _to_config(self) -> BaseModel: + """Dummy implementation for testing.""" + return BaseModel() + + @classmethod + def _from_config(cls, _config: BaseModel) -> "MockCacheStore": + """Dummy implementation for testing.""" + return cls() + + +def test_check_cache_redis_dict_deserialization_success() -> None: + """Test _check_cache when Redis cache returns a dict that can be deserialized to CreateResult. + This tests the core Redis deserialization fix where Redis returns serialized Pydantic + models as dictionaries instead of the original objects. + """ + _, prompts, system_prompt, replay_client, _ = get_test_data() + + # Create a CreateResult instance (simulating deserialized Redis data) + create_result = CreateResult( + content="test response from redis", + usage=RequestUsage(prompt_tokens=15, completion_tokens=8), + cached=False, + finish_reason="stop", + ) + + # Mock cache store that returns a CreateResult (simulating Redis behavior) + mock_store = MockCacheStore(return_value=create_result) + cached_client = ChatCompletionCache(replay_client, mock_store) + + # Test _check_cache method directly using proper test data + messages = [system_prompt, UserMessage(content=prompts[0], source="user")] + cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore + + assert cached_result is not None + assert isinstance(cached_result, CreateResult) + assert cached_result.content == "test response from redis" + assert cache_key is not None + + +def test_check_cache_redis_dict_deserialization_failure() -> None: + """Test _check_cache gracefully handles corrupted Redis data. + This ensures the system degrades gracefully when Redis returns corrupted + or invalid data that cannot be deserialized back to CreateResult. + """ + _, prompts, system_prompt, replay_client, _ = get_test_data() + + # Mock cache store that returns None (simulating deserialization failure) + mock_store = MockCacheStore(return_value=None) + cached_client = ChatCompletionCache(replay_client, mock_store) + + # Test _check_cache method directly using proper test data + messages = [system_prompt, UserMessage(content=prompts[1], source="user")] + cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore + + # Should return None (cache miss) when deserialization fails + assert cached_result is None + assert cache_key is not None + + +def test_check_cache_redis_streaming_dict_deserialization() -> None: + """Test _check_cache with Redis streaming data containing dicts that need deserialization. + This tests the streaming scenario where Redis returns a list containing + serialized CreateResult objects as dictionaries mixed with string chunks. + """ + _, prompts, system_prompt, replay_client, _ = get_test_data() + + # Create a list with CreateResult objects mixed with strings (streaming scenario) + create_result = CreateResult( + content="final streaming response from redis", + usage=RequestUsage(prompt_tokens=12, completion_tokens=6), + cached=False, + finish_reason="stop", + ) + + cached_list: List[Union[str, CreateResult]] = [ + "streaming chunk 1", + create_result, # Proper CreateResult object + "streaming chunk 2", + ] + + # Mock cache store that returns the list with CreateResults (simulating Redis streaming) + mock_store = MockCacheStore(return_value=cached_list) + cached_client = ChatCompletionCache(replay_client, mock_store) + + # Test _check_cache method directly using proper test data + messages = [system_prompt, UserMessage(content=prompts[2], source="user")] + cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore + + assert cached_result is not None + assert isinstance(cached_result, list) + assert len(cached_result) == 3 + assert cached_result[0] == "streaming chunk 1" + assert isinstance(cached_result[1], CreateResult) + assert cached_result[1].content == "final streaming response from redis" + assert cached_result[2] == "streaming chunk 2" + assert cache_key is not None + + +def test_check_cache_redis_streaming_deserialization_failure() -> None: + """Test _check_cache gracefully handles corrupted Redis streaming data. + This ensures the system degrades gracefully when Redis returns streaming + data with corrupted CreateResult dictionaries that cannot be deserialized. + """ + _, prompts, system_prompt, replay_client, _ = get_test_data(num_messages=4) + + # Mock cache store that returns None (simulating deserialization failure) + mock_store = MockCacheStore(return_value=None) + cached_client = ChatCompletionCache(replay_client, mock_store) + + # Test _check_cache method directly using proper test data + messages = [system_prompt, UserMessage(content=prompts[0], source="user")] + cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore + + # Should return None (cache miss) when deserialization fails + assert cached_result is None + assert cache_key is not None + + +def test_check_cache_dict_reconstruction_success() -> None: + """Test _check_cache successfully reconstructs CreateResult from a dict. + This tests the line: cached_result = CreateResult.model_validate(cached_result) + """ + _, prompts, system_prompt, replay_client, _ = get_test_data() + + # Create a dict that can be successfully validated as CreateResult + valid_dict = { + "content": "reconstructed response", + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + "cached": False, + "finish_reason": "stop", + } + + # Create a MockCacheStore that returns the dict directly (simulating Redis) + mock_store = MockCacheStore(return_value=cast(Any, valid_dict)) + cached_client = ChatCompletionCache(replay_client, mock_store) + + # Test _check_cache method + messages = [system_prompt, UserMessage(content=prompts[0], source="user")] + cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore + + # Should successfully reconstruct the CreateResult from dict + assert cached_result is not None + assert isinstance(cached_result, CreateResult) + assert cached_result.content == "reconstructed response" + assert cache_key is not None + + +def test_check_cache_dict_reconstruction_failure() -> None: + """Test _check_cache handles ValidationError when dict cannot be reconstructed. + This tests the except ValidationError block for single dicts. + """ + _, prompts, system_prompt, replay_client, _ = get_test_data() + + # Create an invalid dict that will fail CreateResult validation + invalid_dict = { + "invalid_field": "value", + "missing_required_fields": True, + } + + # Create a MockCacheStore that returns the invalid dict + mock_store = MockCacheStore(return_value=cast(Any, invalid_dict)) + cached_client = ChatCompletionCache(replay_client, mock_store) + + # Test _check_cache method + messages = [system_prompt, UserMessage(content=prompts[0], source="user")] + cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore + + # Should return None (cache miss) when reconstruction fails + assert cached_result is None + assert cache_key is not None + + +def test_check_cache_list_reconstruction_success() -> None: + """Test _check_cache successfully reconstructs CreateResult objects from dicts in a list. + This tests the line: reconstructed_list.append(CreateResult.model_validate(item)) + """ + _, prompts, system_prompt, replay_client, _ = get_test_data() + + # Create a list with valid dicts that can be reconstructed + valid_dict1 = { + "content": "first result", + "usage": {"prompt_tokens": 8, "completion_tokens": 3}, + "cached": False, + "finish_reason": "stop", + } + valid_dict2 = { + "content": "second result", + "usage": {"prompt_tokens": 12, "completion_tokens": 7}, + "cached": False, + "finish_reason": "stop", + } + + cached_list = [ + "streaming chunk 1", + valid_dict1, + "streaming chunk 2", + valid_dict2, + ] + + # Create a MockCacheStore that returns the list with dicts + mock_store = MockCacheStore(return_value=cast(Any, cached_list)) + cached_client = ChatCompletionCache(replay_client, mock_store) + + # Test _check_cache method + messages = [system_prompt, UserMessage(content=prompts[0], source="user")] + cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore + + # Should successfully reconstruct the list with CreateResult objects + assert cached_result is not None + assert isinstance(cached_result, list) + assert len(cached_result) == 4 + assert cached_result[0] == "streaming chunk 1" + assert isinstance(cached_result[1], CreateResult) + assert cached_result[1].content == "first result" + assert cached_result[2] == "streaming chunk 2" + assert isinstance(cached_result[3], CreateResult) + assert cached_result[3].content == "second result" + assert cache_key is not None + + +def test_check_cache_list_reconstruction_failure() -> None: + """Test _check_cache handles ValidationError when list contains invalid dicts. + This tests the except ValidationError block for lists. + """ + _, prompts, system_prompt, replay_client, _ = get_test_data() + + # Create a list with an invalid dict that will fail validation + invalid_dict = { + "invalid_field": "value", + "missing_required": True, + } + + cached_list = [ + "streaming chunk 1", + invalid_dict, # This will cause ValidationError + "streaming chunk 2", + ] + + # Create a MockCacheStore that returns the list with invalid dict + mock_store = MockCacheStore(return_value=cast(Any, cached_list)) + cached_client = ChatCompletionCache(replay_client, mock_store) + + # Test _check_cache method + messages = [system_prompt, UserMessage(content=prompts[0], source="user")] + cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore + + # Should return None (cache miss) when list reconstruction fails + assert cached_result is None + assert cache_key is not None + + +def test_check_cache_already_correct_type() -> None: + """Test _check_cache returns data as-is when it's already the correct type. + This tests the final return path when no reconstruction is needed. + """ + _, prompts, system_prompt, replay_client, _ = get_test_data() + + # Create a proper CreateResult object (already correct type) + create_result = CreateResult( + content="already correct type", + usage=RequestUsage(prompt_tokens=15, completion_tokens=8), + cached=False, + finish_reason="stop", + ) + + # Create a MockCacheStore that returns the proper type + mock_store = MockCacheStore(return_value=create_result) + cached_client = ChatCompletionCache(replay_client, mock_store) + + # Test _check_cache method + messages = [system_prompt, UserMessage(content=prompts[0], source="user")] + cached_result, cache_key = cached_client._check_cache(messages, [], None, {}) # type: ignore + + # Should return the same object without reconstruction + assert cached_result is not None + assert cached_result is create_result # Same object reference + assert isinstance(cached_result, CreateResult) + assert cached_result.content == "already correct type" + assert cache_key is not None