import os from typing import Any, final from dataclasses import dataclass import pipmaster as pm import configparser from contextlib import asynccontextmanager if not pm.is_installed("redis"): pm.install("redis") # aioredis is a depricated library, replaced with redis from redis.asyncio import Redis, ConnectionPool from redis.exceptions import RedisError, ConnectionError from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseKVStorage import json config = configparser.ConfigParser() config.read("config.ini", "utf-8") # Constants for Redis connection pool MAX_CONNECTIONS = 50 SOCKET_TIMEOUT = 5.0 SOCKET_CONNECT_TIMEOUT = 3.0 @final @dataclass class RedisKVStorage(BaseKVStorage): def __post_init__(self): redis_url = os.environ.get( "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") ) # Create a connection pool with limits self._pool = ConnectionPool.from_url( redis_url, max_connections=MAX_CONNECTIONS, decode_responses=True, socket_timeout=SOCKET_TIMEOUT, socket_connect_timeout=SOCKET_CONNECT_TIMEOUT ) self._redis = Redis(connection_pool=self._pool) logger.info(f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections") @asynccontextmanager async def _get_redis_connection(self): """Safe context manager for Redis operations.""" try: yield self._redis except ConnectionError as e: logger.error(f"Redis connection error in {self.namespace}: {e}") raise except RedisError as e: logger.error(f"Redis operation error in {self.namespace}: {e}") raise except Exception as e: logger.error(f"Unexpected error in Redis operation for {self.namespace}: {e}") raise async def close(self): """Close the Redis connection pool to prevent resource leaks.""" if hasattr(self, '_redis') and self._redis: await self._redis.close() await self._pool.disconnect() logger.debug(f"Closed Redis connection pool for {self.namespace}") async def __aenter__(self): """Support for async context manager.""" return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Ensure Redis resources are cleaned up when exiting context.""" await self.close() async def get_by_id(self, id: str) -> dict[str, Any] | None: async with self._get_redis_connection() as redis: try: data = await redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None except json.JSONDecodeError as e: logger.error(f"JSON decode error for id {id}: {e}") return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async with self._get_redis_connection() as redis: try: pipe = redis.pipeline() for id in ids: pipe.get(f"{self.namespace}:{id}") results = await pipe.execute() return [json.loads(result) if result else None for result in results] except json.JSONDecodeError as e: logger.error(f"JSON decode error in batch get: {e}") return [None] * len(ids) async def filter_keys(self, keys: set[str]) -> set[str]: async with self._get_redis_connection() as redis: pipe = redis.pipeline() for key in keys: pipe.exists(f"{self.namespace}:{key}") results = await pipe.execute() existing_ids = {keys[i] for i, exists in enumerate(results) if exists} return set(keys) - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: return logger.info(f"Inserting {len(data)} items to {self.namespace}") async with self._get_redis_connection() as redis: try: pipe = redis.pipeline() for k, v in data.items(): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) await pipe.execute() for k in data: data[k]["_id"] = k except json.JSONEncodeError as e: logger.error(f"JSON encode error during upsert: {e}") raise async def delete(self, ids: list[str]) -> None: """Delete entries with specified IDs""" if not ids: return async with self._get_redis_connection() as redis: pipe = redis.pipeline() for id in ids: pipe.delete(f"{self.namespace}:{id}") results = await pipe.execute() deleted_count = sum(results) logger.info( f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" ) async def delete_entity(self, entity_name: str) -> None: """Delete an entity by name""" try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug( f"Attempting to delete entity {entity_name} with ID {entity_id}" ) async with self._get_redis_connection() as redis: result = await redis.delete(f"{self.namespace}:{entity_id}") if result: logger.debug(f"Successfully deleted entity {entity_name}") else: logger.debug(f"Entity {entity_name} not found in storage") except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity""" try: async with self._get_redis_connection() as redis: cursor = 0 relation_keys = [] pattern = f"{self.namespace}:*" while True: cursor, keys = await redis.scan(cursor, match=pattern) # Process keys in batches pipe = redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() for key, value in zip(keys, values): if value: try: data = json.loads(value) if ( data.get("src_id") == entity_name or data.get("tgt_id") == entity_name ): relation_keys.append(key) except json.JSONDecodeError: logger.warning(f"Invalid JSON in key {key}") continue if cursor == 0: break # Delete relations in batches if relation_keys: # Delete in chunks to avoid too many arguments chunk_size = 1000 for i in range(0, len(relation_keys), chunk_size): chunk = relation_keys[i:i + chunk_size] deleted = await redis.delete(*chunk) logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})") else: logger.debug(f"No relations found for entity {entity_name}") except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self) -> None: # Redis handles persistence automatically pass