mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-25 16:16:37 +00:00
[Core] implement redis cache mode (#1222)
* implement redis cache mode, if redis_url is set in the llm_config then it will try to use this. also adds a test to validate both the existing and the redis cache behavior. * PR feedback, add unit tests * more PR feedback, move the new style cache to a context manager * Update agent_chat.md * more PR feedback, remove tests from contrib and have them run with the normal jobs * doc * updated * Update website/docs/Use-Cases/agent_chat.md Co-authored-by: Chi Wang <wang.chi@microsoft.com> * update docs * update docs; let openaiwrapper to use cache object * typo * Update website/docs/Use-Cases/enhanced_inference.md Co-authored-by: Chi Wang <wang.chi@microsoft.com> * save previous client cache and reset it after send/a_send * a_run_chat --------- Co-authored-by: Vijay Ramesh <vijay@regrello.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
e97b6395af
commit
ee6ad8d519
1
.github/workflows/build.yml
vendored
1
.github/workflows/build.yml
vendored
@ -54,6 +54,7 @@ jobs:
|
||||
if: matrix.python-version == '3.10'
|
||||
run: |
|
||||
pip install -e .[test]
|
||||
pip install -e .[redis]
|
||||
coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai
|
||||
coverage xml
|
||||
- name: Upload coverage to Codecov
|
||||
|
7
.github/workflows/openai.yml
vendored
7
.github/workflows/openai.yml
vendored
@ -21,6 +21,12 @@ jobs:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
environment: openai1
|
||||
services:
|
||||
redis:
|
||||
image: redis
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: --entrypoint redis-server
|
||||
steps:
|
||||
# checkout to pr branch
|
||||
- name: Checkout
|
||||
@ -42,6 +48,7 @@ jobs:
|
||||
if: matrix.python-version == '3.9'
|
||||
run: |
|
||||
pip install docker
|
||||
pip install -e .[redis]
|
||||
- name: Coverage
|
||||
if: matrix.python-version == '3.9'
|
||||
env:
|
||||
|
@ -9,6 +9,7 @@ from collections import defaultdict
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
from .. import OpenAIWrapper
|
||||
from ..cache.cache import Cache
|
||||
from ..code_utils import (
|
||||
DEFAULT_MODEL,
|
||||
UNKNOWN,
|
||||
@ -135,6 +136,9 @@ class ConversableAgent(Agent):
|
||||
self.llm_config.update(llm_config)
|
||||
self.client = OpenAIWrapper(**self.llm_config)
|
||||
|
||||
# Initialize standalone client cache object.
|
||||
self.client_cache = None
|
||||
|
||||
self._code_execution_config: Union[Dict, Literal[False]] = (
|
||||
{} if code_execution_config is None else code_execution_config
|
||||
)
|
||||
@ -665,6 +669,7 @@ class ConversableAgent(Agent):
|
||||
recipient: "ConversableAgent",
|
||||
clear_history: Optional[bool] = True,
|
||||
silent: Optional[bool] = False,
|
||||
cache: Optional[Cache] = None,
|
||||
**context,
|
||||
):
|
||||
"""Initiate a chat with the recipient agent.
|
||||
@ -677,6 +682,7 @@ class ConversableAgent(Agent):
|
||||
recipient: the recipient agent.
|
||||
clear_history (bool): whether to clear the chat history with the agent.
|
||||
silent (bool or None): (Experimental) whether to print the messages for this conversation.
|
||||
cache (Cache or None): the cache client to be used for this conversation.
|
||||
**context: any context information.
|
||||
"message" needs to be provided if the `generate_init_message` method is not overridden.
|
||||
Otherwise, input() will be called to get the initial message.
|
||||
@ -686,14 +692,20 @@ class ConversableAgent(Agent):
|
||||
"""
|
||||
for agent in [self, recipient]:
|
||||
agent._raise_exception_on_async_reply_functions()
|
||||
agent.previous_cache = agent.client_cache
|
||||
agent.client_cache = cache
|
||||
self._prepare_chat(recipient, clear_history)
|
||||
self.send(self.generate_init_message(**context), recipient, silent=silent)
|
||||
for agent in [self, recipient]:
|
||||
agent.client_cache = agent.previous_cache
|
||||
agent.previous_cache = None
|
||||
|
||||
async def a_initiate_chat(
|
||||
self,
|
||||
recipient: "ConversableAgent",
|
||||
clear_history: Optional[bool] = True,
|
||||
silent: Optional[bool] = False,
|
||||
cache: Optional[Cache] = None,
|
||||
**context,
|
||||
):
|
||||
"""(async) Initiate a chat with the recipient agent.
|
||||
@ -706,12 +718,19 @@ class ConversableAgent(Agent):
|
||||
recipient: the recipient agent.
|
||||
clear_history (bool): whether to clear the chat history with the agent.
|
||||
silent (bool or None): (Experimental) whether to print the messages for this conversation.
|
||||
cache (Cache or None): the cache client to be used for this conversation.
|
||||
**context: any context information.
|
||||
"message" needs to be provided if the `generate_init_message` method is not overridden.
|
||||
Otherwise, input() will be called to get the initial message.
|
||||
"""
|
||||
self._prepare_chat(recipient, clear_history)
|
||||
for agent in [self, recipient]:
|
||||
agent.previous_cache = agent.client_cache
|
||||
agent.client_cache = cache
|
||||
await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)
|
||||
for agent in [self, recipient]:
|
||||
agent.client_cache = agent.previous_cache
|
||||
agent.previous_cache = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset the agent."""
|
||||
@ -778,7 +797,9 @@ class ConversableAgent(Agent):
|
||||
|
||||
# TODO: #1143 handle token limit exceeded error
|
||||
response = client.create(
|
||||
context=messages[-1].pop("context", None), messages=self._oai_system_message + all_messages
|
||||
context=messages[-1].pop("context", None),
|
||||
messages=self._oai_system_message + all_messages,
|
||||
cache=self.client_cache,
|
||||
)
|
||||
|
||||
extracted_response = client.extract_text_or_completion_object(response)[0]
|
||||
|
@ -349,13 +349,17 @@ class GroupChatManager(ConversableAgent):
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[GroupChat] = None,
|
||||
) -> Union[str, Dict, None]:
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""Run a group chat."""
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
speaker = sender
|
||||
groupchat = config
|
||||
if self.client_cache is not None:
|
||||
for a in groupchat.agents:
|
||||
a.previous_cache = a.client_cache
|
||||
a.client_cache = self.client_cache
|
||||
for i in range(groupchat.max_round):
|
||||
groupchat.append(message, speaker)
|
||||
if self._is_termination_msg(message):
|
||||
@ -389,6 +393,10 @@ class GroupChatManager(ConversableAgent):
|
||||
message = self.last_message(speaker)
|
||||
if i == groupchat.max_round - 1:
|
||||
groupchat.append(message, speaker)
|
||||
if self.client_cache is not None:
|
||||
for a in groupchat.agents:
|
||||
a.client_cache = a.previous_cache
|
||||
a.previous_cache = None
|
||||
return True, None
|
||||
|
||||
async def a_run_chat(
|
||||
@ -403,6 +411,10 @@ class GroupChatManager(ConversableAgent):
|
||||
message = messages[-1]
|
||||
speaker = sender
|
||||
groupchat = config
|
||||
if self.client_cache is not None:
|
||||
for a in groupchat.agents:
|
||||
a.previous_cache = a.client_cache
|
||||
a.client_cache = self.client_cache
|
||||
for i in range(groupchat.max_round):
|
||||
groupchat.append(message, speaker)
|
||||
|
||||
@ -436,6 +448,10 @@ class GroupChatManager(ConversableAgent):
|
||||
# The speaker sends the message without requesting a reply
|
||||
await speaker.a_send(reply, self, request_reply=False)
|
||||
message = self.last_message(speaker)
|
||||
if self.client_cache is not None:
|
||||
for a in groupchat.agents:
|
||||
a.client_cache = a.previous_cache
|
||||
a.previous_cache = None
|
||||
return True, None
|
||||
|
||||
def _raise_exception_on_async_reply_functions(self) -> None:
|
||||
|
0
autogen/cache/__init__.py
vendored
Normal file
0
autogen/cache/__init__.py
vendored
Normal file
90
autogen/cache/abstract_cache_base.py
vendored
Normal file
90
autogen/cache/abstract_cache_base.py
vendored
Normal file
@ -0,0 +1,90 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AbstractCache(ABC):
|
||||
"""
|
||||
Abstract base class for cache implementations.
|
||||
|
||||
This class defines the basic interface for cache operations.
|
||||
Implementing classes should provide concrete implementations for
|
||||
these methods to handle caching mechanisms.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key, default=None):
|
||||
"""
|
||||
Retrieve an item from the cache.
|
||||
|
||||
Abstract method that must be implemented by subclasses to
|
||||
retrieve an item from the cache.
|
||||
|
||||
Args:
|
||||
key (str): The key identifying the item in the cache.
|
||||
default (optional): The default value to return if the key is not found.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
The value associated with the key if found, else the default value.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the subclass does not implement this method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key, value):
|
||||
"""
|
||||
Set an item in the cache.
|
||||
|
||||
Abstract method that must be implemented by subclasses to
|
||||
store an item in the cache.
|
||||
|
||||
Args:
|
||||
key (str): The key under which the item is to be stored.
|
||||
value: The value to be stored in the cache.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the subclass does not implement this method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
"""
|
||||
Close the cache.
|
||||
|
||||
Abstract method that should be implemented by subclasses to
|
||||
perform any necessary cleanup, such as closing network connections or
|
||||
releasing resources.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the subclass does not implement this method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __enter__(self):
|
||||
"""
|
||||
Enter the runtime context related to this object.
|
||||
|
||||
The with statement will bind this method’s return value to the target(s)
|
||||
specified in the as clause of the statement, if any.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the subclass does not implement this method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""
|
||||
Exit the runtime context and close the cache.
|
||||
|
||||
Abstract method that should be implemented by subclasses to handle
|
||||
the exit from a with statement. It is responsible for resource
|
||||
release and cleanup.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type if an exception was raised in the context.
|
||||
exc_value: The exception value if an exception was raised in the context.
|
||||
traceback: The traceback if an exception was raised in the context.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the subclass does not implement this method.
|
||||
"""
|
137
autogen/cache/cache.py
vendored
Normal file
137
autogen/cache/cache.py
vendored
Normal file
@ -0,0 +1,137 @@
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from autogen.cache.cache_factory import CacheFactory
|
||||
|
||||
|
||||
class Cache:
|
||||
"""
|
||||
A wrapper class for managing cache configuration and instances.
|
||||
|
||||
This class provides a unified interface for creating and interacting with
|
||||
different types of cache (e.g., Redis, Disk). It abstracts the underlying
|
||||
cache implementation details, providing methods for cache operations.
|
||||
|
||||
Attributes:
|
||||
config (Dict[str, Any]): A dictionary containing cache configuration.
|
||||
cache: The cache instance created based on the provided configuration.
|
||||
|
||||
Methods:
|
||||
redis(cache_seed=42, redis_url="redis://localhost:6379/0"): Static method to create a Redis cache instance.
|
||||
disk(cache_seed=42, cache_path_root=".cache"): Static method to create a Disk cache instance.
|
||||
__init__(self, config): Initializes the Cache with the given configuration.
|
||||
__enter__(self): Context management entry, returning the cache instance.
|
||||
__exit__(self, exc_type, exc_value, traceback): Context management exit.
|
||||
get(self, key, default=None): Retrieves an item from the cache.
|
||||
set(self, key, value): Sets an item in the cache.
|
||||
close(self): Closes the cache.
|
||||
"""
|
||||
|
||||
ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]
|
||||
|
||||
@staticmethod
|
||||
def redis(cache_seed=42, redis_url="redis://localhost:6379/0"):
|
||||
"""
|
||||
Create a Redis cache instance.
|
||||
|
||||
Args:
|
||||
cache_seed (int, optional): A seed for the cache. Defaults to 42.
|
||||
redis_url (str, optional): The URL for the Redis server. Defaults to "redis://localhost:6379/0".
|
||||
|
||||
Returns:
|
||||
Cache: A Cache instance configured for Redis.
|
||||
"""
|
||||
return Cache({"cache_seed": cache_seed, "redis_url": redis_url})
|
||||
|
||||
@staticmethod
|
||||
def disk(cache_seed=42, cache_path_root=".cache"):
|
||||
"""
|
||||
Create a Disk cache instance.
|
||||
|
||||
Args:
|
||||
cache_seed (int, optional): A seed for the cache. Defaults to 42.
|
||||
cache_path_root (str, optional): The root path for the disk cache. Defaults to ".cache".
|
||||
|
||||
Returns:
|
||||
Cache: A Cache instance configured for Disk caching.
|
||||
"""
|
||||
return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root})
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize the Cache with the given configuration.
|
||||
|
||||
Validates the configuration keys and creates the cache instance.
|
||||
|
||||
Args:
|
||||
config (Dict[str, Any]): A dictionary containing the cache configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid configuration key is provided.
|
||||
"""
|
||||
self.config = config
|
||||
# validate config
|
||||
for key in self.config.keys():
|
||||
if key not in self.ALLOWED_CONFIG_KEYS:
|
||||
raise ValueError(f"Invalid config key: {key}")
|
||||
# create cache instance
|
||||
self.cache = CacheFactory.cache_factory(
|
||||
self.config.get("cache_seed", "42"),
|
||||
self.config.get("redis_url", None),
|
||||
self.config.get("cache_path_root", None),
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Enter the runtime context related to the cache object.
|
||||
|
||||
Returns:
|
||||
The cache instance for use within a context block.
|
||||
"""
|
||||
return self.cache.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""
|
||||
Exit the runtime context related to the cache object.
|
||||
|
||||
Cleans up the cache instance and handles any exceptions that occurred
|
||||
within the context.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type if an exception was raised in the context.
|
||||
exc_value: The exception value if an exception was raised in the context.
|
||||
traceback: The traceback if an exception was raised in the context.
|
||||
"""
|
||||
return self.cache.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""
|
||||
Retrieve an item from the cache.
|
||||
|
||||
Args:
|
||||
key (str): The key identifying the item in the cache.
|
||||
default (optional): The default value to return if the key is not found.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
The value associated with the key if found, else the default value.
|
||||
"""
|
||||
return self.cache.get(key, default)
|
||||
|
||||
def set(self, key, value):
|
||||
"""
|
||||
Set an item in the cache.
|
||||
|
||||
Args:
|
||||
key (str): The key under which the item is to be stored.
|
||||
value: The value to be stored in the cache.
|
||||
"""
|
||||
self.cache.set(key, value)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the cache.
|
||||
|
||||
Perform any necessary cleanup, such as closing connections or releasing resources.
|
||||
"""
|
||||
self.cache.close()
|
40
autogen/cache/cache_factory.py
vendored
Normal file
40
autogen/cache/cache_factory.py
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
from autogen.cache.disk_cache import DiskCache
|
||||
|
||||
try:
|
||||
from autogen.cache.redis_cache import RedisCache
|
||||
except ImportError:
|
||||
RedisCache = None
|
||||
|
||||
|
||||
class CacheFactory:
|
||||
@staticmethod
|
||||
def cache_factory(seed, redis_url=None, cache_path_root=".cache"):
|
||||
"""
|
||||
Factory function for creating cache instances.
|
||||
|
||||
Based on the provided redis_url, this function decides whether to create a RedisCache
|
||||
or DiskCache instance. If RedisCache is available and redis_url is provided,
|
||||
a RedisCache instance is created. Otherwise, a DiskCache instance is used.
|
||||
|
||||
Args:
|
||||
seed (str): A string used as a seed or namespace for the cache.
|
||||
This could be useful for creating distinct cache instances
|
||||
or for namespacing keys in the cache.
|
||||
redis_url (str or None): The URL for the Redis server. If this is None
|
||||
or if RedisCache is not available, a DiskCache instance is created.
|
||||
|
||||
Returns:
|
||||
An instance of either RedisCache or DiskCache, depending on the availability of RedisCache
|
||||
and the provided redis_url.
|
||||
|
||||
Examples:
|
||||
Creating a Redis cache
|
||||
> redis_cache = cache_factory("myseed", "redis://localhost:6379/0")
|
||||
|
||||
Creating a Disk cache
|
||||
> disk_cache = cache_factory("myseed", None)
|
||||
"""
|
||||
if RedisCache is not None and redis_url is not None:
|
||||
return RedisCache(seed, redis_url)
|
||||
else:
|
||||
return DiskCache(f"./{cache_path_root}/{seed}")
|
88
autogen/cache/disk_cache.py
vendored
Normal file
88
autogen/cache/disk_cache.py
vendored
Normal file
@ -0,0 +1,88 @@
|
||||
import diskcache
|
||||
from .abstract_cache_base import AbstractCache
|
||||
|
||||
|
||||
class DiskCache(AbstractCache):
|
||||
"""
|
||||
Implementation of AbstractCache using the DiskCache library.
|
||||
|
||||
This class provides a concrete implementation of the AbstractCache
|
||||
interface using the diskcache library for caching data on disk.
|
||||
|
||||
Attributes:
|
||||
cache (diskcache.Cache): The DiskCache instance used for caching.
|
||||
|
||||
Methods:
|
||||
__init__(self, seed): Initializes the DiskCache with the given seed.
|
||||
get(self, key, default=None): Retrieves an item from the cache.
|
||||
set(self, key, value): Sets an item in the cache.
|
||||
close(self): Closes the cache.
|
||||
__enter__(self): Context management entry.
|
||||
__exit__(self, exc_type, exc_value, traceback): Context management exit.
|
||||
"""
|
||||
|
||||
def __init__(self, seed):
|
||||
"""
|
||||
Initialize the DiskCache instance.
|
||||
|
||||
Args:
|
||||
seed (str): A seed or namespace for the cache. This is used to create
|
||||
a unique storage location for the cache data.
|
||||
|
||||
"""
|
||||
self.cache = diskcache.Cache(seed)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""
|
||||
Retrieve an item from the cache.
|
||||
|
||||
Args:
|
||||
key (str): The key identifying the item in the cache.
|
||||
default (optional): The default value to return if the key is not found.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
The value associated with the key if found, else the default value.
|
||||
"""
|
||||
return self.cache.get(key, default)
|
||||
|
||||
def set(self, key, value):
|
||||
"""
|
||||
Set an item in the cache.
|
||||
|
||||
Args:
|
||||
key (str): The key under which the item is to be stored.
|
||||
value: The value to be stored in the cache.
|
||||
"""
|
||||
self.cache.set(key, value)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the cache.
|
||||
|
||||
Perform any necessary cleanup, such as closing file handles or
|
||||
releasing resources.
|
||||
"""
|
||||
self.cache.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Enter the runtime context related to the object.
|
||||
|
||||
Returns:
|
||||
self: The instance itself.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""
|
||||
Exit the runtime context related to the object.
|
||||
|
||||
Perform cleanup actions such as closing the cache.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type if an exception was raised in the context.
|
||||
exc_value: The exception value if an exception was raised in the context.
|
||||
traceback: The traceback if an exception was raised in the context.
|
||||
"""
|
||||
self.close()
|
110
autogen/cache/redis_cache.py
vendored
Normal file
110
autogen/cache/redis_cache.py
vendored
Normal file
@ -0,0 +1,110 @@
|
||||
import pickle
|
||||
import redis
|
||||
from .abstract_cache_base import AbstractCache
|
||||
|
||||
|
||||
class RedisCache(AbstractCache):
|
||||
"""
|
||||
Implementation of AbstractCache using the Redis database.
|
||||
|
||||
This class provides a concrete implementation of the AbstractCache
|
||||
interface using the Redis database for caching data.
|
||||
|
||||
Attributes:
|
||||
seed (str): A seed or namespace used as a prefix for cache keys.
|
||||
cache (redis.Redis): The Redis client used for caching.
|
||||
|
||||
Methods:
|
||||
__init__(self, seed, redis_url): Initializes the RedisCache with the given seed and Redis URL.
|
||||
_prefixed_key(self, key): Internal method to get a namespaced cache key.
|
||||
get(self, key, default=None): Retrieves an item from the cache.
|
||||
set(self, key, value): Sets an item in the cache.
|
||||
close(self): Closes the Redis client.
|
||||
__enter__(self): Context management entry.
|
||||
__exit__(self, exc_type, exc_value, traceback): Context management exit.
|
||||
"""
|
||||
|
||||
def __init__(self, seed, redis_url):
|
||||
"""
|
||||
Initialize the RedisCache instance.
|
||||
|
||||
Args:
|
||||
seed (str): A seed or namespace for the cache. This is used as a prefix for all cache keys.
|
||||
redis_url (str): The URL for the Redis server.
|
||||
|
||||
"""
|
||||
self.seed = seed
|
||||
self.cache = redis.Redis.from_url(redis_url)
|
||||
|
||||
def _prefixed_key(self, key):
|
||||
"""
|
||||
Get a namespaced key for the cache.
|
||||
|
||||
Args:
|
||||
key (str): The original key.
|
||||
|
||||
Returns:
|
||||
str: The namespaced key.
|
||||
"""
|
||||
return f"autogen:{self.seed}:{key}"
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""
|
||||
Retrieve an item from the Redis cache.
|
||||
|
||||
Args:
|
||||
key (str): The key identifying the item in the cache.
|
||||
default (optional): The default value to return if the key is not found.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
The deserialized value associated with the key if found, else the default value.
|
||||
"""
|
||||
result = self.cache.get(self._prefixed_key(key))
|
||||
if result is None:
|
||||
return default
|
||||
return pickle.loads(result)
|
||||
|
||||
def set(self, key, value):
|
||||
"""
|
||||
Set an item in the Redis cache.
|
||||
|
||||
Args:
|
||||
key (str): The key under which the item is to be stored.
|
||||
value: The value to be stored in the cache.
|
||||
|
||||
Notes:
|
||||
The value is serialized using pickle before being stored in Redis.
|
||||
"""
|
||||
serialized_value = pickle.dumps(value)
|
||||
self.cache.set(self._prefixed_key(key), serialized_value)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the Redis client.
|
||||
|
||||
Perform any necessary cleanup, such as closing network connections.
|
||||
"""
|
||||
self.cache.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Enter the runtime context related to the object.
|
||||
|
||||
Returns:
|
||||
self: The instance itself.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""
|
||||
Exit the runtime context related to the object.
|
||||
|
||||
Perform cleanup actions such as closing the Redis client.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type if an exception was raised in the context.
|
||||
exc_value: The exception value if an exception was raised in the context.
|
||||
traceback: The traceback if an exception was raised in the context.
|
||||
"""
|
||||
self.close()
|
@ -9,6 +9,7 @@ from flaml.automl.logger import logger_formatter
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogen.cache.cache import Cache
|
||||
from autogen.oai import completion
|
||||
|
||||
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, get_key, OAI_PRICE1K
|
||||
@ -35,7 +36,6 @@ else:
|
||||
)
|
||||
from openai.types.completion import Completion
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
import diskcache
|
||||
|
||||
if openai.__version__ >= "1.1.0":
|
||||
TOOL_ENABLED = True
|
||||
@ -48,12 +48,15 @@ if not logger.handlers:
|
||||
_ch.setFormatter(logger_formatter)
|
||||
logger.addHandler(_ch)
|
||||
|
||||
LEGACY_DEFAULT_CACHE_SEED = 41
|
||||
LEGACY_CACHE_DIR = ".cache"
|
||||
|
||||
|
||||
class OpenAIWrapper:
|
||||
"""A wrapper class for openai client."""
|
||||
|
||||
cache_path_root: str = ".cache"
|
||||
extra_kwargs = {
|
||||
"cache",
|
||||
"cache_seed",
|
||||
"filter_func",
|
||||
"allow_format_str_template",
|
||||
@ -62,6 +65,7 @@ class OpenAIWrapper:
|
||||
"api_type",
|
||||
"tags",
|
||||
}
|
||||
|
||||
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
|
||||
aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs)
|
||||
openai_kwargs = openai_kwargs | aopenai_kwargs
|
||||
@ -205,9 +209,14 @@ class OpenAIWrapper:
|
||||
The actual prompt will be:
|
||||
"Complete the following sentence: Today I feel".
|
||||
More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating).
|
||||
- `cache_seed` (int | None) for the cache. Default to 41.
|
||||
- cache (Cache | None): A Cache object to use for response cache. Default to None.
|
||||
Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided,
|
||||
then the cache_seed argument is ignored. If this argument is not provided or None,
|
||||
then the cache_seed argument is used.
|
||||
- (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41.
|
||||
An integer cache_seed is useful when implementing "controlled randomness" for the completion.
|
||||
None for no caching.
|
||||
Note: this is a legacy argument. It is only used when the cache argument is not provided.
|
||||
- filter_func (Callable | None): A function that takes in the context and the response
|
||||
and returns a boolean to indicate whether the response is valid. E.g.,
|
||||
|
||||
@ -235,13 +244,21 @@ class OpenAIWrapper:
|
||||
# construct the create params
|
||||
params = self._construct_create_params(create_config, extra_kwargs)
|
||||
# get the cache_seed, filter_func and context
|
||||
cache_seed = extra_kwargs.get("cache_seed", 41)
|
||||
cache_seed = extra_kwargs.get("cache_seed", LEGACY_DEFAULT_CACHE_SEED)
|
||||
cache = extra_kwargs.get("cache")
|
||||
filter_func = extra_kwargs.get("filter_func")
|
||||
context = extra_kwargs.get("context")
|
||||
|
||||
# Try to load the response from cache
|
||||
if cache_seed is not None:
|
||||
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
|
||||
cache_client = None
|
||||
if cache is not None:
|
||||
# Use the cache object if provided.
|
||||
cache_client = cache
|
||||
elif cache_seed is not None:
|
||||
# Legacy cache behavior, if cache_seed is given, use DiskCache.
|
||||
cache_client = Cache.disk(cache_seed, LEGACY_CACHE_DIR)
|
||||
|
||||
if cache_client is not None:
|
||||
with cache_client as cache:
|
||||
# Try to get the response from cache
|
||||
key = get_key(params)
|
||||
response: ChatCompletion = cache.get(key, None)
|
||||
@ -276,9 +293,9 @@ class OpenAIWrapper:
|
||||
# add cost calculation before caching no matter filter is passed or not
|
||||
response.cost = self.cost(response)
|
||||
self._update_usage_summary(response, use_cache=False)
|
||||
if cache_seed is not None:
|
||||
if cache_client is not None:
|
||||
# Cache the response
|
||||
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
|
||||
with cache_client as cache:
|
||||
cache.set(key, response)
|
||||
|
||||
# check the filter
|
||||
|
1
setup.py
1
setup.py
@ -52,6 +52,7 @@ setuptools.setup(
|
||||
"teachable": ["chromadb"],
|
||||
"lmm": ["replicate", "pillow"],
|
||||
"graphs": ["networkx~=3.2.1", "matplotlib~=3.8.1"],
|
||||
"redis": ["redis"],
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
227
test/agentchat/test_cache_agent.py
Normal file
227
test/agentchat/test_cache_agent.py
Normal file
@ -0,0 +1,227 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import autogen
|
||||
from autogen.agentchat import AssistantAgent, UserProxyAgent
|
||||
from autogen.cache.cache import Cache
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from conftest import skip_openai, skip_redis # noqa: E402
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
skip_openai_tests = True
|
||||
else:
|
||||
skip_openai_tests = False or skip_openai
|
||||
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
skip_redis_tests = True
|
||||
else:
|
||||
skip_redis_tests = False or skip_redis
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai_tests, reason="openai not installed OR requested to skip")
|
||||
def test_legacy_disk_cache():
|
||||
random_cache_seed = int.from_bytes(os.urandom(2), "big")
|
||||
start_time = time.time()
|
||||
cold_cache_messages = run_conversation(
|
||||
cache_seed=random_cache_seed,
|
||||
)
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_messages = run_conversation(
|
||||
cache_seed=random_cache_seed,
|
||||
)
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_messages == warm_cache_messages
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai_tests or skip_redis_tests, reason="redis not installed OR requested to skip")
|
||||
def test_redis_cache():
|
||||
random_cache_seed = int.from_bytes(os.urandom(2), "big")
|
||||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
start_time = time.time()
|
||||
with Cache.redis(random_cache_seed, redis_url) as cache_client:
|
||||
cold_cache_messages = run_conversation(cache_seed=None, cache=cache_client)
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_messages = run_conversation(cache_seed=None, cache=cache_client)
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_messages == warm_cache_messages
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
|
||||
random_cache_seed = int.from_bytes(os.urandom(2), "big")
|
||||
with Cache.redis(random_cache_seed, redis_url) as cache_client:
|
||||
cold_cache_messages = run_groupchat_conversation(cache=cache_client)
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_messages = run_groupchat_conversation(cache=cache_client)
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_messages == warm_cache_messages
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai_tests, reason="openai not installed OR requested to skip")
|
||||
def test_disk_cache():
|
||||
random_cache_seed = int.from_bytes(os.urandom(2), "big")
|
||||
start_time = time.time()
|
||||
with Cache.disk(random_cache_seed) as cache_client:
|
||||
cold_cache_messages = run_conversation(cache_seed=None, cache=cache_client)
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_messages = run_conversation(cache_seed=None, cache=cache_client)
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_messages == warm_cache_messages
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
|
||||
random_cache_seed = int.from_bytes(os.urandom(2), "big")
|
||||
with Cache.disk(random_cache_seed) as cache_client:
|
||||
cold_cache_messages = run_groupchat_conversation(cache=cache_client)
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_messages = run_groupchat_conversation(cache=cache_client)
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_messages == warm_cache_messages
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
|
||||
|
||||
def run_conversation(cache_seed, human_input_mode="NEVER", max_consecutive_auto_reply=5, cache=None):
|
||||
KEY_LOC = "notebook"
|
||||
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
config_list = autogen.config_list_from_json(
|
||||
OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={
|
||||
"model": {
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-35-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"chatgpt-35-turbo-0301",
|
||||
"gpt-35-turbo-v0301",
|
||||
"gpt",
|
||||
},
|
||||
},
|
||||
)
|
||||
llm_config = {
|
||||
"cache_seed": cache_seed,
|
||||
"config_list": config_list,
|
||||
"max_tokens": 1024,
|
||||
}
|
||||
assistant = AssistantAgent(
|
||||
"coding_agent",
|
||||
llm_config=llm_config,
|
||||
)
|
||||
user = UserProxyAgent(
|
||||
"user",
|
||||
human_input_mode=human_input_mode,
|
||||
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||
code_execution_config={
|
||||
"work_dir": f"{here}/test_agent_scripts",
|
||||
"use_docker": "python:3",
|
||||
"timeout": 60,
|
||||
},
|
||||
llm_config=llm_config,
|
||||
system_message="""Is code provided but not enclosed in ``` blocks?
|
||||
If so, remind that code blocks need to be enclosed in ``` blocks.
|
||||
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation.
|
||||
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""",
|
||||
)
|
||||
|
||||
user.initiate_chat(assistant, message="TERMINATE", cache=cache)
|
||||
# should terminate without sending any message
|
||||
assert assistant.last_message()["content"] == assistant.last_message(user)["content"] == "TERMINATE"
|
||||
coding_task = "Print hello world to a file called hello.txt"
|
||||
|
||||
# track how long this takes
|
||||
user.initiate_chat(assistant, message=coding_task, cache=cache)
|
||||
return user.chat_messages[list(user.chat_messages.keys())[-0]]
|
||||
|
||||
|
||||
def run_groupchat_conversation(cache, human_input_mode="NEVER", max_consecutive_auto_reply=5):
|
||||
KEY_LOC = "notebook"
|
||||
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
config_list = autogen.config_list_from_json(
|
||||
OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={
|
||||
"model": {
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-35-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"chatgpt-35-turbo-0301",
|
||||
"gpt-35-turbo-v0301",
|
||||
"gpt",
|
||||
},
|
||||
},
|
||||
)
|
||||
llm_config = {
|
||||
"cache_seed": None,
|
||||
"config_list": config_list,
|
||||
"max_tokens": 1024,
|
||||
}
|
||||
assistant = AssistantAgent(
|
||||
"coding_agent",
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
planner = AssistantAgent(
|
||||
"planner",
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
user = UserProxyAgent(
|
||||
"user",
|
||||
human_input_mode=human_input_mode,
|
||||
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||
code_execution_config={
|
||||
"work_dir": f"{here}/test_agent_scripts",
|
||||
"use_docker": "python:3",
|
||||
"timeout": 60,
|
||||
},
|
||||
system_message="""Is code provided but not enclosed in ``` blocks?
|
||||
If so, remind that code blocks need to be enclosed in ``` blocks.
|
||||
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation.
|
||||
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""",
|
||||
)
|
||||
|
||||
group_chat = autogen.GroupChat(
|
||||
agents=[planner, assistant, user],
|
||||
messages=[],
|
||||
max_round=4,
|
||||
speaker_selection_method="round_robin",
|
||||
)
|
||||
manager = autogen.GroupChatManager(groupchat=group_chat, llm_config=llm_config)
|
||||
|
||||
coding_task = "Print hello world to a file called hello.txt"
|
||||
|
||||
user.initiate_chat(manager, message=coding_task, cache=cache)
|
||||
return user.chat_messages[list(user.chat_messages.keys())[-0]]
|
53
test/cache/test_cache.py
vendored
Normal file
53
test/cache/test_cache.py
vendored
Normal file
@ -0,0 +1,53 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from autogen.cache.cache import Cache
|
||||
|
||||
|
||||
class TestCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = {"cache_seed": "test_seed", "redis_url": "redis://test", "cache_path_root": ".test_cache"}
|
||||
|
||||
@patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock())
|
||||
def test_init(self, mock_cache_factory):
|
||||
cache = Cache(self.config)
|
||||
self.assertIsInstance(cache.cache, MagicMock)
|
||||
mock_cache_factory.assert_called_with("test_seed", "redis://test", ".test_cache")
|
||||
|
||||
@patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock())
|
||||
def test_context_manager(self, mock_cache_factory):
|
||||
mock_cache_instance = MagicMock()
|
||||
mock_cache_factory.return_value = mock_cache_instance
|
||||
|
||||
with Cache(self.config) as cache:
|
||||
self.assertIsInstance(cache, MagicMock)
|
||||
|
||||
mock_cache_instance.__enter__.assert_called()
|
||||
mock_cache_instance.__exit__.assert_called()
|
||||
|
||||
@patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock())
|
||||
def test_get_set(self, mock_cache_factory):
|
||||
key = "key"
|
||||
value = "value"
|
||||
mock_cache_instance = MagicMock()
|
||||
mock_cache_factory.return_value = mock_cache_instance
|
||||
|
||||
cache = Cache(self.config)
|
||||
cache.set(key, value)
|
||||
cache.get(key)
|
||||
|
||||
mock_cache_instance.set.assert_called_with(key, value)
|
||||
mock_cache_instance.get.assert_called_with(key, None)
|
||||
|
||||
@patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock())
|
||||
def test_close(self, mock_cache_factory):
|
||||
mock_cache_instance = MagicMock()
|
||||
mock_cache_factory.return_value = mock_cache_instance
|
||||
|
||||
cache = Cache(self.config)
|
||||
cache.close()
|
||||
|
||||
mock_cache_instance.close.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
51
test/cache/test_disk_cache.py
vendored
Normal file
51
test/cache/test_disk_cache.py
vendored
Normal file
@ -0,0 +1,51 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from autogen.cache.disk_cache import DiskCache
|
||||
|
||||
|
||||
class TestDiskCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.seed = "test_seed"
|
||||
|
||||
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
|
||||
def test_init(self, mock_cache):
|
||||
cache = DiskCache(self.seed)
|
||||
self.assertIsInstance(cache.cache, MagicMock)
|
||||
mock_cache.assert_called_with(self.seed)
|
||||
|
||||
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
|
||||
def test_get(self, mock_cache):
|
||||
key = "key"
|
||||
value = "value"
|
||||
cache = DiskCache(self.seed)
|
||||
cache.cache.get.return_value = value
|
||||
self.assertEqual(cache.get(key), value)
|
||||
cache.cache.get.assert_called_with(key, None)
|
||||
|
||||
cache.cache.get.return_value = None
|
||||
self.assertIsNone(cache.get(key, None))
|
||||
|
||||
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
|
||||
def test_set(self, mock_cache):
|
||||
key = "key"
|
||||
value = "value"
|
||||
cache = DiskCache(self.seed)
|
||||
cache.set(key, value)
|
||||
cache.cache.set.assert_called_with(key, value)
|
||||
|
||||
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
|
||||
def test_context_manager(self, mock_cache):
|
||||
with DiskCache(self.seed) as cache:
|
||||
self.assertIsInstance(cache, DiskCache)
|
||||
mock_cache_instance = cache.cache
|
||||
mock_cache_instance.close.assert_called()
|
||||
|
||||
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
|
||||
def test_close(self, mock_cache):
|
||||
cache = DiskCache(self.seed)
|
||||
cache.close()
|
||||
cache.cache.close.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
69
test/cache/test_redis_cache.py
vendored
Normal file
69
test/cache/test_redis_cache.py
vendored
Normal file
@ -0,0 +1,69 @@
|
||||
import unittest
|
||||
import pickle
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from autogen.cache.redis_cache import RedisCache
|
||||
|
||||
skip_redis_tests = False
|
||||
except ImportError:
|
||||
skip_redis_tests = True
|
||||
|
||||
|
||||
class TestRedisCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.seed = "test_seed"
|
||||
self.redis_url = "redis://localhost:6379/0"
|
||||
|
||||
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
|
||||
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
|
||||
def test_init(self, mock_redis_from_url):
|
||||
cache = RedisCache(self.seed, self.redis_url)
|
||||
self.assertEqual(cache.seed, self.seed)
|
||||
mock_redis_from_url.assert_called_with(self.redis_url)
|
||||
|
||||
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
|
||||
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
|
||||
def test_prefixed_key(self, mock_redis_from_url):
|
||||
cache = RedisCache(self.seed, self.redis_url)
|
||||
key = "test_key"
|
||||
expected_prefixed_key = f"autogen:{self.seed}:{key}"
|
||||
self.assertEqual(cache._prefixed_key(key), expected_prefixed_key)
|
||||
|
||||
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
|
||||
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
|
||||
def test_get(self, mock_redis_from_url):
|
||||
key = "key"
|
||||
value = "value"
|
||||
serialized_value = pickle.dumps(value)
|
||||
cache = RedisCache(self.seed, self.redis_url)
|
||||
cache.cache.get.return_value = serialized_value
|
||||
self.assertEqual(cache.get(key), value)
|
||||
cache.cache.get.assert_called_with(f"autogen:{self.seed}:{key}")
|
||||
|
||||
cache.cache.get.return_value = None
|
||||
self.assertIsNone(cache.get(key))
|
||||
|
||||
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
|
||||
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
|
||||
def test_set(self, mock_redis_from_url):
|
||||
key = "key"
|
||||
value = "value"
|
||||
serialized_value = pickle.dumps(value)
|
||||
cache = RedisCache(self.seed, self.redis_url)
|
||||
cache.set(key, value)
|
||||
cache.cache.set.assert_called_with(f"autogen:{self.seed}:{key}", serialized_value)
|
||||
|
||||
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
|
||||
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
|
||||
def test_context_manager(self, mock_redis_from_url):
|
||||
with RedisCache(self.seed, self.redis_url) as cache:
|
||||
self.assertIsInstance(cache, RedisCache)
|
||||
mock_redis_instance = cache.cache
|
||||
mock_redis_instance.close.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -14,3 +14,5 @@ def pytest_addoption(parser):
|
||||
def pytest_configure(config):
|
||||
global skip_openai
|
||||
skip_openai = config.getoption("--skip-openai", False)
|
||||
global skip_redis
|
||||
skip_redis = config.getoption("--skip-redis", False)
|
||||
|
@ -1,7 +1,11 @@
|
||||
import shutil
|
||||
import time
|
||||
import pytest
|
||||
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
|
||||
from autogen.oai.client import LEGACY_CACHE_DIR, LEGACY_DEFAULT_CACHE_SEED
|
||||
import sys
|
||||
import os
|
||||
from autogen.cache.cache import Cache
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from conftest import skip_openai # noqa: E402
|
||||
@ -151,10 +155,144 @@ def test_usage_summary():
|
||||
assert client.actual_usage_summary is None, "No actual cost should be recorded"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_legacy_cache():
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
|
||||
# Clear cache.
|
||||
if os.path.exists(LEGACY_CACHE_DIR):
|
||||
shutil.rmtree(LEGACY_CACHE_DIR)
|
||||
|
||||
# Test default cache seed.
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
start_time = time.time()
|
||||
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_response == warm_cache_response
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(LEGACY_DEFAULT_CACHE_SEED)))
|
||||
|
||||
# Test with cache seed set through constructor
|
||||
client = OpenAIWrapper(config_list=config_list, cache_seed=13)
|
||||
start_time = time.time()
|
||||
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_response == warm_cache_response
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(13)))
|
||||
|
||||
# Test with cache seed set through create method
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
start_time = time.time()
|
||||
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache_seed=17)
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache_seed=17)
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_response == warm_cache_response
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(17)))
|
||||
|
||||
# Test using a different cache seed through create method.
|
||||
start_time = time.time()
|
||||
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache_seed=21)
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(21)))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_cache():
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
|
||||
# Clear cache.
|
||||
if os.path.exists(LEGACY_CACHE_DIR):
|
||||
shutil.rmtree(LEGACY_CACHE_DIR)
|
||||
cache_dir = ".cache_test"
|
||||
assert cache_dir != LEGACY_CACHE_DIR
|
||||
if os.path.exists(cache_dir):
|
||||
shutil.rmtree(cache_dir)
|
||||
|
||||
# Test cache set through constructor.
|
||||
with Cache.disk(cache_seed=49, cache_path_root=cache_dir) as cache:
|
||||
client = OpenAIWrapper(config_list=config_list, cache=cache)
|
||||
start_time = time.time()
|
||||
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_response == warm_cache_response
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
assert os.path.exists(os.path.join(cache_dir, str(49)))
|
||||
# Test legacy cache is not used.
|
||||
assert not os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(49)))
|
||||
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))
|
||||
|
||||
# Test cache set through method.
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
with Cache.disk(cache_seed=312, cache_path_root=cache_dir) as cache:
|
||||
start_time = time.time()
|
||||
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache=cache)
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache=cache)
|
||||
end_time = time.time()
|
||||
duration_with_warm_cache = end_time - start_time
|
||||
assert cold_cache_response == warm_cache_response
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
assert os.path.exists(os.path.join(cache_dir, str(312)))
|
||||
# Test legacy cache is not used.
|
||||
assert not os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(312)))
|
||||
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))
|
||||
|
||||
# Test different cache seed.
|
||||
with Cache.disk(cache_seed=123, cache_path_root=cache_dir) as cache:
|
||||
start_time = time.time()
|
||||
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache=cache)
|
||||
end_time = time.time()
|
||||
duration_with_cold_cache = end_time - start_time
|
||||
assert duration_with_warm_cache < duration_with_cold_cache
|
||||
# Test legacy cache is not used.
|
||||
assert not os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(123)))
|
||||
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_aoai_chat_completion()
|
||||
# test_oai_tool_calling_extraction()
|
||||
# test_chat_completion()
|
||||
test_completion()
|
||||
# test_completion()
|
||||
# # test_cost()
|
||||
# test_usage_summary()
|
||||
test_legacy_cache()
|
||||
test_cache()
|
||||
|
@ -209,6 +209,17 @@ Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.gith
|
||||
|
||||
### Optional Dependencies
|
||||
|
||||
- #### LLM Caching
|
||||
|
||||
To use LLM caching with Redis, you need to install the Python package with
|
||||
the option `redis`:
|
||||
|
||||
```bash
|
||||
pip install "pyautogen[redis]"
|
||||
```
|
||||
|
||||
See [LLM Caching](Use-Cases/agent_chat.md#llm-caching) for details.
|
||||
|
||||
- #### blendsearch
|
||||
|
||||
`pyautogen<0.2` offers a cost-effective hyperparameter optimization technique [EcoOptiGen](https://arxiv.org/abs/2303.04673) for tuning Large Language Models. Please install with the [blendsearch] option to use it.
|
||||
|
@ -285,6 +285,33 @@ By adopting the conversation-driven control with both programming language and n
|
||||
- LLM-based function call. In this approach, LLM decides whether or not to call a particular function depending on the conversation status in each inference call.
|
||||
By messaging additional agents in the called functions, the LLM can drive dynamic multi-agent conversation. A working system showcasing this type of dynamic conversation can be found in the [multi-user math problem solving scenario](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_two_users.ipynb), where a student assistant would automatically resort to an expert using function calls.
|
||||
|
||||
### LLM Caching
|
||||
|
||||
Since version 0.2.8, a configurable context manager allows you to easily configure LLM cache, using either DiskCache or Redis. All agents inside the context manager will use the same cache.
|
||||
|
||||
```python
|
||||
from autogen.cache.cache import Cache
|
||||
|
||||
with Cache.redis(cache_seed=42, redis_url="redis://localhost:6379/0") as cache:
|
||||
user.initiate_chat(assistant, message=coding_task, cache=cache)
|
||||
|
||||
with Cache.disk(cache_seed=42, cache_dir=".cache") as cache:
|
||||
user.initiate_chat(assistant, message=coding_task, cache=cache)
|
||||
```
|
||||
|
||||
For backward compatibility, DiskCache is on by default with `cache_seed` set to 41.
|
||||
To disable caching completely, set `cache_seed` to `None` in the `llm_config` of the agent.
|
||||
|
||||
```python
|
||||
assistant = AssistantAgent(
|
||||
"coding_agent",
|
||||
llm_config={
|
||||
"cache_seed": None,
|
||||
"config_list": OAI_CONFIG_LIST,
|
||||
"max_tokens": 1024,
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
### Diverse Applications Implemented with AutoGen
|
||||
|
||||
|
@ -168,19 +168,45 @@ Total cost: 0.00027
|
||||
|
||||
## Caching
|
||||
|
||||
API call results are cached locally and reused when the same request is issued. This is useful when repeating or continuing experiments for reproducibility and cost saving. It still allows controlled randomness by setting the "cache_seed" specified in `OpenAIWrapper.create()` or the constructor of `OpenAIWrapper`.
|
||||
API call results are cached locally and reused when the same request is issued. This is useful when repeating or continuing experiments for reproducibility and cost saving.
|
||||
|
||||
Starting version 0.2.8, a configurable context manager allows you to easily configure
|
||||
the cache, using either DiskCache or Redis.
|
||||
All `OpenAIWrapper` created inside the context manager can use the same cache through the constructor.
|
||||
|
||||
```python
|
||||
client = OpenAIWrapper(cache_seed=...)
|
||||
client.create(...)
|
||||
from autogen.cache.cache import Cache
|
||||
|
||||
with Cache.redis(cache_seed=42, redis_url="redis://localhost:6379/0") as cache:
|
||||
client = OpenAIWrapper(..., cache=cache)
|
||||
client.create(...)
|
||||
|
||||
with Cache.disk(cache_seed=42, cache_dir=".cache") as cache:
|
||||
client = OpenAIWrapper(..., cache=cache)
|
||||
client.create(...)
|
||||
```
|
||||
|
||||
You can also set a cache directly in the `create()` method.
|
||||
|
||||
```python
|
||||
client = OpenAIWrapper()
|
||||
client.create(cache_seed=..., ...)
|
||||
with Cache.disk(cache_seed=42, cache_dir=".cache") as cache:
|
||||
client.create(..., cache=cache)
|
||||
```
|
||||
|
||||
Caching is enabled by default with cache_seed 41. To disable it please set `cache_seed` to None.
|
||||
You can control the randomness by setting the `cache_seed` parameter.
|
||||
|
||||
### Turnning off cache
|
||||
|
||||
For backward compatibility, DiskCache is always enabled by default
|
||||
with `cache_seed` set to 41. To fully disable it, set `cache_seed` to None.
|
||||
|
||||
```python
|
||||
# Turn off cache in constructor,
|
||||
client = OpenAIWrapper(..., cache_seed=None)
|
||||
# or directly in create().
|
||||
client.create(..., cache_seed=None)
|
||||
```
|
||||
|
||||
_NOTE_. openai v1.1 introduces a new param `seed`. The difference between autogen's `cache_seed` and openai's `seed` is that:
|
||||
* autogen uses local disk cache to guarantee the exactly same output is produced for the same input and when cache is hit, no openai api call will be made.
|
||||
|
Loading…
x
Reference in New Issue
Block a user