Implement default in-memory store for ChatCompletionCache (#5188)

This commit is contained in:
Sachin Joglekar 2025-01-25 13:07:58 -08:00 committed by GitHub
parent 67029853ec
commit 8926206479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 8 deletions

View File

@ -1,8 +1,8 @@
import asyncio
import functools
import warnings
from textwrap import dedent
from typing import Any, Callable, Sequence
import warnings
from pydantic import BaseModel
from typing_extensions import Self

View File

@ -3,7 +3,7 @@ import json
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast
from autogen_core import CacheStore, CancellationToken
from autogen_core import CacheStore, CancellationToken, InMemoryStore
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
@ -74,11 +74,16 @@ class ChatCompletionCache(ChatCompletionClient):
client (ChatCompletionClient): The original ChatCompletionClient to wrap.
store (CacheStore): A store object that implements get and set methods.
The user is responsible for managing the store's lifecycle & clearing it (if needed).
Defaults to using in-memory cache.
"""
def __init__(self, client: ChatCompletionClient, store: CacheStore[CHAT_CACHE_VALUE_TYPE]):
def __init__(
self,
client: ChatCompletionClient,
store: Optional[CacheStore[CHAT_CACHE_VALUE_TYPE]] = None,
):
self.client = client
self.store = store
self.store = store or InMemoryStore[CHAT_CACHE_VALUE_TYPE]()
def _check_cache(
self,

View File

@ -2,7 +2,6 @@ import copy
from typing import List, Tuple, Union
import pytest
from autogen_core import InMemoryStore
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
@ -10,7 +9,7 @@ from autogen_core.models import (
SystemMessage,
UserMessage,
)
from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache
from autogen_ext.models.cache import ChatCompletionCache
from autogen_ext.models.replay import ReplayChatCompletionClient
@ -21,8 +20,7 @@ def get_test_data() -> Tuple[list[str], list[str], SystemMessage, ChatCompletion
system_prompt = SystemMessage(content="This is a system prompt")
replay_client = ReplayChatCompletionClient(responses)
replay_client.set_cached_bool_value(False)
store = InMemoryStore[CHAT_CACHE_VALUE_TYPE]()
cached_client = ChatCompletionCache(replay_client, store)
cached_client = ChatCompletionCache(replay_client)
return responses, prompts, system_prompt, replay_client, cached_client