[Core] implement redis cache mode (#1222)

* implement redis cache mode, if redis_url is set in the llm_config then
it will try to use this.  also adds a test to validate both the existing
and the redis cache behavior.

* PR feedback, add unit tests

* more PR feedback, move the new style cache to a context manager

* Update agent_chat.md

* more PR feedback, remove tests from contrib and have them run with the normal jobs

* doc

* updated

* Update website/docs/Use-Cases/agent_chat.md

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* update docs

* update docs; let openaiwrapper to use cache object

* typo

* Update website/docs/Use-Cases/enhanced_inference.md

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* save previous client cache and reset it after send/a_send

* a_run_chat

---------

Co-authored-by: Vijay Ramesh <vijay@regrello.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Vijay Ramesh 2024-01-20 09:06:29 -08:00 committed by GitHub
parent e97b6395af
commit ee6ad8d519
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1149 additions and 17 deletions

View File

@ -54,6 +54,7 @@ jobs:
if: matrix.python-version == '3.10'
run: |
pip install -e .[test]
pip install -e .[redis]
coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai
coverage xml
- name: Upload coverage to Codecov

View File

@ -21,6 +21,12 @@ jobs:
python-version: ["3.9", "3.10", "3.11", "3.12"]
runs-on: ${{ matrix.os }}
environment: openai1
services:
redis:
image: redis
ports:
- 6379:6379
options: --entrypoint redis-server
steps:
# checkout to pr branch
- name: Checkout
@ -42,6 +48,7 @@ jobs:
if: matrix.python-version == '3.9'
run: |
pip install docker
pip install -e .[redis]
- name: Coverage
if: matrix.python-version == '3.9'
env:

View File

@ -9,6 +9,7 @@ from collections import defaultdict
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
from .. import OpenAIWrapper
from ..cache.cache import Cache
from ..code_utils import (
DEFAULT_MODEL,
UNKNOWN,
@ -135,6 +136,9 @@ class ConversableAgent(Agent):
self.llm_config.update(llm_config)
self.client = OpenAIWrapper(**self.llm_config)
# Initialize standalone client cache object.
self.client_cache = None
self._code_execution_config: Union[Dict, Literal[False]] = (
{} if code_execution_config is None else code_execution_config
)
@ -665,6 +669,7 @@ class ConversableAgent(Agent):
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
**context,
):
"""Initiate a chat with the recipient agent.
@ -677,6 +682,7 @@ class ConversableAgent(Agent):
recipient: the recipient agent.
clear_history (bool): whether to clear the chat history with the agent.
silent (bool or None): (Experimental) whether to print the messages for this conversation.
cache (Cache or None): the cache client to be used for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
@ -686,14 +692,20 @@ class ConversableAgent(Agent):
"""
for agent in [self, recipient]:
agent._raise_exception_on_async_reply_functions()
agent.previous_cache = agent.client_cache
agent.client_cache = cache
self._prepare_chat(recipient, clear_history)
self.send(self.generate_init_message(**context), recipient, silent=silent)
for agent in [self, recipient]:
agent.client_cache = agent.previous_cache
agent.previous_cache = None
async def a_initiate_chat(
self,
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
**context,
):
"""(async) Initiate a chat with the recipient agent.
@ -706,12 +718,19 @@ class ConversableAgent(Agent):
recipient: the recipient agent.
clear_history (bool): whether to clear the chat history with the agent.
silent (bool or None): (Experimental) whether to print the messages for this conversation.
cache (Cache or None): the cache client to be used for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
"""
self._prepare_chat(recipient, clear_history)
for agent in [self, recipient]:
agent.previous_cache = agent.client_cache
agent.client_cache = cache
await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)
for agent in [self, recipient]:
agent.client_cache = agent.previous_cache
agent.previous_cache = None
def reset(self):
"""Reset the agent."""
@ -778,7 +797,9 @@ class ConversableAgent(Agent):
# TODO: #1143 handle token limit exceeded error
response = client.create(
context=messages[-1].pop("context", None), messages=self._oai_system_message + all_messages
context=messages[-1].pop("context", None),
messages=self._oai_system_message + all_messages,
cache=self.client_cache,
)
extracted_response = client.extract_text_or_completion_object(response)[0]

View File

@ -349,13 +349,17 @@ class GroupChatManager(ConversableAgent):
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[GroupChat] = None,
) -> Union[str, Dict, None]:
) -> Tuple[bool, Optional[str]]:
"""Run a group chat."""
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
speaker = sender
groupchat = config
if self.client_cache is not None:
for a in groupchat.agents:
a.previous_cache = a.client_cache
a.client_cache = self.client_cache
for i in range(groupchat.max_round):
groupchat.append(message, speaker)
if self._is_termination_msg(message):
@ -389,6 +393,10 @@ class GroupChatManager(ConversableAgent):
message = self.last_message(speaker)
if i == groupchat.max_round - 1:
groupchat.append(message, speaker)
if self.client_cache is not None:
for a in groupchat.agents:
a.client_cache = a.previous_cache
a.previous_cache = None
return True, None
async def a_run_chat(
@ -403,6 +411,10 @@ class GroupChatManager(ConversableAgent):
message = messages[-1]
speaker = sender
groupchat = config
if self.client_cache is not None:
for a in groupchat.agents:
a.previous_cache = a.client_cache
a.client_cache = self.client_cache
for i in range(groupchat.max_round):
groupchat.append(message, speaker)
@ -436,6 +448,10 @@ class GroupChatManager(ConversableAgent):
# The speaker sends the message without requesting a reply
await speaker.a_send(reply, self, request_reply=False)
message = self.last_message(speaker)
if self.client_cache is not None:
for a in groupchat.agents:
a.client_cache = a.previous_cache
a.previous_cache = None
return True, None
def _raise_exception_on_async_reply_functions(self) -> None:

0
autogen/cache/__init__.py vendored Normal file
View File

90
autogen/cache/abstract_cache_base.py vendored Normal file
View File

@ -0,0 +1,90 @@
from abc import ABC, abstractmethod
class AbstractCache(ABC):
"""
Abstract base class for cache implementations.
This class defines the basic interface for cache operations.
Implementing classes should provide concrete implementations for
these methods to handle caching mechanisms.
"""
@abstractmethod
def get(self, key, default=None):
"""
Retrieve an item from the cache.
Abstract method that must be implemented by subclasses to
retrieve an item from the cache.
Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
Defaults to None.
Returns:
The value associated with the key if found, else the default value.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
@abstractmethod
def set(self, key, value):
"""
Set an item in the cache.
Abstract method that must be implemented by subclasses to
store an item in the cache.
Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
@abstractmethod
def close(self):
"""
Close the cache.
Abstract method that should be implemented by subclasses to
perform any necessary cleanup, such as closing network connections or
releasing resources.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
@abstractmethod
def __enter__(self):
"""
Enter the runtime context related to this object.
The with statement will bind this methods return value to the target(s)
specified in the as clause of the statement, if any.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the runtime context and close the cache.
Abstract method that should be implemented by subclasses to handle
the exit from a with statement. It is responsible for resource
release and cleanup.
Args:
exc_type: The exception type if an exception was raised in the context.
exc_value: The exception value if an exception was raised in the context.
traceback: The traceback if an exception was raised in the context.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""

137
autogen/cache/cache.py vendored Normal file
View File

@ -0,0 +1,137 @@
import os
from typing import Dict, Any
from autogen.cache.cache_factory import CacheFactory
class Cache:
"""
A wrapper class for managing cache configuration and instances.
This class provides a unified interface for creating and interacting with
different types of cache (e.g., Redis, Disk). It abstracts the underlying
cache implementation details, providing methods for cache operations.
Attributes:
config (Dict[str, Any]): A dictionary containing cache configuration.
cache: The cache instance created based on the provided configuration.
Methods:
redis(cache_seed=42, redis_url="redis://localhost:6379/0"): Static method to create a Redis cache instance.
disk(cache_seed=42, cache_path_root=".cache"): Static method to create a Disk cache instance.
__init__(self, config): Initializes the Cache with the given configuration.
__enter__(self): Context management entry, returning the cache instance.
__exit__(self, exc_type, exc_value, traceback): Context management exit.
get(self, key, default=None): Retrieves an item from the cache.
set(self, key, value): Sets an item in the cache.
close(self): Closes the cache.
"""
ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]
@staticmethod
def redis(cache_seed=42, redis_url="redis://localhost:6379/0"):
"""
Create a Redis cache instance.
Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
redis_url (str, optional): The URL for the Redis server. Defaults to "redis://localhost:6379/0".
Returns:
Cache: A Cache instance configured for Redis.
"""
return Cache({"cache_seed": cache_seed, "redis_url": redis_url})
@staticmethod
def disk(cache_seed=42, cache_path_root=".cache"):
"""
Create a Disk cache instance.
Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
cache_path_root (str, optional): The root path for the disk cache. Defaults to ".cache".
Returns:
Cache: A Cache instance configured for Disk caching.
"""
return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root})
def __init__(self, config: Dict[str, Any]):
"""
Initialize the Cache with the given configuration.
Validates the configuration keys and creates the cache instance.
Args:
config (Dict[str, Any]): A dictionary containing the cache configuration.
Raises:
ValueError: If an invalid configuration key is provided.
"""
self.config = config
# validate config
for key in self.config.keys():
if key not in self.ALLOWED_CONFIG_KEYS:
raise ValueError(f"Invalid config key: {key}")
# create cache instance
self.cache = CacheFactory.cache_factory(
self.config.get("cache_seed", "42"),
self.config.get("redis_url", None),
self.config.get("cache_path_root", None),
)
def __enter__(self):
"""
Enter the runtime context related to the cache object.
Returns:
The cache instance for use within a context block.
"""
return self.cache.__enter__()
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the runtime context related to the cache object.
Cleans up the cache instance and handles any exceptions that occurred
within the context.
Args:
exc_type: The exception type if an exception was raised in the context.
exc_value: The exception value if an exception was raised in the context.
traceback: The traceback if an exception was raised in the context.
"""
return self.cache.__exit__(exc_type, exc_value, traceback)
def get(self, key, default=None):
"""
Retrieve an item from the cache.
Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
Defaults to None.
Returns:
The value associated with the key if found, else the default value.
"""
return self.cache.get(key, default)
def set(self, key, value):
"""
Set an item in the cache.
Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.
"""
self.cache.set(key, value)
def close(self):
"""
Close the cache.
Perform any necessary cleanup, such as closing connections or releasing resources.
"""
self.cache.close()

40
autogen/cache/cache_factory.py vendored Normal file
View File

@ -0,0 +1,40 @@
from autogen.cache.disk_cache import DiskCache
try:
from autogen.cache.redis_cache import RedisCache
except ImportError:
RedisCache = None
class CacheFactory:
@staticmethod
def cache_factory(seed, redis_url=None, cache_path_root=".cache"):
"""
Factory function for creating cache instances.
Based on the provided redis_url, this function decides whether to create a RedisCache
or DiskCache instance. If RedisCache is available and redis_url is provided,
a RedisCache instance is created. Otherwise, a DiskCache instance is used.
Args:
seed (str): A string used as a seed or namespace for the cache.
This could be useful for creating distinct cache instances
or for namespacing keys in the cache.
redis_url (str or None): The URL for the Redis server. If this is None
or if RedisCache is not available, a DiskCache instance is created.
Returns:
An instance of either RedisCache or DiskCache, depending on the availability of RedisCache
and the provided redis_url.
Examples:
Creating a Redis cache
> redis_cache = cache_factory("myseed", "redis://localhost:6379/0")
Creating a Disk cache
> disk_cache = cache_factory("myseed", None)
"""
if RedisCache is not None and redis_url is not None:
return RedisCache(seed, redis_url)
else:
return DiskCache(f"./{cache_path_root}/{seed}")

88
autogen/cache/disk_cache.py vendored Normal file
View File

@ -0,0 +1,88 @@
import diskcache
from .abstract_cache_base import AbstractCache
class DiskCache(AbstractCache):
"""
Implementation of AbstractCache using the DiskCache library.
This class provides a concrete implementation of the AbstractCache
interface using the diskcache library for caching data on disk.
Attributes:
cache (diskcache.Cache): The DiskCache instance used for caching.
Methods:
__init__(self, seed): Initializes the DiskCache with the given seed.
get(self, key, default=None): Retrieves an item from the cache.
set(self, key, value): Sets an item in the cache.
close(self): Closes the cache.
__enter__(self): Context management entry.
__exit__(self, exc_type, exc_value, traceback): Context management exit.
"""
def __init__(self, seed):
"""
Initialize the DiskCache instance.
Args:
seed (str): A seed or namespace for the cache. This is used to create
a unique storage location for the cache data.
"""
self.cache = diskcache.Cache(seed)
def get(self, key, default=None):
"""
Retrieve an item from the cache.
Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
Defaults to None.
Returns:
The value associated with the key if found, else the default value.
"""
return self.cache.get(key, default)
def set(self, key, value):
"""
Set an item in the cache.
Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.
"""
self.cache.set(key, value)
def close(self):
"""
Close the cache.
Perform any necessary cleanup, such as closing file handles or
releasing resources.
"""
self.cache.close()
def __enter__(self):
"""
Enter the runtime context related to the object.
Returns:
self: The instance itself.
"""
return self
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the runtime context related to the object.
Perform cleanup actions such as closing the cache.
Args:
exc_type: The exception type if an exception was raised in the context.
exc_value: The exception value if an exception was raised in the context.
traceback: The traceback if an exception was raised in the context.
"""
self.close()

110
autogen/cache/redis_cache.py vendored Normal file
View File

@ -0,0 +1,110 @@
import pickle
import redis
from .abstract_cache_base import AbstractCache
class RedisCache(AbstractCache):
"""
Implementation of AbstractCache using the Redis database.
This class provides a concrete implementation of the AbstractCache
interface using the Redis database for caching data.
Attributes:
seed (str): A seed or namespace used as a prefix for cache keys.
cache (redis.Redis): The Redis client used for caching.
Methods:
__init__(self, seed, redis_url): Initializes the RedisCache with the given seed and Redis URL.
_prefixed_key(self, key): Internal method to get a namespaced cache key.
get(self, key, default=None): Retrieves an item from the cache.
set(self, key, value): Sets an item in the cache.
close(self): Closes the Redis client.
__enter__(self): Context management entry.
__exit__(self, exc_type, exc_value, traceback): Context management exit.
"""
def __init__(self, seed, redis_url):
"""
Initialize the RedisCache instance.
Args:
seed (str): A seed or namespace for the cache. This is used as a prefix for all cache keys.
redis_url (str): The URL for the Redis server.
"""
self.seed = seed
self.cache = redis.Redis.from_url(redis_url)
def _prefixed_key(self, key):
"""
Get a namespaced key for the cache.
Args:
key (str): The original key.
Returns:
str: The namespaced key.
"""
return f"autogen:{self.seed}:{key}"
def get(self, key, default=None):
"""
Retrieve an item from the Redis cache.
Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
Defaults to None.
Returns:
The deserialized value associated with the key if found, else the default value.
"""
result = self.cache.get(self._prefixed_key(key))
if result is None:
return default
return pickle.loads(result)
def set(self, key, value):
"""
Set an item in the Redis cache.
Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.
Notes:
The value is serialized using pickle before being stored in Redis.
"""
serialized_value = pickle.dumps(value)
self.cache.set(self._prefixed_key(key), serialized_value)
def close(self):
"""
Close the Redis client.
Perform any necessary cleanup, such as closing network connections.
"""
self.cache.close()
def __enter__(self):
"""
Enter the runtime context related to the object.
Returns:
self: The instance itself.
"""
return self
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the runtime context related to the object.
Perform cleanup actions such as closing the Redis client.
Args:
exc_type: The exception type if an exception was raised in the context.
exc_value: The exception value if an exception was raised in the context.
traceback: The traceback if an exception was raised in the context.
"""
self.close()

View File

@ -9,6 +9,7 @@ from flaml.automl.logger import logger_formatter
from pydantic import BaseModel
from autogen.cache.cache import Cache
from autogen.oai import completion
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, get_key, OAI_PRICE1K
@ -35,7 +36,6 @@ else:
)
from openai.types.completion import Completion
from openai.types.completion_usage import CompletionUsage
import diskcache
if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True
@ -48,12 +48,15 @@ if not logger.handlers:
_ch.setFormatter(logger_formatter)
logger.addHandler(_ch)
LEGACY_DEFAULT_CACHE_SEED = 41
LEGACY_CACHE_DIR = ".cache"
class OpenAIWrapper:
"""A wrapper class for openai client."""
cache_path_root: str = ".cache"
extra_kwargs = {
"cache",
"cache_seed",
"filter_func",
"allow_format_str_template",
@ -62,6 +65,7 @@ class OpenAIWrapper:
"api_type",
"tags",
}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs)
openai_kwargs = openai_kwargs | aopenai_kwargs
@ -205,9 +209,14 @@ class OpenAIWrapper:
The actual prompt will be:
"Complete the following sentence: Today I feel".
More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating).
- `cache_seed` (int | None) for the cache. Default to 41.
- cache (Cache | None): A Cache object to use for response cache. Default to None.
Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided,
then the cache_seed argument is ignored. If this argument is not provided or None,
then the cache_seed argument is used.
- (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41.
An integer cache_seed is useful when implementing "controlled randomness" for the completion.
None for no caching.
Note: this is a legacy argument. It is only used when the cache argument is not provided.
- filter_func (Callable | None): A function that takes in the context and the response
and returns a boolean to indicate whether the response is valid. E.g.,
@ -235,13 +244,21 @@ class OpenAIWrapper:
# construct the create params
params = self._construct_create_params(create_config, extra_kwargs)
# get the cache_seed, filter_func and context
cache_seed = extra_kwargs.get("cache_seed", 41)
cache_seed = extra_kwargs.get("cache_seed", LEGACY_DEFAULT_CACHE_SEED)
cache = extra_kwargs.get("cache")
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")
# Try to load the response from cache
if cache_seed is not None:
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
cache_client = None
if cache is not None:
# Use the cache object if provided.
cache_client = cache
elif cache_seed is not None:
# Legacy cache behavior, if cache_seed is given, use DiskCache.
cache_client = Cache.disk(cache_seed, LEGACY_CACHE_DIR)
if cache_client is not None:
with cache_client as cache:
# Try to get the response from cache
key = get_key(params)
response: ChatCompletion = cache.get(key, None)
@ -276,9 +293,9 @@ class OpenAIWrapper:
# add cost calculation before caching no matter filter is passed or not
response.cost = self.cost(response)
self._update_usage_summary(response, use_cache=False)
if cache_seed is not None:
if cache_client is not None:
# Cache the response
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
with cache_client as cache:
cache.set(key, response)
# check the filter

View File

@ -52,6 +52,7 @@ setuptools.setup(
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],
"graphs": ["networkx~=3.2.1", "matplotlib~=3.8.1"],
"redis": ["redis"],
},
classifiers=[
"Programming Language :: Python :: 3",

View File

@ -0,0 +1,227 @@
import os
import sys
import time
import pytest
import autogen
from autogen.agentchat import AssistantAgent, UserProxyAgent
from autogen.cache.cache import Cache
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai, skip_redis # noqa: E402
try:
from openai import OpenAI
except ImportError:
skip_openai_tests = True
else:
skip_openai_tests = False or skip_openai
try:
import redis
except ImportError:
skip_redis_tests = True
else:
skip_redis_tests = False or skip_redis
@pytest.mark.skipif(skip_openai_tests, reason="openai not installed OR requested to skip")
def test_legacy_disk_cache():
random_cache_seed = int.from_bytes(os.urandom(2), "big")
start_time = time.time()
cold_cache_messages = run_conversation(
cache_seed=random_cache_seed,
)
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_messages = run_conversation(
cache_seed=random_cache_seed,
)
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_messages == warm_cache_messages
assert duration_with_warm_cache < duration_with_cold_cache
@pytest.mark.skipif(skip_openai_tests or skip_redis_tests, reason="redis not installed OR requested to skip")
def test_redis_cache():
random_cache_seed = int.from_bytes(os.urandom(2), "big")
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
start_time = time.time()
with Cache.redis(random_cache_seed, redis_url) as cache_client:
cold_cache_messages = run_conversation(cache_seed=None, cache=cache_client)
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_messages = run_conversation(cache_seed=None, cache=cache_client)
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_messages == warm_cache_messages
assert duration_with_warm_cache < duration_with_cold_cache
random_cache_seed = int.from_bytes(os.urandom(2), "big")
with Cache.redis(random_cache_seed, redis_url) as cache_client:
cold_cache_messages = run_groupchat_conversation(cache=cache_client)
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_messages = run_groupchat_conversation(cache=cache_client)
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_messages == warm_cache_messages
assert duration_with_warm_cache < duration_with_cold_cache
@pytest.mark.skipif(skip_openai_tests, reason="openai not installed OR requested to skip")
def test_disk_cache():
random_cache_seed = int.from_bytes(os.urandom(2), "big")
start_time = time.time()
with Cache.disk(random_cache_seed) as cache_client:
cold_cache_messages = run_conversation(cache_seed=None, cache=cache_client)
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_messages = run_conversation(cache_seed=None, cache=cache_client)
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_messages == warm_cache_messages
assert duration_with_warm_cache < duration_with_cold_cache
random_cache_seed = int.from_bytes(os.urandom(2), "big")
with Cache.disk(random_cache_seed) as cache_client:
cold_cache_messages = run_groupchat_conversation(cache=cache_client)
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_messages = run_groupchat_conversation(cache=cache_client)
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_messages == warm_cache_messages
assert duration_with_warm_cache < duration_with_cold_cache
def run_conversation(cache_seed, human_input_mode="NEVER", max_consecutive_auto_reply=5, cache=None):
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
here = os.path.abspath(os.path.dirname(__file__))
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={
"model": {
"gpt-3.5-turbo",
"gpt-35-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
"chatgpt-35-turbo-0301",
"gpt-35-turbo-v0301",
"gpt",
},
},
)
llm_config = {
"cache_seed": cache_seed,
"config_list": config_list,
"max_tokens": 1024,
}
assistant = AssistantAgent(
"coding_agent",
llm_config=llm_config,
)
user = UserProxyAgent(
"user",
human_input_mode=human_input_mode,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
max_consecutive_auto_reply=max_consecutive_auto_reply,
code_execution_config={
"work_dir": f"{here}/test_agent_scripts",
"use_docker": "python:3",
"timeout": 60,
},
llm_config=llm_config,
system_message="""Is code provided but not enclosed in ``` blocks?
If so, remind that code blocks need to be enclosed in ``` blocks.
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation.
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""",
)
user.initiate_chat(assistant, message="TERMINATE", cache=cache)
# should terminate without sending any message
assert assistant.last_message()["content"] == assistant.last_message(user)["content"] == "TERMINATE"
coding_task = "Print hello world to a file called hello.txt"
# track how long this takes
user.initiate_chat(assistant, message=coding_task, cache=cache)
return user.chat_messages[list(user.chat_messages.keys())[-0]]
def run_groupchat_conversation(cache, human_input_mode="NEVER", max_consecutive_auto_reply=5):
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
here = os.path.abspath(os.path.dirname(__file__))
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={
"model": {
"gpt-3.5-turbo",
"gpt-35-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
"chatgpt-35-turbo-0301",
"gpt-35-turbo-v0301",
"gpt",
},
},
)
llm_config = {
"cache_seed": None,
"config_list": config_list,
"max_tokens": 1024,
}
assistant = AssistantAgent(
"coding_agent",
llm_config=llm_config,
)
planner = AssistantAgent(
"planner",
llm_config=llm_config,
)
user = UserProxyAgent(
"user",
human_input_mode=human_input_mode,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
max_consecutive_auto_reply=max_consecutive_auto_reply,
code_execution_config={
"work_dir": f"{here}/test_agent_scripts",
"use_docker": "python:3",
"timeout": 60,
},
system_message="""Is code provided but not enclosed in ``` blocks?
If so, remind that code blocks need to be enclosed in ``` blocks.
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation.
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""",
)
group_chat = autogen.GroupChat(
agents=[planner, assistant, user],
messages=[],
max_round=4,
speaker_selection_method="round_robin",
)
manager = autogen.GroupChatManager(groupchat=group_chat, llm_config=llm_config)
coding_task = "Print hello world to a file called hello.txt"
user.initiate_chat(manager, message=coding_task, cache=cache)
return user.chat_messages[list(user.chat_messages.keys())[-0]]

53
test/cache/test_cache.py vendored Normal file
View File

@ -0,0 +1,53 @@
import unittest
from unittest.mock import patch, MagicMock
from autogen.cache.cache import Cache
class TestCache(unittest.TestCase):
def setUp(self):
self.config = {"cache_seed": "test_seed", "redis_url": "redis://test", "cache_path_root": ".test_cache"}
@patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock())
def test_init(self, mock_cache_factory):
cache = Cache(self.config)
self.assertIsInstance(cache.cache, MagicMock)
mock_cache_factory.assert_called_with("test_seed", "redis://test", ".test_cache")
@patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock())
def test_context_manager(self, mock_cache_factory):
mock_cache_instance = MagicMock()
mock_cache_factory.return_value = mock_cache_instance
with Cache(self.config) as cache:
self.assertIsInstance(cache, MagicMock)
mock_cache_instance.__enter__.assert_called()
mock_cache_instance.__exit__.assert_called()
@patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock())
def test_get_set(self, mock_cache_factory):
key = "key"
value = "value"
mock_cache_instance = MagicMock()
mock_cache_factory.return_value = mock_cache_instance
cache = Cache(self.config)
cache.set(key, value)
cache.get(key)
mock_cache_instance.set.assert_called_with(key, value)
mock_cache_instance.get.assert_called_with(key, None)
@patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock())
def test_close(self, mock_cache_factory):
mock_cache_instance = MagicMock()
mock_cache_factory.return_value = mock_cache_instance
cache = Cache(self.config)
cache.close()
mock_cache_instance.close.assert_called()
if __name__ == "__main__":
unittest.main()

51
test/cache/test_disk_cache.py vendored Normal file
View File

@ -0,0 +1,51 @@
import unittest
from unittest.mock import patch, MagicMock
from autogen.cache.disk_cache import DiskCache
class TestDiskCache(unittest.TestCase):
def setUp(self):
self.seed = "test_seed"
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
def test_init(self, mock_cache):
cache = DiskCache(self.seed)
self.assertIsInstance(cache.cache, MagicMock)
mock_cache.assert_called_with(self.seed)
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
def test_get(self, mock_cache):
key = "key"
value = "value"
cache = DiskCache(self.seed)
cache.cache.get.return_value = value
self.assertEqual(cache.get(key), value)
cache.cache.get.assert_called_with(key, None)
cache.cache.get.return_value = None
self.assertIsNone(cache.get(key, None))
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
def test_set(self, mock_cache):
key = "key"
value = "value"
cache = DiskCache(self.seed)
cache.set(key, value)
cache.cache.set.assert_called_with(key, value)
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
def test_context_manager(self, mock_cache):
with DiskCache(self.seed) as cache:
self.assertIsInstance(cache, DiskCache)
mock_cache_instance = cache.cache
mock_cache_instance.close.assert_called()
@patch("autogen.cache.disk_cache.diskcache.Cache", return_value=MagicMock())
def test_close(self, mock_cache):
cache = DiskCache(self.seed)
cache.close()
cache.cache.close.assert_called()
if __name__ == "__main__":
unittest.main()

69
test/cache/test_redis_cache.py vendored Normal file
View File

@ -0,0 +1,69 @@
import unittest
import pickle
from unittest.mock import patch, MagicMock
import pytest
try:
from autogen.cache.redis_cache import RedisCache
skip_redis_tests = False
except ImportError:
skip_redis_tests = True
class TestRedisCache(unittest.TestCase):
def setUp(self):
self.seed = "test_seed"
self.redis_url = "redis://localhost:6379/0"
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
def test_init(self, mock_redis_from_url):
cache = RedisCache(self.seed, self.redis_url)
self.assertEqual(cache.seed, self.seed)
mock_redis_from_url.assert_called_with(self.redis_url)
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
def test_prefixed_key(self, mock_redis_from_url):
cache = RedisCache(self.seed, self.redis_url)
key = "test_key"
expected_prefixed_key = f"autogen:{self.seed}:{key}"
self.assertEqual(cache._prefixed_key(key), expected_prefixed_key)
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
def test_get(self, mock_redis_from_url):
key = "key"
value = "value"
serialized_value = pickle.dumps(value)
cache = RedisCache(self.seed, self.redis_url)
cache.cache.get.return_value = serialized_value
self.assertEqual(cache.get(key), value)
cache.cache.get.assert_called_with(f"autogen:{self.seed}:{key}")
cache.cache.get.return_value = None
self.assertIsNone(cache.get(key))
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
def test_set(self, mock_redis_from_url):
key = "key"
value = "value"
serialized_value = pickle.dumps(value)
cache = RedisCache(self.seed, self.redis_url)
cache.set(key, value)
cache.cache.set.assert_called_with(f"autogen:{self.seed}:{key}", serialized_value)
@pytest.mark.skipif(skip_redis_tests, reason="redis not installed")
@patch("autogen.cache.redis_cache.redis.Redis.from_url", return_value=MagicMock())
def test_context_manager(self, mock_redis_from_url):
with RedisCache(self.seed, self.redis_url) as cache:
self.assertIsInstance(cache, RedisCache)
mock_redis_instance = cache.cache
mock_redis_instance.close.assert_called()
if __name__ == "__main__":
unittest.main()

View File

@ -14,3 +14,5 @@ def pytest_addoption(parser):
def pytest_configure(config):
global skip_openai
skip_openai = config.getoption("--skip-openai", False)
global skip_redis
skip_redis = config.getoption("--skip-redis", False)

View File

@ -1,7 +1,11 @@
import shutil
import time
import pytest
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
from autogen.oai.client import LEGACY_CACHE_DIR, LEGACY_DEFAULT_CACHE_SEED
import sys
import os
from autogen.cache.cache import Cache
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai # noqa: E402
@ -151,10 +155,144 @@ def test_usage_summary():
assert client.actual_usage_summary is None, "No actual cost should be recorded"
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_legacy_cache():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
)
# Clear cache.
if os.path.exists(LEGACY_CACHE_DIR):
shutil.rmtree(LEGACY_CACHE_DIR)
# Test default cache seed.
client = OpenAIWrapper(config_list=config_list)
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_response == warm_cache_response
assert duration_with_warm_cache < duration_with_cold_cache
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(LEGACY_DEFAULT_CACHE_SEED)))
# Test with cache seed set through constructor
client = OpenAIWrapper(config_list=config_list, cache_seed=13)
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_response == warm_cache_response
assert duration_with_warm_cache < duration_with_cold_cache
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(13)))
# Test with cache seed set through create method
client = OpenAIWrapper(config_list=config_list)
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache_seed=17)
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache_seed=17)
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_response == warm_cache_response
assert duration_with_warm_cache < duration_with_cold_cache
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(17)))
# Test using a different cache seed through create method.
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache_seed=21)
end_time = time.time()
duration_with_cold_cache = end_time - start_time
assert duration_with_warm_cache < duration_with_cold_cache
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(21)))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_cache():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
)
# Clear cache.
if os.path.exists(LEGACY_CACHE_DIR):
shutil.rmtree(LEGACY_CACHE_DIR)
cache_dir = ".cache_test"
assert cache_dir != LEGACY_CACHE_DIR
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
# Test cache set through constructor.
with Cache.disk(cache_seed=49, cache_path_root=cache_dir) as cache:
client = OpenAIWrapper(config_list=config_list, cache=cache)
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}])
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_response == warm_cache_response
assert duration_with_warm_cache < duration_with_cold_cache
assert os.path.exists(os.path.join(cache_dir, str(49)))
# Test legacy cache is not used.
assert not os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(49)))
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))
# Test cache set through method.
client = OpenAIWrapper(config_list=config_list)
with Cache.disk(cache_seed=312, cache_path_root=cache_dir) as cache:
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache=cache)
end_time = time.time()
duration_with_cold_cache = end_time - start_time
start_time = time.time()
warm_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache=cache)
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_response == warm_cache_response
assert duration_with_warm_cache < duration_with_cold_cache
assert os.path.exists(os.path.join(cache_dir, str(312)))
# Test legacy cache is not used.
assert not os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(312)))
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))
# Test different cache seed.
with Cache.disk(cache_seed=123, cache_path_root=cache_dir) as cache:
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": "random()"}], cache=cache)
end_time = time.time()
duration_with_cold_cache = end_time - start_time
assert duration_with_warm_cache < duration_with_cold_cache
# Test legacy cache is not used.
assert not os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(123)))
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))
if __name__ == "__main__":
# test_aoai_chat_completion()
# test_oai_tool_calling_extraction()
# test_chat_completion()
test_completion()
# test_completion()
# # test_cost()
# test_usage_summary()
test_legacy_cache()
test_cache()

View File

@ -209,6 +209,17 @@ Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.gith
### Optional Dependencies
- #### LLM Caching
To use LLM caching with Redis, you need to install the Python package with
the option `redis`:
```bash
pip install "pyautogen[redis]"
```
See [LLM Caching](Use-Cases/agent_chat.md#llm-caching) for details.
- #### blendsearch
`pyautogen<0.2` offers a cost-effective hyperparameter optimization technique [EcoOptiGen](https://arxiv.org/abs/2303.04673) for tuning Large Language Models. Please install with the [blendsearch] option to use it.

View File

@ -285,6 +285,33 @@ By adopting the conversation-driven control with both programming language and n
- LLM-based function call. In this approach, LLM decides whether or not to call a particular function depending on the conversation status in each inference call.
By messaging additional agents in the called functions, the LLM can drive dynamic multi-agent conversation. A working system showcasing this type of dynamic conversation can be found in the [multi-user math problem solving scenario](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_two_users.ipynb), where a student assistant would automatically resort to an expert using function calls.
### LLM Caching
Since version 0.2.8, a configurable context manager allows you to easily configure LLM cache, using either DiskCache or Redis. All agents inside the context manager will use the same cache.
```python
from autogen.cache.cache import Cache
with Cache.redis(cache_seed=42, redis_url="redis://localhost:6379/0") as cache:
user.initiate_chat(assistant, message=coding_task, cache=cache)
with Cache.disk(cache_seed=42, cache_dir=".cache") as cache:
user.initiate_chat(assistant, message=coding_task, cache=cache)
```
For backward compatibility, DiskCache is on by default with `cache_seed` set to 41.
To disable caching completely, set `cache_seed` to `None` in the `llm_config` of the agent.
```python
assistant = AssistantAgent(
"coding_agent",
llm_config={
"cache_seed": None,
"config_list": OAI_CONFIG_LIST,
"max_tokens": 1024,
},
)
```
### Diverse Applications Implemented with AutoGen

View File

@ -168,19 +168,45 @@ Total cost: 0.00027
## Caching
API call results are cached locally and reused when the same request is issued. This is useful when repeating or continuing experiments for reproducibility and cost saving. It still allows controlled randomness by setting the "cache_seed" specified in `OpenAIWrapper.create()` or the constructor of `OpenAIWrapper`.
API call results are cached locally and reused when the same request is issued. This is useful when repeating or continuing experiments for reproducibility and cost saving.
Starting version 0.2.8, a configurable context manager allows you to easily configure
the cache, using either DiskCache or Redis.
All `OpenAIWrapper` created inside the context manager can use the same cache through the constructor.
```python
client = OpenAIWrapper(cache_seed=...)
client.create(...)
from autogen.cache.cache import Cache
with Cache.redis(cache_seed=42, redis_url="redis://localhost:6379/0") as cache:
client = OpenAIWrapper(..., cache=cache)
client.create(...)
with Cache.disk(cache_seed=42, cache_dir=".cache") as cache:
client = OpenAIWrapper(..., cache=cache)
client.create(...)
```
You can also set a cache directly in the `create()` method.
```python
client = OpenAIWrapper()
client.create(cache_seed=..., ...)
with Cache.disk(cache_seed=42, cache_dir=".cache") as cache:
client.create(..., cache=cache)
```
Caching is enabled by default with cache_seed 41. To disable it please set `cache_seed` to None.
You can control the randomness by setting the `cache_seed` parameter.
### Turnning off cache
For backward compatibility, DiskCache is always enabled by default
with `cache_seed` set to 41. To fully disable it, set `cache_seed` to None.
```python
# Turn off cache in constructor,
client = OpenAIWrapper(..., cache_seed=None)
# or directly in create().
client.create(..., cache_seed=None)
```
_NOTE_. openai v1.1 introduces a new param `seed`. The difference between autogen's `cache_seed` and openai's `seed` is that:
* autogen uses local disk cache to guarantee the exactly same output is produced for the same input and when cache is hit, no openai api call will be made.