Fix linting

This commit is contained in:
yangdx 2025-07-02 16:29:43 +08:00
parent 271722405f
commit 86c9a0cda2
2 changed files with 63 additions and 37 deletions

View File

@ -603,7 +603,7 @@ class PGKVStorage(BaseKVStorage):
try: try:
results = await self.db.query(sql, params, multirows=True) results = await self.db.query(sql, params, multirows=True)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
processed_results = {} processed_results = {}
@ -611,19 +611,21 @@ class PGKVStorage(BaseKVStorage):
# Parse flattened key to extract cache_type # Parse flattened key to extract cache_type
key_parts = row["id"].split(":") key_parts = row["id"].split(":")
cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown" cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
# Map field names and add cache_type for compatibility # Map field names and add cache_type for compatibility
processed_row = { processed_row = {
**row, **row,
"return": row.get("return_value", ""), # Map return_value to return "return": row.get(
"return_value", ""
), # Map return_value to return
"cache_type": cache_type, # Add cache_type from key "cache_type": cache_type, # Add cache_type from key
"original_prompt": row.get("original_prompt", ""), "original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"), "chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default") "mode": row.get("mode", "default"),
} }
processed_results[row["id"]] = processed_row processed_results[row["id"]] = processed_row
return processed_results return processed_results
# For other namespaces, return as-is # For other namespaces, return as-is
return {row["id"]: row for row in results} return {row["id"]: row for row in results}
except Exception as e: except Exception as e:

View File

@ -14,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
from redis.exceptions import RedisError, ConnectionError # type: ignore from redis.exceptions import RedisError, ConnectionError # type: ignore
from lightrag.utils import logger from lightrag.utils import logger
from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus from lightrag.base import (
BaseKVStorage,
DocStatusStorage,
DocStatus,
DocProcessingStatus,
)
import json import json
@ -29,10 +34,10 @@ SOCKET_CONNECT_TIMEOUT = 3.0
class RedisConnectionManager: class RedisConnectionManager:
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI""" """Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
_pools = {} _pools = {}
_lock = threading.Lock() _lock = threading.Lock()
@classmethod @classmethod
def get_pool(cls, redis_url: str) -> ConnectionPool: def get_pool(cls, redis_url: str) -> ConnectionPool:
"""Get or create a connection pool for the given Redis URL""" """Get or create a connection pool for the given Redis URL"""
@ -48,7 +53,7 @@ class RedisConnectionManager:
) )
logger.info(f"Created shared Redis connection pool for {redis_url}") logger.info(f"Created shared Redis connection pool for {redis_url}")
return cls._pools[redis_url] return cls._pools[redis_url]
@classmethod @classmethod
def close_all_pools(cls): def close_all_pools(cls):
"""Close all connection pools (for cleanup)""" """Close all connection pools (for cleanup)"""
@ -254,17 +259,21 @@ class RedisKVStorage(BaseKVStorage):
pattern = f"{self.namespace}:{mode}:*" pattern = f"{self.namespace}:{mode}:*"
cursor = 0 cursor = 0
mode_keys = [] mode_keys = []
while True: while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000) cursor, keys = await redis.scan(
cursor, match=pattern, count=1000
)
if keys: if keys:
mode_keys.extend(keys) mode_keys.extend(keys)
if cursor == 0: if cursor == 0:
break break
keys_to_delete.extend(mode_keys) keys_to_delete.extend(mode_keys)
logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'") logger.info(
f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
)
if keys_to_delete: if keys_to_delete:
# Batch delete # Batch delete
@ -296,7 +305,7 @@ class RedisKVStorage(BaseKVStorage):
pattern = f"{self.namespace}:*" pattern = f"{self.namespace}:*"
cursor = 0 cursor = 0
deleted_count = 0 deleted_count = 0
while True: while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000) cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
if keys: if keys:
@ -306,7 +315,7 @@ class RedisKVStorage(BaseKVStorage):
pipe.delete(key) pipe.delete(key)
results = await pipe.execute() results = await pipe.execute()
deleted_count += sum(results) deleted_count += sum(results)
if cursor == 0: if cursor == 0:
break break
@ -419,7 +428,9 @@ class RedisDocStatusStorage(DocStatusStorage):
try: try:
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
await redis.ping() await redis.ping()
logger.info(f"Connected to Redis for doc status namespace {self.namespace}") logger.info(
f"Connected to Redis for doc status namespace {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Redis for doc status: {e}") logger.error(f"Failed to connect to Redis for doc status: {e}")
raise raise
@ -475,7 +486,7 @@ class RedisDocStatusStorage(DocStatusStorage):
for id in ids: for id in ids:
pipe.get(f"{self.namespace}:{id}") pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute() results = await pipe.execute()
for result_data in results: for result_data in results:
if result_data: if result_data:
try: try:
@ -495,14 +506,16 @@ class RedisDocStatusStorage(DocStatusStorage):
# Use SCAN to iterate through all keys in the namespace # Use SCAN to iterate through all keys in the namespace
cursor = 0 cursor = 0
while True: while True:
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000) cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
)
if keys: if keys:
# Get all values in batch # Get all values in batch
pipe = redis.pipeline() pipe = redis.pipeline()
for key in keys: for key in keys:
pipe.get(key) pipe.get(key)
values = await pipe.execute() values = await pipe.execute()
# Count statuses # Count statuses
for value in values: for value in values:
if value: if value:
@ -513,12 +526,12 @@ class RedisDocStatusStorage(DocStatusStorage):
counts[status] += 1 counts[status] += 1
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
if cursor == 0: if cursor == 0:
break break
except Exception as e: except Exception as e:
logger.error(f"Error getting status counts: {e}") logger.error(f"Error getting status counts: {e}")
return counts return counts
async def get_docs_by_status( async def get_docs_by_status(
@ -531,14 +544,16 @@ class RedisDocStatusStorage(DocStatusStorage):
# Use SCAN to iterate through all keys in the namespace # Use SCAN to iterate through all keys in the namespace
cursor = 0 cursor = 0
while True: while True:
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000) cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
)
if keys: if keys:
# Get all values in batch # Get all values in batch
pipe = redis.pipeline() pipe = redis.pipeline()
for key in keys: for key in keys:
pipe.get(key) pipe.get(key)
values = await pipe.execute() values = await pipe.execute()
# Filter by status and create DocProcessingStatus objects # Filter by status and create DocProcessingStatus objects
for key, value in zip(keys, values): for key, value in zip(keys, values):
if value: if value:
@ -547,26 +562,31 @@ class RedisDocStatusStorage(DocStatusStorage):
if doc_data.get("status") == status.value: if doc_data.get("status") == status.value:
# Extract document ID from key # Extract document ID from key
doc_id = key.split(":", 1)[1] doc_id = key.split(":", 1)[1]
# Make a copy of the data to avoid modifying the original # Make a copy of the data to avoid modifying the original
data = doc_data.copy() data = doc_data.copy()
# If content is missing, use content_summary as content # If content is missing, use content_summary as content
if "content" not in data and "content_summary" in data: if (
"content" not in data
and "content_summary" in data
):
data["content"] = data["content_summary"] data["content"] = data["content_summary"]
# If file_path is not in data, use document id as file path # If file_path is not in data, use document id as file path
if "file_path" not in data: if "file_path" not in data:
data["file_path"] = "no-file-path" data["file_path"] = "no-file-path"
result[doc_id] = DocProcessingStatus(**data) result[doc_id] = DocProcessingStatus(**data)
except (json.JSONDecodeError, KeyError) as e: except (json.JSONDecodeError, KeyError) as e:
logger.error(f"Error processing document {key}: {e}") logger.error(
f"Error processing document {key}: {e}"
)
continue continue
if cursor == 0: if cursor == 0:
break break
except Exception as e: except Exception as e:
logger.error(f"Error getting docs by status: {e}") logger.error(f"Error getting docs by status: {e}")
return result return result
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
@ -577,7 +597,7 @@ class RedisDocStatusStorage(DocStatusStorage):
"""Insert or update document status data""" """Insert or update document status data"""
if not data: if not data:
return return
logger.debug(f"Inserting {len(data)} records to {self.namespace}") logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
@ -602,15 +622,17 @@ class RedisDocStatusStorage(DocStatusStorage):
"""Delete specific records from storage by their IDs""" """Delete specific records from storage by their IDs"""
if not doc_ids: if not doc_ids:
return return
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
pipe = redis.pipeline() pipe = redis.pipeline()
for doc_id in doc_ids: for doc_id in doc_ids:
pipe.delete(f"{self.namespace}:{doc_id}") pipe.delete(f"{self.namespace}:{doc_id}")
results = await pipe.execute() results = await pipe.execute()
deleted_count = sum(results) deleted_count = sum(results)
logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}") logger.info(
f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
)
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources""" """Drop all document status data from storage and clean up resources"""
@ -620,7 +642,7 @@ class RedisDocStatusStorage(DocStatusStorage):
pattern = f"{self.namespace}:*" pattern = f"{self.namespace}:*"
cursor = 0 cursor = 0
deleted_count = 0 deleted_count = 0
while True: while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000) cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
if keys: if keys:
@ -630,11 +652,13 @@ class RedisDocStatusStorage(DocStatusStorage):
pipe.delete(key) pipe.delete(key)
results = await pipe.execute() results = await pipe.execute()
deleted_count += sum(results) deleted_count += sum(results)
if cursor == 0: if cursor == 0:
break break
logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}") logger.info(
f"Dropped {deleted_count} doc status keys from {self.namespace}"
)
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error(f"Error dropping doc status {self.namespace}: {e}") logger.error(f"Error dropping doc status {self.namespace}: {e}")