Make ChatCompletionCache support component config (#5658)

<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

This PR makes makes ChatCompletionCache   support component config

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed? 

Ensures we have a path to serializing ChatCompletionCache , similar to
the ChatCompletion client that it wraps.

This PR does the following

- Makes CacheStore serializable first (part of this includes converting
from Protocol to base class). Makes it's derivatives serializable as
well (diskcache, redis)
- Makes ChatCompletionCache serializable 
- Adds some tests

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

<!-- For example: "Closes #1234" -->

Closes #5141

## Checks

- [ ] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [ ] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [ ] I've made sure all auto checks have passed. 


cc @nour-bouzid
This commit is contained in:
Victor Dibia 2025-02-23 19:49:22 -08:00 committed by GitHub
parent a226966dbe
commit 170b8cc893
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 147 additions and 10 deletions

View File

@ -1,15 +1,24 @@
from typing import Dict, Generic, Optional, Protocol, TypeVar from abc import ABC, abstractmethod
from typing import Dict, Generic, Optional, TypeVar
from pydantic import BaseModel
from typing_extensions import Self
from ._component_config import Component, ComponentBase
T = TypeVar("T") T = TypeVar("T")
class CacheStore(Protocol, Generic[T]): class CacheStore(ABC, Generic[T], ComponentBase[BaseModel]):
""" """
This protocol defines the basic interface for store/cache operations. This protocol defines the basic interface for store/cache operations.
Sub-classes should handle the lifecycle of underlying storage. Sub-classes should handle the lifecycle of underlying storage.
""" """
component_type = "cache_store"
@abstractmethod
def get(self, key: str, default: Optional[T] = None) -> Optional[T]: def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
""" """
Retrieve an item from the store. Retrieve an item from the store.
@ -24,6 +33,7 @@ class CacheStore(Protocol, Generic[T]):
""" """
... ...
@abstractmethod
def set(self, key: str, value: T) -> None: def set(self, key: str, value: T) -> None:
""" """
Set an item in the store. Set an item in the store.
@ -35,7 +45,14 @@ class CacheStore(Protocol, Generic[T]):
... ...
class InMemoryStore(CacheStore[T]): class InMemoryStoreConfig(BaseModel):
pass
class InMemoryStore(CacheStore[T], Component[InMemoryStoreConfig]):
component_provider_override = "autogen_core.InMemoryStore"
component_config_schema = InMemoryStoreConfig
def __init__(self) -> None: def __init__(self) -> None:
self.store: Dict[str, T] = {} self.store: Dict[str, T] = {}
@ -44,3 +61,10 @@ class InMemoryStore(CacheStore[T]):
def set(self, key: str, value: T) -> None: def set(self, key: str, value: T) -> None:
self.store[key] = value self.store[key] = value
def _to_config(self) -> InMemoryStoreConfig:
return InMemoryStoreConfig()
@classmethod
def _from_config(cls, config: InMemoryStoreConfig) -> Self:
return cls()

View File

@ -1,12 +1,21 @@
from typing import Any, Optional, TypeVar, cast from typing import Any, Optional, TypeVar, cast
import diskcache import diskcache
from autogen_core import CacheStore from autogen_core import CacheStore, Component
from pydantic import BaseModel
from typing_extensions import Self
T = TypeVar("T") T = TypeVar("T")
class DiskCacheStore(CacheStore[T]): class DiskCacheStoreConfig(BaseModel):
"""Configuration for DiskCacheStore"""
directory: str # Path where cache is stored
# Could add other diskcache.Cache parameters like size_limit, etc.
class DiskCacheStore(CacheStore[T], Component[DiskCacheStoreConfig]):
""" """
A typed CacheStore implementation that uses diskcache as the underlying storage. A typed CacheStore implementation that uses diskcache as the underlying storage.
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage. See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
@ -16,6 +25,9 @@ class DiskCacheStore(CacheStore[T]):
The user is responsible for managing the DiskCache instance's lifetime. The user is responsible for managing the DiskCache instance's lifetime.
""" """
component_config_schema = DiskCacheStoreConfig
component_provider_override = "autogen_ext.cache_store.diskcache.DiskCacheStore"
def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported] def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported]
self.cache = cache_instance self.cache = cache_instance
@ -24,3 +36,11 @@ class DiskCacheStore(CacheStore[T]):
def set(self, key: str, value: T) -> None: def set(self, key: str, value: T) -> None:
self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType] self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType]
def _to_config(self) -> DiskCacheStoreConfig:
# Get directory from cache instance
return DiskCacheStoreConfig(directory=self.cache.directory)
@classmethod
def _from_config(cls, config: DiskCacheStoreConfig) -> Self:
return cls(cache_instance=diskcache.Cache(config.directory)) # type: ignore[no-any-return]

View File

@ -1,12 +1,27 @@
from typing import Any, Optional, TypeVar, cast from typing import Any, Dict, Optional, TypeVar, cast
import redis import redis
from autogen_core import CacheStore from autogen_core import CacheStore, Component
from pydantic import BaseModel
from typing_extensions import Self
T = TypeVar("T") T = TypeVar("T")
class RedisStore(CacheStore[T]): class RedisStoreConfig(BaseModel):
"""Configuration for RedisStore"""
host: str = "localhost"
port: int = 6379
db: int = 0
# Add other relevant redis connection parameters
username: Optional[str] = None
password: Optional[str] = None
ssl: bool = False
socket_timeout: Optional[float] = None
class RedisStore(CacheStore[T], Component[RedisStoreConfig]):
""" """
A typed CacheStore implementation that uses redis as the underlying storage. A typed CacheStore implementation that uses redis as the underlying storage.
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage. See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
@ -16,6 +31,9 @@ class RedisStore(CacheStore[T]):
The user is responsible for managing the Redis instance's lifetime. The user is responsible for managing the Redis instance's lifetime.
""" """
component_config_schema = RedisStoreConfig
component_provider_override = "autogen_ext.cache_store.redis.RedisStore"
def __init__(self, redis_instance: redis.Redis): def __init__(self, redis_instance: redis.Redis):
self.cache = redis_instance self.cache = redis_instance
@ -27,3 +45,36 @@ class RedisStore(CacheStore[T]):
def set(self, key: str, value: T) -> None: def set(self, key: str, value: T) -> None:
self.cache.set(key, cast(Any, value)) self.cache.set(key, cast(Any, value))
def _to_config(self) -> RedisStoreConfig:
# Extract connection info from redis instance
connection_pool = self.cache.connection_pool
connection_kwargs: Dict[str, Any] = connection_pool.connection_kwargs # type: ignore[reportUnknownMemberType]
username = connection_kwargs.get("username")
password = connection_kwargs.get("password")
socket_timeout = connection_kwargs.get("socket_timeout")
return RedisStoreConfig(
host=str(connection_kwargs.get("host", "localhost")),
port=int(connection_kwargs.get("port", 6379)),
db=int(connection_kwargs.get("db", 0)),
username=str(username) if username is not None else None,
password=str(password) if password is not None else None,
ssl=bool(connection_kwargs.get("ssl", False)),
socket_timeout=float(socket_timeout) if socket_timeout is not None else None,
)
@classmethod
def _from_config(cls, config: RedisStoreConfig) -> Self:
# Create new redis instance from config
redis_instance = redis.Redis(
host=config.host,
port=config.port,
db=config.db,
username=config.username,
password=config.password,
ssl=config.ssl,
socket_timeout=config.socket_timeout,
)
return cls(redis_instance=redis_instance)

View File

@ -3,7 +3,7 @@ import json
import warnings import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast
from autogen_core import CacheStore, CancellationToken, InMemoryStore from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
from autogen_core.models import ( from autogen_core.models import (
ChatCompletionClient, ChatCompletionClient,
CreateResult, CreateResult,
@ -13,11 +13,20 @@ from autogen_core.models import (
RequestUsage, RequestUsage,
) )
from autogen_core.tools import Tool, ToolSchema from autogen_core.tools import Tool, ToolSchema
from pydantic import BaseModel
from typing_extensions import Self
CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]] CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]
class ChatCompletionCache(ChatCompletionClient): class ChatCompletionCacheConfig(BaseModel):
""" """
client: ComponentModel
store: Optional[ComponentModel] = None
class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheConfig]):
""" """
A wrapper around a :class:`~autogen_ext.models.cache.ChatCompletionClient` that caches A wrapper around a :class:`~autogen_ext.models.cache.ChatCompletionClient` that caches
creation results from an underlying client. creation results from an underlying client.
@ -77,6 +86,10 @@ class ChatCompletionCache(ChatCompletionClient):
Defaults to using in-memory cache. Defaults to using in-memory cache.
""" """
component_type = "chat_completion_cache"
component_provider_override = "autogen_ext.models.cache.ChatCompletionCache"
component_config_schema = ChatCompletionCacheConfig
def __init__( def __init__(
self, self,
client: ChatCompletionClient, client: ChatCompletionClient,
@ -213,3 +226,17 @@ class ChatCompletionCache(ChatCompletionClient):
def total_usage(self) -> RequestUsage: def total_usage(self) -> RequestUsage:
return self.client.total_usage() return self.client.total_usage()
def _to_config(self) -> ChatCompletionCacheConfig:
return ChatCompletionCacheConfig(
client=self.client.dump_component(),
store=self.store.dump_component() if not isinstance(self.store, InMemoryStore) else None,
)
@classmethod
def _from_config(cls, config: ChatCompletionCacheConfig) -> Self:
client = ChatCompletionClient.load_component(config.client)
store: Optional[CacheStore[CHAT_CACHE_VALUE_TYPE]] = (
CacheStore.load_component(config.store) if config.store else InMemoryStore()
)
return cls(client=client, store=store)

View File

@ -46,3 +46,8 @@ def test_diskcache_with_different_instances() -> None:
store_2.set(test_key, test_value_2) store_2.set(test_key, test_value_2)
assert store_2.get(test_key) == test_value_2 assert store_2.get(test_key) == test_value_2
# test serialization
store_1_config = store_1.dump_component()
loaded_store_1: DiskCacheStore[int] = DiskCacheStore.load_component(store_1_config)
assert loaded_store_1.get(test_key) == test_value_1

View File

@ -51,3 +51,8 @@ def test_redis_with_different_instances() -> None:
redis_instance_2.set.assert_called_with(test_key, test_value_2) redis_instance_2.set.assert_called_with(test_key, test_value_2)
redis_instance_2.get.return_value = test_value_2 redis_instance_2.get.return_value = test_value_2
assert store_2.get(test_key) == test_value_2 assert store_2.get(test_key) == test_value_2
# test serialization
store_1_config = store_1.dump_component()
assert store_1_config.component_type == "cache_store"
assert store_1_config.component_version == 1

View File

@ -129,3 +129,8 @@ async def test_cache_create_stream() -> None:
assert not original.cached assert not original.cached
else: else:
raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}") raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}")
# test serialization
# cached_client_config = cached_client.dump_component()
# loaded_client = ChatCompletionCache.load_component(cached_client_config)
# assert loaded_client.client == cached_client.client