mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-08-06 15:51:48 +00:00
Fix linting
This commit is contained in:
parent
271722405f
commit
86c9a0cda2
@ -603,7 +603,7 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
results = await self.db.query(sql, params, multirows=True)
|
results = await self.db.query(sql, params, multirows=True)
|
||||||
|
|
||||||
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
processed_results = {}
|
processed_results = {}
|
||||||
@ -611,19 +611,21 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
# Parse flattened key to extract cache_type
|
# Parse flattened key to extract cache_type
|
||||||
key_parts = row["id"].split(":")
|
key_parts = row["id"].split(":")
|
||||||
cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
|
cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
|
||||||
|
|
||||||
# Map field names and add cache_type for compatibility
|
# Map field names and add cache_type for compatibility
|
||||||
processed_row = {
|
processed_row = {
|
||||||
**row,
|
**row,
|
||||||
"return": row.get("return_value", ""), # Map return_value to return
|
"return": row.get(
|
||||||
|
"return_value", ""
|
||||||
|
), # Map return_value to return
|
||||||
"cache_type": cache_type, # Add cache_type from key
|
"cache_type": cache_type, # Add cache_type from key
|
||||||
"original_prompt": row.get("original_prompt", ""),
|
"original_prompt": row.get("original_prompt", ""),
|
||||||
"chunk_id": row.get("chunk_id"),
|
"chunk_id": row.get("chunk_id"),
|
||||||
"mode": row.get("mode", "default")
|
"mode": row.get("mode", "default"),
|
||||||
}
|
}
|
||||||
processed_results[row["id"]] = processed_row
|
processed_results[row["id"]] = processed_row
|
||||||
return processed_results
|
return processed_results
|
||||||
|
|
||||||
# For other namespaces, return as-is
|
# For other namespaces, return as-is
|
||||||
return {row["id"]: row for row in results}
|
return {row["id"]: row for row in results}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -14,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
|
|||||||
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
|
|
||||||
from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus
|
from lightrag.base import (
|
||||||
|
BaseKVStorage,
|
||||||
|
DocStatusStorage,
|
||||||
|
DocStatus,
|
||||||
|
DocProcessingStatus,
|
||||||
|
)
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
@ -29,10 +34,10 @@ SOCKET_CONNECT_TIMEOUT = 3.0
|
|||||||
|
|
||||||
class RedisConnectionManager:
|
class RedisConnectionManager:
|
||||||
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
|
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
|
||||||
|
|
||||||
_pools = {}
|
_pools = {}
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_pool(cls, redis_url: str) -> ConnectionPool:
|
def get_pool(cls, redis_url: str) -> ConnectionPool:
|
||||||
"""Get or create a connection pool for the given Redis URL"""
|
"""Get or create a connection pool for the given Redis URL"""
|
||||||
@ -48,7 +53,7 @@ class RedisConnectionManager:
|
|||||||
)
|
)
|
||||||
logger.info(f"Created shared Redis connection pool for {redis_url}")
|
logger.info(f"Created shared Redis connection pool for {redis_url}")
|
||||||
return cls._pools[redis_url]
|
return cls._pools[redis_url]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def close_all_pools(cls):
|
def close_all_pools(cls):
|
||||||
"""Close all connection pools (for cleanup)"""
|
"""Close all connection pools (for cleanup)"""
|
||||||
@ -254,17 +259,21 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
pattern = f"{self.namespace}:{mode}:*"
|
pattern = f"{self.namespace}:{mode}:*"
|
||||||
cursor = 0
|
cursor = 0
|
||||||
mode_keys = []
|
mode_keys = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
cursor, keys = await redis.scan(
|
||||||
|
cursor, match=pattern, count=1000
|
||||||
|
)
|
||||||
if keys:
|
if keys:
|
||||||
mode_keys.extend(keys)
|
mode_keys.extend(keys)
|
||||||
|
|
||||||
if cursor == 0:
|
if cursor == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
keys_to_delete.extend(mode_keys)
|
keys_to_delete.extend(mode_keys)
|
||||||
logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'")
|
logger.info(
|
||||||
|
f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
|
||||||
|
)
|
||||||
|
|
||||||
if keys_to_delete:
|
if keys_to_delete:
|
||||||
# Batch delete
|
# Batch delete
|
||||||
@ -296,7 +305,7 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
pattern = f"{self.namespace}:*"
|
pattern = f"{self.namespace}:*"
|
||||||
cursor = 0
|
cursor = 0
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||||
if keys:
|
if keys:
|
||||||
@ -306,7 +315,7 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
pipe.delete(key)
|
pipe.delete(key)
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
deleted_count += sum(results)
|
deleted_count += sum(results)
|
||||||
|
|
||||||
if cursor == 0:
|
if cursor == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -419,7 +428,9 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
try:
|
try:
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
await redis.ping()
|
await redis.ping()
|
||||||
logger.info(f"Connected to Redis for doc status namespace {self.namespace}")
|
logger.info(
|
||||||
|
f"Connected to Redis for doc status namespace {self.namespace}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to Redis for doc status: {e}")
|
logger.error(f"Failed to connect to Redis for doc status: {e}")
|
||||||
raise
|
raise
|
||||||
@ -475,7 +486,7 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
for id in ids:
|
for id in ids:
|
||||||
pipe.get(f"{self.namespace}:{id}")
|
pipe.get(f"{self.namespace}:{id}")
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
|
|
||||||
for result_data in results:
|
for result_data in results:
|
||||||
if result_data:
|
if result_data:
|
||||||
try:
|
try:
|
||||||
@ -495,14 +506,16 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
# Use SCAN to iterate through all keys in the namespace
|
# Use SCAN to iterate through all keys in the namespace
|
||||||
cursor = 0
|
cursor = 0
|
||||||
while True:
|
while True:
|
||||||
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
|
cursor, keys = await redis.scan(
|
||||||
|
cursor, match=f"{self.namespace}:*", count=1000
|
||||||
|
)
|
||||||
if keys:
|
if keys:
|
||||||
# Get all values in batch
|
# Get all values in batch
|
||||||
pipe = redis.pipeline()
|
pipe = redis.pipeline()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
pipe.get(key)
|
pipe.get(key)
|
||||||
values = await pipe.execute()
|
values = await pipe.execute()
|
||||||
|
|
||||||
# Count statuses
|
# Count statuses
|
||||||
for value in values:
|
for value in values:
|
||||||
if value:
|
if value:
|
||||||
@ -513,12 +526,12 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
counts[status] += 1
|
counts[status] += 1
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cursor == 0:
|
if cursor == 0:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting status counts: {e}")
|
logger.error(f"Error getting status counts: {e}")
|
||||||
|
|
||||||
return counts
|
return counts
|
||||||
|
|
||||||
async def get_docs_by_status(
|
async def get_docs_by_status(
|
||||||
@ -531,14 +544,16 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
# Use SCAN to iterate through all keys in the namespace
|
# Use SCAN to iterate through all keys in the namespace
|
||||||
cursor = 0
|
cursor = 0
|
||||||
while True:
|
while True:
|
||||||
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
|
cursor, keys = await redis.scan(
|
||||||
|
cursor, match=f"{self.namespace}:*", count=1000
|
||||||
|
)
|
||||||
if keys:
|
if keys:
|
||||||
# Get all values in batch
|
# Get all values in batch
|
||||||
pipe = redis.pipeline()
|
pipe = redis.pipeline()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
pipe.get(key)
|
pipe.get(key)
|
||||||
values = await pipe.execute()
|
values = await pipe.execute()
|
||||||
|
|
||||||
# Filter by status and create DocProcessingStatus objects
|
# Filter by status and create DocProcessingStatus objects
|
||||||
for key, value in zip(keys, values):
|
for key, value in zip(keys, values):
|
||||||
if value:
|
if value:
|
||||||
@ -547,26 +562,31 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
if doc_data.get("status") == status.value:
|
if doc_data.get("status") == status.value:
|
||||||
# Extract document ID from key
|
# Extract document ID from key
|
||||||
doc_id = key.split(":", 1)[1]
|
doc_id = key.split(":", 1)[1]
|
||||||
|
|
||||||
# Make a copy of the data to avoid modifying the original
|
# Make a copy of the data to avoid modifying the original
|
||||||
data = doc_data.copy()
|
data = doc_data.copy()
|
||||||
# If content is missing, use content_summary as content
|
# If content is missing, use content_summary as content
|
||||||
if "content" not in data and "content_summary" in data:
|
if (
|
||||||
|
"content" not in data
|
||||||
|
and "content_summary" in data
|
||||||
|
):
|
||||||
data["content"] = data["content_summary"]
|
data["content"] = data["content_summary"]
|
||||||
# If file_path is not in data, use document id as file path
|
# If file_path is not in data, use document id as file path
|
||||||
if "file_path" not in data:
|
if "file_path" not in data:
|
||||||
data["file_path"] = "no-file-path"
|
data["file_path"] = "no-file-path"
|
||||||
|
|
||||||
result[doc_id] = DocProcessingStatus(**data)
|
result[doc_id] = DocProcessingStatus(**data)
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
logger.error(f"Error processing document {key}: {e}")
|
logger.error(
|
||||||
|
f"Error processing document {key}: {e}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cursor == 0:
|
if cursor == 0:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting docs by status: {e}")
|
logger.error(f"Error getting docs by status: {e}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
@ -577,7 +597,7 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
"""Insert or update document status data"""
|
"""Insert or update document status data"""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
try:
|
try:
|
||||||
@ -602,15 +622,17 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
"""Delete specific records from storage by their IDs"""
|
"""Delete specific records from storage by their IDs"""
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
pipe = redis.pipeline()
|
pipe = redis.pipeline()
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
pipe.delete(f"{self.namespace}:{doc_id}")
|
pipe.delete(f"{self.namespace}:{doc_id}")
|
||||||
|
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
deleted_count = sum(results)
|
deleted_count = sum(results)
|
||||||
logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}")
|
logger.info(
|
||||||
|
f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
|
||||||
|
)
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop all document status data from storage and clean up resources"""
|
"""Drop all document status data from storage and clean up resources"""
|
||||||
@ -620,7 +642,7 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
pattern = f"{self.namespace}:*"
|
pattern = f"{self.namespace}:*"
|
||||||
cursor = 0
|
cursor = 0
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||||
if keys:
|
if keys:
|
||||||
@ -630,11 +652,13 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||||||
pipe.delete(key)
|
pipe.delete(key)
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
deleted_count += sum(results)
|
deleted_count += sum(results)
|
||||||
|
|
||||||
if cursor == 0:
|
if cursor == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}")
|
logger.info(
|
||||||
|
f"Dropped {deleted_count} doc status keys from {self.namespace}"
|
||||||
|
)
|
||||||
return {"status": "success", "message": "data dropped"}
|
return {"status": "success", "message": "data dropped"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error dropping doc status {self.namespace}: {e}")
|
logger.error(f"Error dropping doc status {self.namespace}: {e}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user