diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index f4b533d8..28a86b6e 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -603,7 +603,7 @@ class PGKVStorage(BaseKVStorage): try: results = await self.db.query(sql, params, multirows=True) - + # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): processed_results = {} @@ -611,19 +611,21 @@ class PGKVStorage(BaseKVStorage): # Parse flattened key to extract cache_type key_parts = row["id"].split(":") cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown" - + # Map field names and add cache_type for compatibility processed_row = { **row, - "return": row.get("return_value", ""), # Map return_value to return + "return": row.get( + "return_value", "" + ), # Map return_value to return "cache_type": cache_type, # Add cache_type from key "original_prompt": row.get("original_prompt", ""), "chunk_id": row.get("chunk_id"), - "mode": row.get("mode", "default") + "mode": row.get("mode", "default"), } processed_results[row["id"]] = processed_row return processed_results - + # For other namespaces, return as-is return {row["id"]: row for row in results} except Exception as e: diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index c87a9a4b..5be9f0e6 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -14,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore from redis.exceptions import RedisError, ConnectionError # type: ignore from lightrag.utils import logger -from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus +from lightrag.base import ( + BaseKVStorage, + DocStatusStorage, + DocStatus, + DocProcessingStatus, +) import json @@ -29,10 +34,10 @@ SOCKET_CONNECT_TIMEOUT = 3.0 class RedisConnectionManager: """Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI""" - + _pools = {} _lock = threading.Lock() - + @classmethod def get_pool(cls, redis_url: str) -> ConnectionPool: """Get or create a connection pool for the given Redis URL""" @@ -48,7 +53,7 @@ class RedisConnectionManager: ) logger.info(f"Created shared Redis connection pool for {redis_url}") return cls._pools[redis_url] - + @classmethod def close_all_pools(cls): """Close all connection pools (for cleanup)""" @@ -254,17 +259,21 @@ class RedisKVStorage(BaseKVStorage): pattern = f"{self.namespace}:{mode}:*" cursor = 0 mode_keys = [] - + while True: - cursor, keys = await redis.scan(cursor, match=pattern, count=1000) + cursor, keys = await redis.scan( + cursor, match=pattern, count=1000 + ) if keys: mode_keys.extend(keys) - + if cursor == 0: break - + keys_to_delete.extend(mode_keys) - logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'") + logger.info( + f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'" + ) if keys_to_delete: # Batch delete @@ -296,7 +305,7 @@ class RedisKVStorage(BaseKVStorage): pattern = f"{self.namespace}:*" cursor = 0 deleted_count = 0 - + while True: cursor, keys = await redis.scan(cursor, match=pattern, count=1000) if keys: @@ -306,7 +315,7 @@ class RedisKVStorage(BaseKVStorage): pipe.delete(key) results = await pipe.execute() deleted_count += sum(results) - + if cursor == 0: break @@ -419,7 +428,9 @@ class RedisDocStatusStorage(DocStatusStorage): try: async with self._get_redis_connection() as redis: await redis.ping() - logger.info(f"Connected to Redis for doc status namespace {self.namespace}") + logger.info( + f"Connected to Redis for doc status namespace {self.namespace}" + ) except Exception as e: logger.error(f"Failed to connect to Redis for doc status: {e}") raise @@ -475,7 +486,7 @@ class RedisDocStatusStorage(DocStatusStorage): for id in ids: pipe.get(f"{self.namespace}:{id}") results = await pipe.execute() - + for result_data in results: if result_data: try: @@ -495,14 +506,16 @@ class RedisDocStatusStorage(DocStatusStorage): # Use SCAN to iterate through all keys in the namespace cursor = 0 while True: - cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000) + cursor, keys = await redis.scan( + cursor, match=f"{self.namespace}:*", count=1000 + ) if keys: # Get all values in batch pipe = redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + # Count statuses for value in values: if value: @@ -513,12 +526,12 @@ class RedisDocStatusStorage(DocStatusStorage): counts[status] += 1 except json.JSONDecodeError: continue - + if cursor == 0: break except Exception as e: logger.error(f"Error getting status counts: {e}") - + return counts async def get_docs_by_status( @@ -531,14 +544,16 @@ class RedisDocStatusStorage(DocStatusStorage): # Use SCAN to iterate through all keys in the namespace cursor = 0 while True: - cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000) + cursor, keys = await redis.scan( + cursor, match=f"{self.namespace}:*", count=1000 + ) if keys: # Get all values in batch pipe = redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + # Filter by status and create DocProcessingStatus objects for key, value in zip(keys, values): if value: @@ -547,26 +562,31 @@ class RedisDocStatusStorage(DocStatusStorage): if doc_data.get("status") == status.value: # Extract document ID from key doc_id = key.split(":", 1)[1] - + # Make a copy of the data to avoid modifying the original data = doc_data.copy() # If content is missing, use content_summary as content - if "content" not in data and "content_summary" in data: + if ( + "content" not in data + and "content_summary" in data + ): data["content"] = data["content_summary"] # If file_path is not in data, use document id as file path if "file_path" not in data: data["file_path"] = "no-file-path" - + result[doc_id] = DocProcessingStatus(**data) except (json.JSONDecodeError, KeyError) as e: - logger.error(f"Error processing document {key}: {e}") + logger.error( + f"Error processing document {key}: {e}" + ) continue - + if cursor == 0: break except Exception as e: logger.error(f"Error getting docs by status: {e}") - + return result async def index_done_callback(self) -> None: @@ -577,7 +597,7 @@ class RedisDocStatusStorage(DocStatusStorage): """Insert or update document status data""" if not data: return - + logger.debug(f"Inserting {len(data)} records to {self.namespace}") async with self._get_redis_connection() as redis: try: @@ -602,15 +622,17 @@ class RedisDocStatusStorage(DocStatusStorage): """Delete specific records from storage by their IDs""" if not doc_ids: return - + async with self._get_redis_connection() as redis: pipe = redis.pipeline() for doc_id in doc_ids: pipe.delete(f"{self.namespace}:{doc_id}") - + results = await pipe.execute() deleted_count = sum(results) - logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}") + logger.info( + f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}" + ) async def drop(self) -> dict[str, str]: """Drop all document status data from storage and clean up resources""" @@ -620,7 +642,7 @@ class RedisDocStatusStorage(DocStatusStorage): pattern = f"{self.namespace}:*" cursor = 0 deleted_count = 0 - + while True: cursor, keys = await redis.scan(cursor, match=pattern, count=1000) if keys: @@ -630,11 +652,13 @@ class RedisDocStatusStorage(DocStatusStorage): pipe.delete(key) results = await pipe.execute() deleted_count += sum(results) - + if cursor == 0: break - logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}") + logger.info( + f"Dropped {deleted_count} doc status keys from {self.namespace}" + ) return {"status": "success", "message": "data dropped"} except Exception as e: logger.error(f"Error dropping doc status {self.namespace}: {e}")