2025-01-22 16:42:13 +08:00
|
|
|
import os
|
2025-07-02 16:11:53 +08:00
|
|
|
from typing import Any, final, Union
|
2025-01-22 16:42:13 +08:00
|
|
|
from dataclasses import dataclass
|
2025-01-27 09:39:58 +01:00
|
|
|
import pipmaster as pm
|
2025-02-11 00:55:52 +08:00
|
|
|
import configparser
|
2025-04-02 21:06:49 -07:00
|
|
|
from contextlib import asynccontextmanager
|
2025-07-02 16:11:53 +08:00
|
|
|
import threading
|
2025-01-27 23:21:34 +08:00
|
|
|
|
2025-01-27 09:39:58 +01:00
|
|
|
if not pm.is_installed("redis"):
|
|
|
|
pm.install("redis")
|
2025-01-25 00:55:07 +01:00
|
|
|
|
2025-01-25 00:11:00 +01:00
|
|
|
# aioredis is a depricated library, replaced with redis
|
2025-04-06 17:45:32 +08:00
|
|
|
from redis.asyncio import Redis, ConnectionPool # type: ignore
|
|
|
|
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
2025-04-06 17:42:13 +08:00
|
|
|
from lightrag.utils import logger
|
2025-04-05 15:27:59 -07:00
|
|
|
|
2025-07-02 16:29:43 +08:00
|
|
|
from lightrag.base import (
|
|
|
|
BaseKVStorage,
|
|
|
|
DocStatusStorage,
|
|
|
|
DocStatus,
|
|
|
|
DocProcessingStatus,
|
|
|
|
)
|
2025-01-22 16:42:13 +08:00
|
|
|
import json
|
|
|
|
|
|
|
|
|
2025-02-11 00:55:52 +08:00
|
|
|
config = configparser.ConfigParser()
|
|
|
|
config.read("config.ini", "utf-8")
|
|
|
|
|
2025-04-02 21:06:49 -07:00
|
|
|
# Constants for Redis connection pool
|
|
|
|
MAX_CONNECTIONS = 50
|
|
|
|
SOCKET_TIMEOUT = 5.0
|
|
|
|
SOCKET_CONNECT_TIMEOUT = 3.0
|
|
|
|
|
2025-02-16 15:54:54 +01:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
class RedisConnectionManager:
|
|
|
|
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
_pools = {}
|
|
|
|
_lock = threading.Lock()
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
@classmethod
|
|
|
|
def get_pool(cls, redis_url: str) -> ConnectionPool:
|
|
|
|
"""Get or create a connection pool for the given Redis URL"""
|
|
|
|
if redis_url not in cls._pools:
|
|
|
|
with cls._lock:
|
|
|
|
if redis_url not in cls._pools:
|
|
|
|
cls._pools[redis_url] = ConnectionPool.from_url(
|
|
|
|
redis_url,
|
|
|
|
max_connections=MAX_CONNECTIONS,
|
|
|
|
decode_responses=True,
|
|
|
|
socket_timeout=SOCKET_TIMEOUT,
|
|
|
|
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
|
|
|
)
|
|
|
|
logger.info(f"Created shared Redis connection pool for {redis_url}")
|
|
|
|
return cls._pools[redis_url]
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
@classmethod
|
|
|
|
def close_all_pools(cls):
|
|
|
|
"""Close all connection pools (for cleanup)"""
|
|
|
|
with cls._lock:
|
|
|
|
for url, pool in cls._pools.items():
|
|
|
|
try:
|
|
|
|
pool.disconnect()
|
|
|
|
logger.info(f"Closed Redis connection pool for {url}")
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error closing Redis pool for {url}: {e}")
|
|
|
|
cls._pools.clear()
|
|
|
|
|
|
|
|
|
2025-02-16 15:52:59 +01:00
|
|
|
@final
|
2025-01-22 16:42:13 +08:00
|
|
|
@dataclass
|
|
|
|
class RedisKVStorage(BaseKVStorage):
|
|
|
|
def __post_init__(self):
|
2025-02-11 03:29:40 +08:00
|
|
|
redis_url = os.environ.get(
|
|
|
|
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
|
|
|
)
|
2025-07-02 16:11:53 +08:00
|
|
|
# Use shared connection pool
|
|
|
|
self._pool = RedisConnectionManager.get_pool(redis_url)
|
2025-04-02 21:06:49 -07:00
|
|
|
self._redis = Redis(connection_pool=self._pool)
|
2025-04-06 17:45:32 +08:00
|
|
|
logger.info(
|
2025-07-02 16:11:53 +08:00
|
|
|
f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
|
2025-04-06 17:45:32 +08:00
|
|
|
)
|
2025-04-02 21:06:49 -07:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
async def initialize(self):
|
|
|
|
"""Initialize Redis connection and migrate legacy cache structure if needed"""
|
|
|
|
# Test connection
|
|
|
|
try:
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
await redis.ping()
|
|
|
|
logger.info(f"Connected to Redis for namespace {self.namespace}")
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Failed to connect to Redis: {e}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
# Migrate legacy cache structure if this is a cache namespace
|
|
|
|
if self.namespace.endswith("_cache"):
|
|
|
|
await self._migrate_legacy_cache_structure()
|
|
|
|
|
2025-04-02 21:06:49 -07:00
|
|
|
@asynccontextmanager
|
|
|
|
async def _get_redis_connection(self):
|
|
|
|
"""Safe context manager for Redis operations."""
|
|
|
|
try:
|
|
|
|
yield self._redis
|
|
|
|
except ConnectionError as e:
|
|
|
|
logger.error(f"Redis connection error in {self.namespace}: {e}")
|
|
|
|
raise
|
|
|
|
except RedisError as e:
|
|
|
|
logger.error(f"Redis operation error in {self.namespace}: {e}")
|
|
|
|
raise
|
|
|
|
except Exception as e:
|
2025-04-06 17:45:32 +08:00
|
|
|
logger.error(
|
|
|
|
f"Unexpected error in Redis operation for {self.namespace}: {e}"
|
|
|
|
)
|
2025-04-02 21:06:49 -07:00
|
|
|
raise
|
|
|
|
|
|
|
|
async def close(self):
|
|
|
|
"""Close the Redis connection pool to prevent resource leaks."""
|
2025-04-06 17:45:32 +08:00
|
|
|
if hasattr(self, "_redis") and self._redis:
|
2025-04-02 21:06:49 -07:00
|
|
|
await self._redis.close()
|
|
|
|
await self._pool.disconnect()
|
|
|
|
logger.debug(f"Closed Redis connection pool for {self.namespace}")
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
"""Support for async context manager."""
|
|
|
|
return self
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
"""Ensure Redis resources are cleaned up when exiting context."""
|
|
|
|
await self.close()
|
2025-02-09 15:24:30 +01:00
|
|
|
|
2025-02-16 13:31:12 +01:00
|
|
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
2025-06-29 22:35:40 +08:00
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
|
|
|
data = await redis.get(f"{self.namespace}:{id}")
|
|
|
|
return json.loads(data) if data else None
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
logger.error(f"JSON decode error for id {id}: {e}")
|
|
|
|
return None
|
2025-01-22 16:42:13 +08:00
|
|
|
|
2025-02-09 10:33:15 +01:00
|
|
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
2025-04-02 21:06:49 -07:00
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for id in ids:
|
|
|
|
pipe.get(f"{self.namespace}:{id}")
|
|
|
|
results = await pipe.execute()
|
|
|
|
return [json.loads(result) if result else None for result in results]
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
logger.error(f"JSON decode error in batch get: {e}")
|
|
|
|
return [None] * len(ids)
|
2025-01-22 16:42:13 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
async def get_all(self) -> dict[str, Any]:
|
|
|
|
"""Get all data from storage
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dictionary containing all stored data
|
|
|
|
"""
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
|
|
|
# Get all keys for this namespace
|
|
|
|
keys = await redis.keys(f"{self.namespace}:*")
|
|
|
|
|
|
|
|
if not keys:
|
|
|
|
return {}
|
|
|
|
|
|
|
|
# Get all values in batch
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for key in keys:
|
|
|
|
pipe.get(key)
|
|
|
|
values = await pipe.execute()
|
|
|
|
|
|
|
|
# Build result dictionary
|
|
|
|
result = {}
|
|
|
|
for key, value in zip(keys, values):
|
|
|
|
if value:
|
|
|
|
# Extract the ID part (after namespace:)
|
|
|
|
key_id = key.split(":", 1)[1]
|
|
|
|
try:
|
|
|
|
result[key_id] = json.loads(value)
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
logger.error(f"JSON decode error for key {key}: {e}")
|
|
|
|
continue
|
|
|
|
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error getting all data from Redis: {e}")
|
|
|
|
return {}
|
|
|
|
|
2025-02-16 13:31:12 +01:00
|
|
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
2025-04-02 21:06:49 -07:00
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
pipe = redis.pipeline()
|
2025-07-02 16:11:53 +08:00
|
|
|
keys_list = list(keys) # Convert set to list for indexing
|
|
|
|
for key in keys_list:
|
2025-04-02 21:06:49 -07:00
|
|
|
pipe.exists(f"{self.namespace}:{key}")
|
|
|
|
results = await pipe.execute()
|
2025-01-22 16:42:13 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
2025-04-02 21:06:49 -07:00
|
|
|
return set(keys) - existing_ids
|
2025-01-22 16:42:13 +08:00
|
|
|
|
2025-02-16 13:31:12 +01:00
|
|
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
2025-02-19 22:22:41 +01:00
|
|
|
if not data:
|
|
|
|
return
|
2025-04-02 21:06:49 -07:00
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for k, v in data.items():
|
|
|
|
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
|
|
|
await pipe.execute()
|
|
|
|
|
|
|
|
for k in data:
|
|
|
|
data[k]["_id"] = k
|
|
|
|
except json.JSONEncodeError as e:
|
|
|
|
logger.error(f"JSON encode error during upsert: {e}")
|
|
|
|
raise
|
2025-04-06 17:45:32 +08:00
|
|
|
|
2025-02-16 13:31:12 +01:00
|
|
|
async def index_done_callback(self) -> None:
|
2025-02-16 16:04:07 +01:00
|
|
|
# Redis handles persistence automatically
|
2025-02-18 10:21:54 +01:00
|
|
|
pass
|
2025-04-06 17:45:32 +08:00
|
|
|
|
2025-03-04 15:50:53 +08:00
|
|
|
async def delete(self, ids: list[str]) -> None:
|
2025-04-02 21:06:49 -07:00
|
|
|
"""Delete entries with specified IDs"""
|
2025-03-04 15:50:53 +08:00
|
|
|
if not ids:
|
|
|
|
return
|
2025-03-04 15:53:20 +08:00
|
|
|
|
2025-04-02 21:06:49 -07:00
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for id in ids:
|
|
|
|
pipe.delete(f"{self.namespace}:{id}")
|
2025-03-04 15:53:20 +08:00
|
|
|
|
2025-04-02 21:06:49 -07:00
|
|
|
results = await pipe.execute()
|
|
|
|
deleted_count = sum(results)
|
|
|
|
logger.info(
|
|
|
|
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
|
|
|
|
)
|
2025-03-04 15:50:53 +08:00
|
|
|
|
2025-03-31 23:22:27 +08:00
|
|
|
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
2025-07-02 16:11:53 +08:00
|
|
|
"""Delete specific records from storage by cache mode
|
2025-03-04 15:53:20 +08:00
|
|
|
|
2025-03-31 23:10:21 +08:00
|
|
|
Importance notes for Redis storage:
|
|
|
|
1. This will immediately delete the specified cache modes from Redis
|
2025-03-04 15:53:20 +08:00
|
|
|
|
2025-03-31 23:10:21 +08:00
|
|
|
Args:
|
2025-07-02 16:11:53 +08:00
|
|
|
modes (list[str]): List of cache modes to be dropped from storage
|
2025-03-31 23:22:27 +08:00
|
|
|
|
2025-03-31 23:10:21 +08:00
|
|
|
Returns:
|
|
|
|
True: if the cache drop successfully
|
|
|
|
False: if the cache drop failed
|
|
|
|
"""
|
|
|
|
if not modes:
|
|
|
|
return False
|
2025-03-04 15:50:53 +08:00
|
|
|
|
|
|
|
try:
|
2025-07-02 16:11:53 +08:00
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
keys_to_delete = []
|
|
|
|
|
|
|
|
# Find matching keys for each mode using SCAN
|
|
|
|
for mode in modes:
|
|
|
|
# Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
|
|
|
|
pattern = f"{self.namespace}:{mode}:*"
|
|
|
|
cursor = 0
|
|
|
|
mode_keys = []
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
while True:
|
2025-07-02 16:29:43 +08:00
|
|
|
cursor, keys = await redis.scan(
|
|
|
|
cursor, match=pattern, count=1000
|
|
|
|
)
|
2025-07-02 16:11:53 +08:00
|
|
|
if keys:
|
|
|
|
mode_keys.extend(keys)
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
if cursor == 0:
|
|
|
|
break
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
keys_to_delete.extend(mode_keys)
|
2025-07-02 16:29:43 +08:00
|
|
|
logger.info(
|
|
|
|
f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
|
|
|
|
)
|
2025-07-02 16:11:53 +08:00
|
|
|
|
|
|
|
if keys_to_delete:
|
|
|
|
# Batch delete
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for key in keys_to_delete:
|
|
|
|
pipe.delete(key)
|
|
|
|
results = await pipe.execute()
|
|
|
|
deleted_count = sum(results)
|
|
|
|
logger.info(
|
|
|
|
f"Dropped {deleted_count} cache entries for modes: {modes}"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
logger.warning(f"No cache entries found for modes: {modes}")
|
|
|
|
|
2025-03-31 23:10:21 +08:00
|
|
|
return True
|
2025-07-02 16:11:53 +08:00
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error dropping cache by modes in Redis: {e}")
|
2025-03-31 23:10:21 +08:00
|
|
|
return False
|
2025-03-31 23:22:27 +08:00
|
|
|
|
2025-03-31 01:40:14 +08:00
|
|
|
async def drop(self) -> dict[str, str]:
|
|
|
|
"""Drop the storage by removing all keys under the current namespace.
|
2025-03-31 23:22:27 +08:00
|
|
|
|
2025-03-31 01:40:14 +08:00
|
|
|
Returns:
|
|
|
|
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
|
|
|
"""
|
2025-04-05 15:27:59 -07:00
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
2025-07-02 16:11:53 +08:00
|
|
|
# Use SCAN to find all keys with the namespace prefix
|
|
|
|
pattern = f"{self.namespace}:*"
|
|
|
|
cursor = 0
|
|
|
|
deleted_count = 0
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
while True:
|
|
|
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
|
|
|
if keys:
|
|
|
|
# Delete keys in batches
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for key in keys:
|
|
|
|
pipe.delete(key)
|
|
|
|
results = await pipe.execute()
|
|
|
|
deleted_count += sum(results)
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
if cursor == 0:
|
|
|
|
break
|
|
|
|
|
|
|
|
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
|
|
|
return {
|
|
|
|
"status": "success",
|
|
|
|
"message": f"{deleted_count} keys dropped",
|
|
|
|
}
|
2025-03-04 15:53:20 +08:00
|
|
|
|
2025-04-05 15:27:59 -07:00
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
|
|
|
return {"status": "error", "message": str(e)}
|
2025-07-02 16:11:53 +08:00
|
|
|
|
|
|
|
async def _migrate_legacy_cache_structure(self):
|
|
|
|
"""Migrate legacy nested cache structure to flattened structure for Redis
|
|
|
|
|
|
|
|
Redis already stores data in a flattened way, but we need to check for
|
|
|
|
legacy keys that might contain nested JSON structures and migrate them.
|
|
|
|
|
|
|
|
Early exit if any flattened key is found (indicating migration already done).
|
|
|
|
"""
|
|
|
|
from lightrag.utils import generate_cache_key
|
|
|
|
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
# Get all keys for this namespace
|
|
|
|
keys = await redis.keys(f"{self.namespace}:*")
|
|
|
|
|
|
|
|
if not keys:
|
|
|
|
return
|
|
|
|
|
|
|
|
# Check if we have any flattened keys already - if so, skip migration
|
|
|
|
has_flattened_keys = False
|
|
|
|
keys_to_migrate = []
|
|
|
|
|
|
|
|
for key in keys:
|
|
|
|
# Extract the ID part (after namespace:)
|
|
|
|
key_id = key.split(":", 1)[1]
|
|
|
|
|
|
|
|
# Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
|
|
|
|
if ":" in key_id and len(key_id.split(":")) == 3:
|
|
|
|
has_flattened_keys = True
|
|
|
|
break # Early exit - migration already done
|
|
|
|
|
|
|
|
# Get the data to check if it's a legacy nested structure
|
|
|
|
data = await redis.get(key)
|
|
|
|
if data:
|
|
|
|
try:
|
|
|
|
parsed_data = json.loads(data)
|
|
|
|
# Check if this looks like a legacy cache mode with nested structure
|
|
|
|
if isinstance(parsed_data, dict) and all(
|
|
|
|
isinstance(v, dict) and "return" in v
|
|
|
|
for v in parsed_data.values()
|
|
|
|
):
|
|
|
|
keys_to_migrate.append((key, key_id, parsed_data))
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
continue
|
|
|
|
|
|
|
|
# If we found any flattened keys, assume migration is already done
|
|
|
|
if has_flattened_keys:
|
|
|
|
logger.debug(
|
|
|
|
f"Found flattened cache keys in {self.namespace}, skipping migration"
|
|
|
|
)
|
|
|
|
return
|
|
|
|
|
|
|
|
if not keys_to_migrate:
|
|
|
|
return
|
|
|
|
|
|
|
|
# Perform migration
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
migration_count = 0
|
|
|
|
|
|
|
|
for old_key, mode, nested_data in keys_to_migrate:
|
|
|
|
# Delete the old key
|
|
|
|
pipe.delete(old_key)
|
|
|
|
|
|
|
|
# Create new flattened keys
|
|
|
|
for cache_hash, cache_entry in nested_data.items():
|
|
|
|
cache_type = cache_entry.get("cache_type", "extract")
|
|
|
|
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
|
|
|
full_key = f"{self.namespace}:{flattened_key}"
|
|
|
|
pipe.set(full_key, json.dumps(cache_entry))
|
|
|
|
migration_count += 1
|
|
|
|
|
|
|
|
await pipe.execute()
|
|
|
|
|
|
|
|
if migration_count > 0:
|
|
|
|
logger.info(
|
|
|
|
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@final
|
|
|
|
@dataclass
|
|
|
|
class RedisDocStatusStorage(DocStatusStorage):
|
|
|
|
"""Redis implementation of document status storage"""
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
redis_url = os.environ.get(
|
|
|
|
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
|
|
|
)
|
|
|
|
# Use shared connection pool
|
|
|
|
self._pool = RedisConnectionManager.get_pool(redis_url)
|
|
|
|
self._redis = Redis(connection_pool=self._pool)
|
|
|
|
logger.info(
|
|
|
|
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
|
|
|
|
)
|
|
|
|
|
|
|
|
async def initialize(self):
|
|
|
|
"""Initialize Redis connection"""
|
|
|
|
try:
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
await redis.ping()
|
2025-07-02 16:29:43 +08:00
|
|
|
logger.info(
|
|
|
|
f"Connected to Redis for doc status namespace {self.namespace}"
|
|
|
|
)
|
2025-07-02 16:11:53 +08:00
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Failed to connect to Redis for doc status: {e}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
async def _get_redis_connection(self):
|
|
|
|
"""Safe context manager for Redis operations."""
|
|
|
|
try:
|
|
|
|
yield self._redis
|
|
|
|
except ConnectionError as e:
|
|
|
|
logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
|
|
|
|
raise
|
|
|
|
except RedisError as e:
|
|
|
|
logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
|
|
|
|
raise
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(
|
|
|
|
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
|
|
|
|
)
|
|
|
|
raise
|
|
|
|
|
|
|
|
async def close(self):
|
|
|
|
"""Close the Redis connection."""
|
|
|
|
if hasattr(self, "_redis") and self._redis:
|
|
|
|
await self._redis.close()
|
|
|
|
logger.debug(f"Closed Redis connection for doc status {self.namespace}")
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
"""Support for async context manager."""
|
|
|
|
return self
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
"""Ensure Redis resources are cleaned up when exiting context."""
|
|
|
|
await self.close()
|
|
|
|
|
|
|
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
|
|
|
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
keys_list = list(keys)
|
|
|
|
for key in keys_list:
|
|
|
|
pipe.exists(f"{self.namespace}:{key}")
|
|
|
|
results = await pipe.execute()
|
|
|
|
|
|
|
|
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
|
|
|
return set(keys) - existing_ids
|
|
|
|
|
|
|
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
|
|
result: list[dict[str, Any]] = []
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for id in ids:
|
|
|
|
pipe.get(f"{self.namespace}:{id}")
|
|
|
|
results = await pipe.execute()
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
for result_data in results:
|
|
|
|
if result_data:
|
|
|
|
try:
|
|
|
|
result.append(json.loads(result_data))
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
logger.error(f"JSON decode error in get_by_ids: {e}")
|
|
|
|
continue
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error in get_by_ids: {e}")
|
|
|
|
return result
|
|
|
|
|
|
|
|
async def get_status_counts(self) -> dict[str, int]:
|
|
|
|
"""Get counts of documents in each status"""
|
|
|
|
counts = {status.value: 0 for status in DocStatus}
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
|
|
|
# Use SCAN to iterate through all keys in the namespace
|
|
|
|
cursor = 0
|
|
|
|
while True:
|
2025-07-02 16:29:43 +08:00
|
|
|
cursor, keys = await redis.scan(
|
|
|
|
cursor, match=f"{self.namespace}:*", count=1000
|
|
|
|
)
|
2025-07-02 16:11:53 +08:00
|
|
|
if keys:
|
|
|
|
# Get all values in batch
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for key in keys:
|
|
|
|
pipe.get(key)
|
|
|
|
values = await pipe.execute()
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
# Count statuses
|
|
|
|
for value in values:
|
|
|
|
if value:
|
|
|
|
try:
|
|
|
|
doc_data = json.loads(value)
|
|
|
|
status = doc_data.get("status")
|
|
|
|
if status in counts:
|
|
|
|
counts[status] += 1
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
continue
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
if cursor == 0:
|
|
|
|
break
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error getting status counts: {e}")
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
return counts
|
|
|
|
|
|
|
|
async def get_docs_by_status(
|
|
|
|
self, status: DocStatus
|
|
|
|
) -> dict[str, DocProcessingStatus]:
|
|
|
|
"""Get all documents with a specific status"""
|
|
|
|
result = {}
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
|
|
|
# Use SCAN to iterate through all keys in the namespace
|
|
|
|
cursor = 0
|
|
|
|
while True:
|
2025-07-02 16:29:43 +08:00
|
|
|
cursor, keys = await redis.scan(
|
|
|
|
cursor, match=f"{self.namespace}:*", count=1000
|
|
|
|
)
|
2025-07-02 16:11:53 +08:00
|
|
|
if keys:
|
|
|
|
# Get all values in batch
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for key in keys:
|
|
|
|
pipe.get(key)
|
|
|
|
values = await pipe.execute()
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
# Filter by status and create DocProcessingStatus objects
|
|
|
|
for key, value in zip(keys, values):
|
|
|
|
if value:
|
|
|
|
try:
|
|
|
|
doc_data = json.loads(value)
|
|
|
|
if doc_data.get("status") == status.value:
|
|
|
|
# Extract document ID from key
|
|
|
|
doc_id = key.split(":", 1)[1]
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
# Make a copy of the data to avoid modifying the original
|
|
|
|
data = doc_data.copy()
|
|
|
|
# If content is missing, use content_summary as content
|
2025-07-02 16:29:43 +08:00
|
|
|
if (
|
|
|
|
"content" not in data
|
|
|
|
and "content_summary" in data
|
|
|
|
):
|
2025-07-02 16:11:53 +08:00
|
|
|
data["content"] = data["content_summary"]
|
|
|
|
# If file_path is not in data, use document id as file path
|
|
|
|
if "file_path" not in data:
|
|
|
|
data["file_path"] = "no-file-path"
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
result[doc_id] = DocProcessingStatus(**data)
|
|
|
|
except (json.JSONDecodeError, KeyError) as e:
|
2025-07-02 16:29:43 +08:00
|
|
|
logger.error(
|
|
|
|
f"Error processing document {key}: {e}"
|
|
|
|
)
|
2025-07-02 16:11:53 +08:00
|
|
|
continue
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
if cursor == 0:
|
|
|
|
break
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error getting docs by status: {e}")
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
return result
|
|
|
|
|
|
|
|
async def index_done_callback(self) -> None:
|
|
|
|
"""Redis handles persistence automatically"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
|
"""Insert or update document status data"""
|
|
|
|
if not data:
|
|
|
|
return
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for k, v in data.items():
|
|
|
|
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
|
|
|
await pipe.execute()
|
|
|
|
except json.JSONEncodeError as e:
|
|
|
|
logger.error(f"JSON encode error during upsert: {e}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
try:
|
|
|
|
data = await redis.get(f"{self.namespace}:{id}")
|
|
|
|
return json.loads(data) if data else None
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
logger.error(f"JSON decode error for id {id}: {e}")
|
|
|
|
return None
|
|
|
|
|
|
|
|
async def delete(self, doc_ids: list[str]) -> None:
|
|
|
|
"""Delete specific records from storage by their IDs"""
|
|
|
|
if not doc_ids:
|
|
|
|
return
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for doc_id in doc_ids:
|
|
|
|
pipe.delete(f"{self.namespace}:{doc_id}")
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
results = await pipe.execute()
|
|
|
|
deleted_count = sum(results)
|
2025-07-02 16:29:43 +08:00
|
|
|
logger.info(
|
|
|
|
f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
|
|
|
|
)
|
2025-07-02 16:11:53 +08:00
|
|
|
|
|
|
|
async def drop(self) -> dict[str, str]:
|
|
|
|
"""Drop all document status data from storage and clean up resources"""
|
|
|
|
try:
|
|
|
|
async with self._get_redis_connection() as redis:
|
|
|
|
# Use SCAN to find all keys with the namespace prefix
|
|
|
|
pattern = f"{self.namespace}:*"
|
|
|
|
cursor = 0
|
|
|
|
deleted_count = 0
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
while True:
|
|
|
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
|
|
|
if keys:
|
|
|
|
# Delete keys in batches
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
for key in keys:
|
|
|
|
pipe.delete(key)
|
|
|
|
results = await pipe.execute()
|
|
|
|
deleted_count += sum(results)
|
2025-07-02 16:29:43 +08:00
|
|
|
|
2025-07-02 16:11:53 +08:00
|
|
|
if cursor == 0:
|
|
|
|
break
|
|
|
|
|
2025-07-02 16:29:43 +08:00
|
|
|
logger.info(
|
|
|
|
f"Dropped {deleted_count} doc status keys from {self.namespace}"
|
|
|
|
)
|
2025-07-02 16:11:53 +08:00
|
|
|
return {"status": "success", "message": "data dropped"}
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error dropping doc status {self.namespace}: {e}")
|
|
|
|
return {"status": "error", "message": str(e)}
|