mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-08-04 06:41:55 +00:00

This commit adds multiprocessing shared memory support to file-based storage implementations: - JsonDocStatusStorage - JsonKVStorage - NanoVectorDBStorage - NetworkXStorage Each storage module now uses module-level global variables with multiprocessing.Manager() to ensure data consistency across multiple uvicorn workers. All processes will see updates immediately when data is modified through ainsert function.
128 lines
4.4 KiB
Python
128 lines
4.4 KiB
Python
from dataclasses import dataclass
|
|
import os
|
|
from typing import Any, Union, final
|
|
import threading
|
|
from multiprocessing import Manager
|
|
|
|
from lightrag.base import (
|
|
DocProcessingStatus,
|
|
DocStatus,
|
|
DocStatusStorage,
|
|
)
|
|
from lightrag.utils import (
|
|
load_json,
|
|
logger,
|
|
write_json,
|
|
)
|
|
|
|
# Global variables for shared memory management
|
|
_init_lock = threading.Lock()
|
|
_manager = None
|
|
_shared_doc_status_data = None
|
|
|
|
|
|
def _get_manager():
|
|
"""Get or create the global manager instance"""
|
|
global _manager, _shared_doc_status_data
|
|
with _init_lock:
|
|
if _manager is None:
|
|
try:
|
|
_manager = Manager()
|
|
_shared_doc_status_data = _manager.dict()
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize shared memory manager: {e}")
|
|
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
|
return _manager
|
|
|
|
|
|
@final
|
|
@dataclass
|
|
class JsonDocStatusStorage(DocStatusStorage):
|
|
"""JSON implementation of document status storage"""
|
|
|
|
def __post_init__(self):
|
|
working_dir = self.global_config["working_dir"]
|
|
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
|
|
|
# Ensure manager is initialized
|
|
_get_manager()
|
|
|
|
# Get or create namespace data
|
|
if self.namespace not in _shared_doc_status_data:
|
|
with _init_lock:
|
|
if self.namespace not in _shared_doc_status_data:
|
|
try:
|
|
initial_data = load_json(self._file_name) or {}
|
|
_shared_doc_status_data[self.namespace] = initial_data
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
|
|
raise RuntimeError(f"Shared data initialization failed: {e}")
|
|
|
|
try:
|
|
self._data = _shared_doc_status_data[self.namespace]
|
|
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
|
except Exception as e:
|
|
logger.error(f"Failed to access shared memory: {e}")
|
|
raise RuntimeError(f"Cannot access shared memory: {e}")
|
|
|
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
|
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
|
return set(keys) - set(self._data.keys())
|
|
|
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
result: list[dict[str, Any]] = []
|
|
for id in ids:
|
|
data = self._data.get(id, None)
|
|
if data:
|
|
result.append(data)
|
|
return result
|
|
|
|
async def get_status_counts(self) -> dict[str, int]:
|
|
"""Get counts of documents in each status"""
|
|
counts = {status.value: 0 for status in DocStatus}
|
|
for doc in self._data.values():
|
|
counts[doc["status"]] += 1
|
|
return counts
|
|
|
|
async def get_docs_by_status(
|
|
self, status: DocStatus
|
|
) -> dict[str, DocProcessingStatus]:
|
|
"""Get all documents with a specific status"""
|
|
result = {}
|
|
for k, v in self._data.items():
|
|
if v["status"] == status.value:
|
|
try:
|
|
# Make a copy of the data to avoid modifying the original
|
|
data = v.copy()
|
|
# If content is missing, use content_summary as content
|
|
if "content" not in data and "content_summary" in data:
|
|
data["content"] = data["content_summary"]
|
|
result[k] = DocProcessingStatus(**data)
|
|
except KeyError as e:
|
|
logger.error(f"Missing required field for document {k}: {e}")
|
|
continue
|
|
return result
|
|
|
|
async def index_done_callback(self) -> None:
|
|
write_json(self._data, self._file_name)
|
|
|
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
|
if not data:
|
|
return
|
|
|
|
self._data.update(data)
|
|
await self.index_done_callback()
|
|
|
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
|
return self._data.get(id)
|
|
|
|
async def delete(self, doc_ids: list[str]):
|
|
for doc_id in doc_ids:
|
|
self._data.pop(doc_id, None)
|
|
await self.index_done_callback()
|
|
|
|
async def drop(self) -> None:
|
|
"""Drop the storage"""
|
|
self._data.clear()
|