mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-12 23:41:28 +00:00
Make ChatCompletionCache support component config (#5658)
<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> This PR makes makes ChatCompletionCache support component config <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Ensures we have a path to serializing ChatCompletionCache , similar to the ChatCompletion client that it wraps. This PR does the following - Makes CacheStore serializable first (part of this includes converting from Protocol to base class). Makes it's derivatives serializable as well (diskcache, redis) - Makes ChatCompletionCache serializable - Adds some tests <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number <!-- For example: "Closes #1234" --> Closes #5141 ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. cc @nour-bouzid
This commit is contained in:
parent
a226966dbe
commit
170b8cc893
@ -1,15 +1,24 @@
|
||||
from typing import Dict, Generic, Optional, Protocol, TypeVar
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._component_config import Component, ComponentBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class CacheStore(Protocol, Generic[T]):
|
||||
class CacheStore(ABC, Generic[T], ComponentBase[BaseModel]):
|
||||
"""
|
||||
This protocol defines the basic interface for store/cache operations.
|
||||
|
||||
Sub-classes should handle the lifecycle of underlying storage.
|
||||
"""
|
||||
|
||||
component_type = "cache_store"
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
|
||||
"""
|
||||
Retrieve an item from the store.
|
||||
@ -24,6 +33,7 @@ class CacheStore(Protocol, Generic[T]):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: T) -> None:
|
||||
"""
|
||||
Set an item in the store.
|
||||
@ -35,7 +45,14 @@ class CacheStore(Protocol, Generic[T]):
|
||||
...
|
||||
|
||||
|
||||
class InMemoryStore(CacheStore[T]):
|
||||
class InMemoryStoreConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryStore(CacheStore[T], Component[InMemoryStoreConfig]):
|
||||
component_provider_override = "autogen_core.InMemoryStore"
|
||||
component_config_schema = InMemoryStoreConfig
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.store: Dict[str, T] = {}
|
||||
|
||||
@ -44,3 +61,10 @@ class InMemoryStore(CacheStore[T]):
|
||||
|
||||
def set(self, key: str, value: T) -> None:
|
||||
self.store[key] = value
|
||||
|
||||
def _to_config(self) -> InMemoryStoreConfig:
|
||||
return InMemoryStoreConfig()
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: InMemoryStoreConfig) -> Self:
|
||||
return cls()
|
||||
|
||||
@ -1,12 +1,21 @@
|
||||
from typing import Any, Optional, TypeVar, cast
|
||||
|
||||
import diskcache
|
||||
from autogen_core import CacheStore
|
||||
from autogen_core import CacheStore, Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DiskCacheStore(CacheStore[T]):
|
||||
class DiskCacheStoreConfig(BaseModel):
|
||||
"""Configuration for DiskCacheStore"""
|
||||
|
||||
directory: str # Path where cache is stored
|
||||
# Could add other diskcache.Cache parameters like size_limit, etc.
|
||||
|
||||
|
||||
class DiskCacheStore(CacheStore[T], Component[DiskCacheStoreConfig]):
|
||||
"""
|
||||
A typed CacheStore implementation that uses diskcache as the underlying storage.
|
||||
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
|
||||
@ -16,6 +25,9 @@ class DiskCacheStore(CacheStore[T]):
|
||||
The user is responsible for managing the DiskCache instance's lifetime.
|
||||
"""
|
||||
|
||||
component_config_schema = DiskCacheStoreConfig
|
||||
component_provider_override = "autogen_ext.cache_store.diskcache.DiskCacheStore"
|
||||
|
||||
def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported]
|
||||
self.cache = cache_instance
|
||||
|
||||
@ -24,3 +36,11 @@ class DiskCacheStore(CacheStore[T]):
|
||||
|
||||
def set(self, key: str, value: T) -> None:
|
||||
self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
def _to_config(self) -> DiskCacheStoreConfig:
|
||||
# Get directory from cache instance
|
||||
return DiskCacheStoreConfig(directory=self.cache.directory)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: DiskCacheStoreConfig) -> Self:
|
||||
return cls(cache_instance=diskcache.Cache(config.directory)) # type: ignore[no-any-return]
|
||||
|
||||
@ -1,12 +1,27 @@
|
||||
from typing import Any, Optional, TypeVar, cast
|
||||
from typing import Any, Dict, Optional, TypeVar, cast
|
||||
|
||||
import redis
|
||||
from autogen_core import CacheStore
|
||||
from autogen_core import CacheStore, Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedisStore(CacheStore[T]):
|
||||
class RedisStoreConfig(BaseModel):
|
||||
"""Configuration for RedisStore"""
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
db: int = 0
|
||||
# Add other relevant redis connection parameters
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
ssl: bool = False
|
||||
socket_timeout: Optional[float] = None
|
||||
|
||||
|
||||
class RedisStore(CacheStore[T], Component[RedisStoreConfig]):
|
||||
"""
|
||||
A typed CacheStore implementation that uses redis as the underlying storage.
|
||||
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
|
||||
@ -16,6 +31,9 @@ class RedisStore(CacheStore[T]):
|
||||
The user is responsible for managing the Redis instance's lifetime.
|
||||
"""
|
||||
|
||||
component_config_schema = RedisStoreConfig
|
||||
component_provider_override = "autogen_ext.cache_store.redis.RedisStore"
|
||||
|
||||
def __init__(self, redis_instance: redis.Redis):
|
||||
self.cache = redis_instance
|
||||
|
||||
@ -27,3 +45,36 @@ class RedisStore(CacheStore[T]):
|
||||
|
||||
def set(self, key: str, value: T) -> None:
|
||||
self.cache.set(key, cast(Any, value))
|
||||
|
||||
def _to_config(self) -> RedisStoreConfig:
|
||||
# Extract connection info from redis instance
|
||||
connection_pool = self.cache.connection_pool
|
||||
connection_kwargs: Dict[str, Any] = connection_pool.connection_kwargs # type: ignore[reportUnknownMemberType]
|
||||
|
||||
username = connection_kwargs.get("username")
|
||||
password = connection_kwargs.get("password")
|
||||
socket_timeout = connection_kwargs.get("socket_timeout")
|
||||
|
||||
return RedisStoreConfig(
|
||||
host=str(connection_kwargs.get("host", "localhost")),
|
||||
port=int(connection_kwargs.get("port", 6379)),
|
||||
db=int(connection_kwargs.get("db", 0)),
|
||||
username=str(username) if username is not None else None,
|
||||
password=str(password) if password is not None else None,
|
||||
ssl=bool(connection_kwargs.get("ssl", False)),
|
||||
socket_timeout=float(socket_timeout) if socket_timeout is not None else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: RedisStoreConfig) -> Self:
|
||||
# Create new redis instance from config
|
||||
redis_instance = redis.Redis(
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
db=config.db,
|
||||
username=config.username,
|
||||
password=config.password,
|
||||
ssl=config.ssl,
|
||||
socket_timeout=config.socket_timeout,
|
||||
)
|
||||
return cls(redis_instance=redis_instance)
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
from autogen_core import CacheStore, CancellationToken, InMemoryStore
|
||||
from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
|
||||
from autogen_core.models import (
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
@ -13,11 +13,20 @@ from autogen_core.models import (
|
||||
RequestUsage,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]
|
||||
|
||||
|
||||
class ChatCompletionCache(ChatCompletionClient):
|
||||
class ChatCompletionCacheConfig(BaseModel):
|
||||
""" """
|
||||
|
||||
client: ComponentModel
|
||||
store: Optional[ComponentModel] = None
|
||||
|
||||
|
||||
class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheConfig]):
|
||||
"""
|
||||
A wrapper around a :class:`~autogen_ext.models.cache.ChatCompletionClient` that caches
|
||||
creation results from an underlying client.
|
||||
@ -77,6 +86,10 @@ class ChatCompletionCache(ChatCompletionClient):
|
||||
Defaults to using in-memory cache.
|
||||
"""
|
||||
|
||||
component_type = "chat_completion_cache"
|
||||
component_provider_override = "autogen_ext.models.cache.ChatCompletionCache"
|
||||
component_config_schema = ChatCompletionCacheConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChatCompletionClient,
|
||||
@ -213,3 +226,17 @@ class ChatCompletionCache(ChatCompletionClient):
|
||||
|
||||
def total_usage(self) -> RequestUsage:
|
||||
return self.client.total_usage()
|
||||
|
||||
def _to_config(self) -> ChatCompletionCacheConfig:
|
||||
return ChatCompletionCacheConfig(
|
||||
client=self.client.dump_component(),
|
||||
store=self.store.dump_component() if not isinstance(self.store, InMemoryStore) else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: ChatCompletionCacheConfig) -> Self:
|
||||
client = ChatCompletionClient.load_component(config.client)
|
||||
store: Optional[CacheStore[CHAT_CACHE_VALUE_TYPE]] = (
|
||||
CacheStore.load_component(config.store) if config.store else InMemoryStore()
|
||||
)
|
||||
return cls(client=client, store=store)
|
||||
|
||||
@ -46,3 +46,8 @@ def test_diskcache_with_different_instances() -> None:
|
||||
|
||||
store_2.set(test_key, test_value_2)
|
||||
assert store_2.get(test_key) == test_value_2
|
||||
|
||||
# test serialization
|
||||
store_1_config = store_1.dump_component()
|
||||
loaded_store_1: DiskCacheStore[int] = DiskCacheStore.load_component(store_1_config)
|
||||
assert loaded_store_1.get(test_key) == test_value_1
|
||||
|
||||
@ -51,3 +51,8 @@ def test_redis_with_different_instances() -> None:
|
||||
redis_instance_2.set.assert_called_with(test_key, test_value_2)
|
||||
redis_instance_2.get.return_value = test_value_2
|
||||
assert store_2.get(test_key) == test_value_2
|
||||
|
||||
# test serialization
|
||||
store_1_config = store_1.dump_component()
|
||||
assert store_1_config.component_type == "cache_store"
|
||||
assert store_1_config.component_version == 1
|
||||
|
||||
@ -129,3 +129,8 @@ async def test_cache_create_stream() -> None:
|
||||
assert not original.cached
|
||||
else:
|
||||
raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}")
|
||||
|
||||
# test serialization
|
||||
# cached_client_config = cached_client.dump_component()
|
||||
# loaded_client = ChatCompletionCache.load_component(cached_client_config)
|
||||
# assert loaded_client.client == cached_client.client
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user