autogen/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py
Victor Dibia 170b8cc893
Make ChatCompletionCache support component config (#5658)
<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

This PR makes makes ChatCompletionCache   support component config

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed? 

Ensures we have a path to serializing ChatCompletionCache , similar to
the ChatCompletion client that it wraps.

This PR does the following

- Makes CacheStore serializable first (part of this includes converting
from Protocol to base class). Makes it's derivatives serializable as
well (diskcache, redis)
- Makes ChatCompletionCache serializable 
- Adds some tests

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

<!-- For example: "Closes #1234" -->

Closes #5141

## Checks

- [ ] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [ ] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [ ] I've made sure all auto checks have passed. 


cc @nour-bouzid
2025-02-23 19:49:22 -08:00

137 lines
5.8 KiB
Python

import copy
from typing import List, Tuple, Union
import pytest
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_ext.models.cache import ChatCompletionCache
from autogen_ext.models.replay import ReplayChatCompletionClient
def get_test_data() -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]:
num_messages = 3
responses = [f"This is dummy message number {i}" for i in range(num_messages)]
prompts = [f"This is dummy prompt number {i}" for i in range(num_messages)]
system_prompt = SystemMessage(content="This is a system prompt")
replay_client = ReplayChatCompletionClient(responses)
replay_client.set_cached_bool_value(False)
cached_client = ChatCompletionCache(replay_client)
return responses, prompts, system_prompt, replay_client, cached_client
@pytest.mark.asyncio
async def test_cache_basic_with_args() -> None:
responses, prompts, system_prompt, _, cached_client = get_test_data()
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[0]
response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")])
assert not response1.cached
assert response1.content == responses[1]
# Cached output.
response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert response0_cached.cached
assert response0_cached.content == responses[0]
# Cache miss if args change.
response2 = await cached_client.create(
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=True
)
assert isinstance(response2, CreateResult)
assert not response2.cached
assert response2.content == responses[2]
@pytest.mark.asyncio
async def test_cache_model_and_count_api() -> None:
_, prompts, system_prompt, replay_client, cached_client = get_test_data()
assert replay_client.model_info == cached_client.model_info
assert replay_client.capabilities == cached_client.capabilities
messages: List[LLMMessage] = [system_prompt, UserMessage(content=prompts[0], source="user")]
assert replay_client.count_tokens(messages) == cached_client.count_tokens(messages)
assert replay_client.remaining_tokens(messages) == cached_client.remaining_tokens(messages)
@pytest.mark.asyncio
async def test_cache_token_usage() -> None:
responses, prompts, system_prompt, replay_client, cached_client = get_test_data()
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[0]
actual_usage0 = copy.copy(cached_client.actual_usage())
total_usage0 = copy.copy(cached_client.total_usage())
response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")])
assert not response1.cached
assert response1.content == responses[1]
actual_usage1 = copy.copy(cached_client.actual_usage())
total_usage1 = copy.copy(cached_client.total_usage())
assert total_usage1.prompt_tokens > total_usage0.prompt_tokens
assert total_usage1.completion_tokens > total_usage0.completion_tokens
assert actual_usage1.prompt_tokens == actual_usage0.prompt_tokens
assert actual_usage1.completion_tokens == actual_usage0.completion_tokens
# Cached output.
response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert response0_cached.cached
assert response0_cached.content == responses[0]
total_usage2 = copy.copy(cached_client.total_usage())
assert total_usage2.prompt_tokens == total_usage1.prompt_tokens
assert total_usage2.completion_tokens == total_usage1.completion_tokens
assert cached_client.actual_usage() == replay_client.actual_usage()
assert cached_client.total_usage() == replay_client.total_usage()
@pytest.mark.asyncio
async def test_cache_create_stream() -> None:
_, prompts, system_prompt, _, cached_client = get_test_data()
original_streamed_results: List[Union[str, CreateResult]] = []
async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")]
):
original_streamed_results.append(copy.copy(completion))
total_usage0 = copy.copy(cached_client.total_usage())
cached_completion_results: List[Union[str, CreateResult]] = []
async for completion in cached_client.create_stream(
[system_prompt, UserMessage(content=prompts[0], source="user")]
):
cached_completion_results.append(copy.copy(completion))
total_usage1 = copy.copy(cached_client.total_usage())
assert total_usage1.prompt_tokens == total_usage0.prompt_tokens
assert total_usage1.completion_tokens == total_usage0.completion_tokens
for original, cached in zip(original_streamed_results, cached_completion_results, strict=False):
if isinstance(original, str):
assert original == cached
elif isinstance(original, CreateResult) and isinstance(cached, CreateResult):
assert original.content == cached.content
assert cached.cached
assert not original.cached
else:
raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}")
# test serialization
# cached_client_config = cached_client.dump_component()
# loaded_client = ChatCompletionCache.load_component(cached_client_config)
# assert loaded_client.client == cached_client.client