Initial commit with keyed graph lock

This commit is contained in:
Arjun Rao 2025-05-08 11:35:10 +10:00
parent b7eae4d7c0
commit f8149790e4
3 changed files with 405 additions and 140 deletions

View File

@ -1,9 +1,12 @@
from collections import defaultdict
import os
import sys
import asyncio
import multiprocessing as mp
from multiprocessing.synchronize import Lock as ProcessLock
from multiprocessing import Manager
from typing import Any, Dict, Optional, Union, TypeVar, Generic
import time
from typing import Any, Callable, Dict, List, Optional, Union, TypeVar, Generic
# Define a direct print function for critical logs that must be visible in all processes
@ -27,8 +30,14 @@ LockType = Union[ProcessLock, asyncio.Lock]
_is_multiprocess = None
_workers = None
_manager = None
_lock_registry: Optional[Dict[str, mp.synchronize.Lock]] = None
_lock_registry_count: Optional[Dict[str, int]] = None
_lock_cleanup_data: Optional[Dict[str, time.time]] = None
_registry_guard = None
_initialized = None
CLEANUP_KEYED_LOCKS_AFTER_SECONDS = 300
# shared data for storage across processes
_shared_dicts: Optional[Dict[str, Any]] = None
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
@ -40,10 +49,31 @@ _internal_lock: Optional[LockType] = None
_pipeline_status_lock: Optional[LockType] = None
_graph_db_lock: Optional[LockType] = None
_data_init_lock: Optional[LockType] = None
_graph_db_lock_keyed: Optional["KeyedUnifiedLock"] = None
# async locks for coroutine synchronization in multiprocess mode
_async_locks: Optional[Dict[str, asyncio.Lock]] = None
DEBUG_LOCKS = False
_debug_n_locks_acquired: int = 0
def inc_debug_n_locks_acquired():
global _debug_n_locks_acquired
if DEBUG_LOCKS:
_debug_n_locks_acquired += 1
print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True)
def dec_debug_n_locks_acquired():
global _debug_n_locks_acquired
if DEBUG_LOCKS:
if _debug_n_locks_acquired > 0:
_debug_n_locks_acquired -= 1
print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True)
else:
raise RuntimeError("Attempting to release lock when no locks are acquired")
def get_debug_n_locks_acquired():
global _debug_n_locks_acquired
return _debug_n_locks_acquired
class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
@ -210,6 +240,207 @@ class UnifiedLock(Generic[T]):
)
raise
def locked(self) -> bool:
if self._is_async:
return self._lock.locked()
else:
return self._lock.locked()
# ─────────────────────────────────────────────────────────────────────────────
# 2. CROSSPROCESS FACTORY (one manager.Lock shared by *all* processes)
# ─────────────────────────────────────────────────────────────────────────────
def _get_combined_key(factory_name: str, key: str) -> str:
"""Return the combined key for the factory and key."""
return f"{factory_name}:{key}"
def _get_or_create_shared_raw_mp_lock(factory_name: str, key: str) -> Optional[mp.synchronize.Lock]:
"""Return the *singleton* manager.Lock() proxy for *key*, creating if needed."""
if not _is_multiprocess:
return None
with _registry_guard:
combined_key = _get_combined_key(factory_name, key)
raw = _lock_registry.get(combined_key)
count = _lock_registry_count.get(combined_key)
if raw is None:
raw = _manager.Lock()
_lock_registry[combined_key] = raw
_lock_registry_count[combined_key] = 0
else:
if count is None:
raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}")
count += 1
_lock_registry_count[combined_key] = count
if count == 1 and combined_key in _lock_cleanup_data:
_lock_cleanup_data.pop(combined_key)
return raw
def _release_shared_raw_mp_lock(factory_name: str, key: str):
"""Release the *singleton* manager.Lock() proxy for *key*."""
if not _is_multiprocess:
return
with _registry_guard:
combined_key = _get_combined_key(factory_name, key)
raw = _lock_registry.get(combined_key)
count = _lock_registry_count.get(combined_key)
if raw is None and count is None:
return
elif raw is None or count is None:
raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}")
count -= 1
if count < 0:
raise RuntimeError(f"Attempting to remove lock for {key} but it is not in the registry")
else:
_lock_registry_count[combined_key] = count
if count == 0:
_lock_cleanup_data[combined_key] = time.time()
for combined_key, value in list(_lock_cleanup_data.items()):
if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS:
_lock_registry.pop(combined_key)
_lock_registry_count.pop(combined_key)
_lock_cleanup_data.pop(combined_key)
# ─────────────────────────────────────────────────────────────────────────────
# 3. PARAMETERKEYED WRAPPER (unchanged except it *accepts a factory*)
# ─────────────────────────────────────────────────────────────────────────────
class KeyedUnifiedLock:
"""
Parameterkeyed wrapper around `UnifiedLock`.
Keeps only a table of perkey *asyncio* gates locally
Fetches the shared processwide mutex on *every* acquire
Builds a fresh `UnifiedLock` each time, so `enable_logging`
(or future options) can vary per call.
"""
# ---------------- construction ----------------
def __init__(self, factory_name: str, *, default_enable_logging: bool = True) -> None:
self._factory_name = factory_name
self._default_enable_logging = default_enable_logging
self._async_lock: Dict[str, asyncio.Lock] = {} # key → asyncio.Lock
self._async_lock_count: Dict[str, int] = {} # key → asyncio.Lock count
self._async_lock_cleanup_data: Dict[str, time.time] = {} # key → time.time
self._mp_locks: Dict[str, mp.synchronize.Lock] = {} # key → mp.synchronize.Lock
# ---------------- public API ------------------
def __call__(self, keys: list[str], *, enable_logging: Optional[bool] = None):
"""
Ergonomic helper so you can write:
async with keyed_locks("alpha"):
...
"""
if enable_logging is None:
enable_logging = self._default_enable_logging
return _KeyedLockContext(self, factory_name=self._factory_name, keys=keys, enable_logging=enable_logging)
def _get_or_create_async_lock(self, key: str) -> asyncio.Lock:
async_lock = self._async_lock.get(key)
count = self._async_lock_count.get(key, 0)
if async_lock is None:
async_lock = asyncio.Lock()
self._async_lock[key] = async_lock
elif count == 0 and key in self._async_lock_cleanup_data:
self._async_lock_cleanup_data.pop(key)
count += 1
self._async_lock_count[key] = count
return async_lock
def _release_async_lock(self, key: str):
count = self._async_lock_count.get(key, 0)
count -= 1
if count == 0:
self._async_lock_cleanup_data[key] = time.time()
self._async_lock_count[key] = count
for key, value in list(self._async_lock_cleanup_data.items()):
if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS:
self._async_lock.pop(key)
self._async_lock_count.pop(key)
self._async_lock_cleanup_data.pop(key)
def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock:
# 1. get (or create) the perprocess async gate for this key
# Is synchronous, so no need to acquire a lock
async_lock = self._get_or_create_async_lock(key)
# 2. fetch the shared raw lock
raw_lock = _get_or_create_shared_raw_mp_lock(self._factory_name, key)
is_multiprocess = raw_lock is not None
if not is_multiprocess:
raw_lock = async_lock
# 3. build a *fresh* UnifiedLock with the chosen logging flag
if is_multiprocess:
return UnifiedLock(
lock=raw_lock,
is_async=False, # manager.Lock is synchronous
name=f"key:{self._factory_name}:{key}",
enable_logging=enable_logging,
async_lock=async_lock, # prevents eventloop blocking
)
else:
return UnifiedLock(
lock=raw_lock,
is_async=True,
name=f"key:{self._factory_name}:{key}",
enable_logging=enable_logging,
async_lock=None, # No need for async lock in single process mode
)
def _release_lock_for_key(self, key: str):
self._release_async_lock(key)
_release_shared_raw_mp_lock(self._factory_name, key)
class _KeyedLockContext:
def __init__(
self,
parent: KeyedUnifiedLock,
factory_name: str,
keys: list[str],
enable_logging: bool,
) -> None:
self._parent = parent
self._factory_name = factory_name
# The sorting is critical to ensure proper lock and release order
# to avoid deadlocks
self._keys = sorted(keys)
self._enable_logging = (
enable_logging if enable_logging is not None
else parent._default_enable_logging
)
self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__
# ----- enter -----
async def __aenter__(self):
if self._ul is not None:
raise RuntimeError("KeyedUnifiedLock already acquired in current context")
# 4. acquire it
self._ul = []
for key in self._keys:
lock = self._parent._get_lock_for_key(key, enable_logging=self._enable_logging)
await lock.__aenter__()
inc_debug_n_locks_acquired()
self._ul.append(lock)
return self # or return self._key if you prefer
# ----- exit -----
async def __aexit__(self, exc_type, exc, tb):
# The UnifiedLock takes care of proper release order
for ul, key in zip(reversed(self._ul), reversed(self._keys)):
await ul.__aexit__(exc_type, exc, tb)
self._parent._release_lock_for_key(key)
dec_debug_n_locks_acquired()
self._ul = None
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
@ -258,6 +489,14 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
async_lock=async_lock,
)
def get_graph_db_lock_keyed(keys: str | list[str], enable_logging: bool = False) -> KeyedUnifiedLock:
"""return unified graph database lock for ensuring atomic operations"""
global _graph_db_lock_keyed
if _graph_db_lock_keyed is None:
raise RuntimeError("Shared-Data is not initialized")
if isinstance(keys, str):
keys = [keys]
return _graph_db_lock_keyed(keys, enable_logging=enable_logging)
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified data initialization lock for ensuring atomic data initialization"""
@ -294,6 +533,10 @@ def initialize_share_data(workers: int = 1):
_workers, \
_is_multiprocess, \
_storage_lock, \
_lock_registry, \
_lock_registry_count, \
_lock_cleanup_data, \
_registry_guard, \
_internal_lock, \
_pipeline_status_lock, \
_graph_db_lock, \
@ -302,7 +545,8 @@ def initialize_share_data(workers: int = 1):
_init_flags, \
_initialized, \
_update_flags, \
_async_locks
_async_locks, \
_graph_db_lock_keyed
# Check if already initialized
if _initialized:
@ -316,6 +560,10 @@ def initialize_share_data(workers: int = 1):
if workers > 1:
_is_multiprocess = True
_manager = Manager()
_lock_registry = _manager.dict()
_lock_registry_count = _manager.dict()
_lock_cleanup_data = _manager.dict()
_registry_guard = _manager.RLock()
_internal_lock = _manager.Lock()
_storage_lock = _manager.Lock()
_pipeline_status_lock = _manager.Lock()
@ -324,6 +572,10 @@ def initialize_share_data(workers: int = 1):
_shared_dicts = _manager.dict()
_init_flags = _manager.dict()
_update_flags = _manager.dict()
_graph_db_lock_keyed = KeyedUnifiedLock(
factory_name="graph_db_lock",
)
# Initialize async locks for multiprocess mode
_async_locks = {
@ -348,6 +600,10 @@ def initialize_share_data(workers: int = 1):
_init_flags = {}
_update_flags = {}
_async_locks = None # No need for async locks in single process mode
_graph_db_lock_keyed = KeyedUnifiedLock(
factory_name="graph_db_lock",
)
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
# Mark as initialized

View File

@ -1024,73 +1024,73 @@ class LightRAG:
}
)
# Semphore was released here
# Semphore is NOT released here, however, the profile context is
if file_extraction_stage_ok:
try:
# Get chunk_results from entity_relation_task
chunk_results = await entity_relation_task
await merge_nodes_and_edges(
chunk_results=chunk_results, # result collected from entity_relation_task
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,
current_file_number=current_file_number,
total_files=total_files,
file_path=file_path,
)
if file_extraction_stage_ok:
try:
# Get chunk_results from entity_relation_task
chunk_results = await entity_relation_task
await merge_nodes_and_edges(
chunk_results=chunk_results, # result collected from entity_relation_task
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,
current_file_number=current_file_number,
total_files=total_files,
file_path=file_path,
)
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
}
}
}
)
)
# Call _insert_done after processing each file
await self._insert_done()
# Call _insert_done after processing each file
await self._insert_done()
async with pipeline_status_lock:
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
async with pipeline_status_lock:
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
# Log error and update pipeline status
error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}"
logger.error(error_msg)
async with pipeline_status_lock:
pipeline_status["latest_message"] = error_msg
pipeline_status["history_messages"].append(error_msg)
except Exception as e:
# Log error and update pipeline status
error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}"
logger.error(error_msg)
async with pipeline_status_lock:
pipeline_status["latest_message"] = error_msg
pipeline_status["history_messages"].append(error_msg)
# Update document status to failed
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
# Update document status to failed
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
}
}
}
)
)
# Create processing tasks for all documents
doc_tasks = []

View File

@ -9,6 +9,8 @@ import os
from typing import Any, AsyncIterator
from collections import Counter, defaultdict
from .kg.shared_storage import get_graph_db_lock_keyed
from .utils import (
logger,
clean_str,
@ -403,27 +405,31 @@ async def _merge_edges_then_upsert(
)
for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)):
# # Discard this edge if the node does not exist
# if need_insert_id == src_id:
# logger.warning(
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
# )
# else:
# logger.warning(
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
# )
# return None
await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"entity_id": need_insert_id,
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
"file_path": file_path,
},
)
if (await knowledge_graph_inst.has_node(need_insert_id)):
# This is so that the initial check for the existence of the node need not be locked
continue
async with get_graph_db_lock_keyed([need_insert_id], enable_logging=False):
if not (await knowledge_graph_inst.has_node(need_insert_id)):
# # Discard this edge if the node does not exist
# if need_insert_id == src_id:
# logger.warning(
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
# )
# else:
# logger.warning(
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
# )
# return None
await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"entity_id": need_insert_id,
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
"file_path": file_path,
},
)
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
@ -523,23 +529,30 @@ async def merge_nodes_and_edges(
all_edges[sorted_edge_key].extend(edges)
# Centralized processing of all nodes and edges
entities_data = []
relationships_data = []
total_entities_count = len(all_nodes)
total_relations_count = len(all_edges)
# Merge nodes and edges
# Use graph database lock to ensure atomic merges and updates
graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock:
async with pipeline_status_lock:
log_message = (
f"Merging stage {current_file_number}/{total_files}: {file_path}"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Process and update all entities at once
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
log_message = (
f"Merging stage {current_file_number}/{total_files}: {file_path}"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Process and update all entities at once
for entity_name, entities in all_nodes.items():
async def _locked_process_entity_name(entity_name, entities):
async with get_graph_db_lock_keyed([entity_name], enable_logging=False):
entity_data = await _merge_nodes_then_upsert(
entity_name,
entities,
@ -549,10 +562,34 @@ async def merge_nodes_and_edges(
pipeline_status_lock,
llm_response_cache,
)
entities_data.append(entity_data)
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): {
"entity_name": entity_data["entity_name"],
"entity_type": entity_data["entity_type"],
"content": f"{entity_data['entity_name']}\n{entity_data['description']}",
"source_id": entity_data["source_id"],
"file_path": entity_data.get("file_path", "unknown_source"),
}
}
await entity_vdb.upsert(data_for_vdb)
return entity_data
# Process and update all relationships at once
for edge_key, edges in all_edges.items():
tasks = []
for entity_name, entities in all_nodes.items():
tasks.append(asyncio.create_task(_locked_process_entity_name(entity_name, entities)))
await asyncio.gather(*tasks)
# Process and update all relationships at once
log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
async def _locked_process_edges(edge_key, edges):
async with get_graph_db_lock_keyed(f"{edge_key[0]}-{edge_key[1]}", enable_logging=False):
edge_data = await _merge_edges_then_upsert(
edge_key[0],
edge_key[1],
@ -563,55 +600,27 @@ async def merge_nodes_and_edges(
pipeline_status_lock,
llm_response_cache,
)
if edge_data is not None:
relationships_data.append(edge_data)
if edge_data is None:
return None
# Update total counts
total_entities_count = len(entities_data)
total_relations_count = len(relationships_data)
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Update vector databases with all collected data
if entity_vdb is not None and entities_data:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
if relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"): {
"src_id": edge_data["src_id"],
"tgt_id": edge_data["tgt_id"],
"keywords": edge_data["keywords"],
"content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}",
"source_id": edge_data["source_id"],
"file_path": edge_data.get("file_path", "unknown_source"),
}
}
for dp in entities_data
}
await entity_vdb.upsert(data_for_vdb)
log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
if relationships_vdb is not None and relationships_data:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
await relationships_vdb.upsert(data_for_vdb)
return edge_data
tasks = []
for edge_key, edges in all_edges.items():
tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges)))
await asyncio.gather(*tasks)
async def extract_entities(
chunks: dict[str, TextChunkSchema],