LightRAG/lightrag/kg/redis_impl.py

247 lines
9.1 KiB
Python
Raw Normal View History

import os
2025-02-16 15:52:59 +01:00
from typing import Any, final
from dataclasses import dataclass
2025-01-27 09:39:58 +01:00
import pipmaster as pm
import configparser
from contextlib import asynccontextmanager
2025-01-27 23:21:34 +08:00
2025-01-27 09:39:58 +01:00
if not pm.is_installed("redis"):
pm.install("redis")
# aioredis is a depricated library, replaced with redis
2025-04-06 17:45:32 +08:00
from redis.asyncio import Redis, ConnectionPool # type: ignore
from redis.exceptions import RedisError, ConnectionError # type: ignore
from lightrag.utils import logger
2025-04-05 15:27:59 -07:00
from lightrag.base import BaseKVStorage
import json
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Constants for Redis connection pool
MAX_CONNECTIONS = 50
SOCKET_TIMEOUT = 5.0
SOCKET_CONNECT_TIMEOUT = 3.0
2025-02-16 15:54:54 +01:00
2025-02-16 15:52:59 +01:00
@final
@dataclass
class RedisKVStorage(BaseKVStorage):
def __post_init__(self):
2025-02-11 03:29:40 +08:00
redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
# Create a connection pool with limits
self._pool = ConnectionPool.from_url(
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
2025-04-06 17:45:32 +08:00
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
)
self._redis = Redis(connection_pool=self._pool)
2025-04-06 17:45:32 +08:00
logger.info(
f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections"
)
@asynccontextmanager
async def _get_redis_connection(self):
"""Safe context manager for Redis operations."""
try:
yield self._redis
except ConnectionError as e:
logger.error(f"Redis connection error in {self.namespace}: {e}")
raise
except RedisError as e:
logger.error(f"Redis operation error in {self.namespace}: {e}")
raise
except Exception as e:
2025-04-06 17:45:32 +08:00
logger.error(
f"Unexpected error in Redis operation for {self.namespace}: {e}"
)
raise
async def close(self):
"""Close the Redis connection pool to prevent resource leaks."""
2025-04-06 17:45:32 +08:00
if hasattr(self, "_redis") and self._redis:
await self._redis.close()
await self._pool.disconnect()
logger.debug(f"Closed Redis connection pool for {self.namespace}")
async def __aenter__(self):
"""Support for async context manager."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Ensure Redis resources are cleaned up when exiting context."""
await self.close()
2025-02-09 15:24:30 +01:00
async def get_by_id(self, id: str) -> dict[str, Any] | None:
if id == "default":
# Find all cache entries with cache_type == "extract"
async with self._get_redis_connection() as redis:
try:
result = {}
pattern = f"{self.namespace}:*"
cursor = 0
2025-06-29 15:15:49 +08:00
while True:
2025-06-29 15:15:49 +08:00
cursor, keys = await redis.scan(
cursor, match=pattern, count=100
)
if keys:
# Batch get values for these keys
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
2025-06-29 15:15:49 +08:00
# Check each value for cache_type == "extract"
for key, value in zip(keys, values):
if value:
try:
data = json.loads(value)
2025-06-29 15:15:49 +08:00
if (
isinstance(data, dict)
and data.get("cache_type") == "extract"
):
# Extract cache key (remove namespace prefix)
2025-06-29 15:15:49 +08:00
cache_key = key.replace(
f"{self.namespace}:", ""
)
result[cache_key] = data
except json.JSONDecodeError:
continue
2025-06-29 15:15:49 +08:00
if cursor == 0:
break
2025-06-29 15:15:49 +08:00
return result if result else None
except Exception as e:
logger.error(f"Error scanning Redis for extract cache entries: {e}")
return None
else:
# Original behavior for non-"default" ids
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
return None
2025-02-09 10:33:15 +01:00
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
return [json.loads(result) if result else None for result in results]
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in batch get: {e}")
return [None] * len(ids)
async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for key in keys:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
2025-02-19 22:22:41 +01:00
if not data:
return
2025-04-06 17:45:32 +08:00
logger.info(f"Inserting {len(data)} items to {self.namespace}")
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
for k in data:
data[k]["_id"] = k
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise
2025-04-06 17:45:32 +08:00
async def index_done_callback(self) -> None:
2025-02-16 16:04:07 +01:00
# Redis handles persistence automatically
2025-02-18 10:21:54 +01:00
pass
2025-04-06 17:45:32 +08:00
2025-03-04 15:50:53 +08:00
async def delete(self, ids: list[str]) -> None:
"""Delete entries with specified IDs"""
2025-03-04 15:50:53 +08:00
if not ids:
return
2025-03-04 15:53:20 +08:00
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for id in ids:
pipe.delete(f"{self.namespace}:{id}")
2025-03-04 15:53:20 +08:00
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
2025-03-04 15:50:53 +08:00
2025-03-31 23:22:27 +08:00
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode
2025-03-04 15:53:20 +08:00
Importance notes for Redis storage:
1. This will immediately delete the specified cache modes from Redis
2025-03-04 15:53:20 +08:00
Args:
modes (list[str]): List of cache mode to be drop from storage
2025-03-31 23:22:27 +08:00
Returns:
True: if the cache drop successfully
False: if the cache drop failed
"""
if not modes:
return False
2025-03-04 15:50:53 +08:00
try:
await self.delete(modes)
return True
except Exception:
return False
2025-03-31 23:22:27 +08:00
2025-03-31 01:40:14 +08:00
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all keys under the current namespace.
2025-03-31 23:22:27 +08:00
2025-03-31 01:40:14 +08:00
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
2025-04-05 15:27:59 -07:00
async with self._get_redis_connection() as redis:
try:
keys = await redis.keys(f"{self.namespace}:*")
if keys:
pipe = redis.pipeline()
for key in keys:
2025-04-05 15:27:59 -07:00
pipe.delete(key)
results = await pipe.execute()
deleted_count = sum(results)
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
2025-04-06 17:45:32 +08:00
return {
"status": "success",
"message": f"{deleted_count} keys dropped",
}
else:
2025-04-05 15:27:59 -07:00
logger.info(f"No keys found to drop in {self.namespace}")
return {"status": "success", "message": "no keys to drop"}
2025-03-04 15:53:20 +08:00
2025-04-05 15:27:59 -07:00
except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}")
return {"status": "error", "message": str(e)}