mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-12-30 08:20:57 +00:00
Initial commit with keyed graph lock
This commit is contained in:
parent
b7eae4d7c0
commit
f8149790e4
@ -1,9 +1,12 @@
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import multiprocessing as mp
|
||||
from multiprocessing.synchronize import Lock as ProcessLock
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, TypeVar, Generic
|
||||
|
||||
|
||||
# Define a direct print function for critical logs that must be visible in all processes
|
||||
@ -27,8 +30,14 @@ LockType = Union[ProcessLock, asyncio.Lock]
|
||||
_is_multiprocess = None
|
||||
_workers = None
|
||||
_manager = None
|
||||
_lock_registry: Optional[Dict[str, mp.synchronize.Lock]] = None
|
||||
_lock_registry_count: Optional[Dict[str, int]] = None
|
||||
_lock_cleanup_data: Optional[Dict[str, time.time]] = None
|
||||
_registry_guard = None
|
||||
_initialized = None
|
||||
|
||||
CLEANUP_KEYED_LOCKS_AFTER_SECONDS = 300
|
||||
|
||||
# shared data for storage across processes
|
||||
_shared_dicts: Optional[Dict[str, Any]] = None
|
||||
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
||||
@ -40,10 +49,31 @@ _internal_lock: Optional[LockType] = None
|
||||
_pipeline_status_lock: Optional[LockType] = None
|
||||
_graph_db_lock: Optional[LockType] = None
|
||||
_data_init_lock: Optional[LockType] = None
|
||||
_graph_db_lock_keyed: Optional["KeyedUnifiedLock"] = None
|
||||
|
||||
# async locks for coroutine synchronization in multiprocess mode
|
||||
_async_locks: Optional[Dict[str, asyncio.Lock]] = None
|
||||
|
||||
DEBUG_LOCKS = False
|
||||
_debug_n_locks_acquired: int = 0
|
||||
def inc_debug_n_locks_acquired():
|
||||
global _debug_n_locks_acquired
|
||||
if DEBUG_LOCKS:
|
||||
_debug_n_locks_acquired += 1
|
||||
print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True)
|
||||
|
||||
def dec_debug_n_locks_acquired():
|
||||
global _debug_n_locks_acquired
|
||||
if DEBUG_LOCKS:
|
||||
if _debug_n_locks_acquired > 0:
|
||||
_debug_n_locks_acquired -= 1
|
||||
print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True)
|
||||
else:
|
||||
raise RuntimeError("Attempting to release lock when no locks are acquired")
|
||||
|
||||
def get_debug_n_locks_acquired():
|
||||
global _debug_n_locks_acquired
|
||||
return _debug_n_locks_acquired
|
||||
|
||||
class UnifiedLock(Generic[T]):
|
||||
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
||||
@ -210,6 +240,207 @@ class UnifiedLock(Generic[T]):
|
||||
)
|
||||
raise
|
||||
|
||||
def locked(self) -> bool:
|
||||
if self._is_async:
|
||||
return self._lock.locked()
|
||||
else:
|
||||
return self._lock.locked()
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 2. CROSS‑PROCESS FACTORY (one manager.Lock shared by *all* processes)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
def _get_combined_key(factory_name: str, key: str) -> str:
|
||||
"""Return the combined key for the factory and key."""
|
||||
return f"{factory_name}:{key}"
|
||||
|
||||
def _get_or_create_shared_raw_mp_lock(factory_name: str, key: str) -> Optional[mp.synchronize.Lock]:
|
||||
"""Return the *singleton* manager.Lock() proxy for *key*, creating if needed."""
|
||||
if not _is_multiprocess:
|
||||
return None
|
||||
|
||||
with _registry_guard:
|
||||
combined_key = _get_combined_key(factory_name, key)
|
||||
raw = _lock_registry.get(combined_key)
|
||||
count = _lock_registry_count.get(combined_key)
|
||||
if raw is None:
|
||||
raw = _manager.Lock()
|
||||
_lock_registry[combined_key] = raw
|
||||
_lock_registry_count[combined_key] = 0
|
||||
else:
|
||||
if count is None:
|
||||
raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}")
|
||||
count += 1
|
||||
_lock_registry_count[combined_key] = count
|
||||
if count == 1 and combined_key in _lock_cleanup_data:
|
||||
_lock_cleanup_data.pop(combined_key)
|
||||
return raw
|
||||
|
||||
|
||||
def _release_shared_raw_mp_lock(factory_name: str, key: str):
|
||||
"""Release the *singleton* manager.Lock() proxy for *key*."""
|
||||
if not _is_multiprocess:
|
||||
return
|
||||
|
||||
with _registry_guard:
|
||||
combined_key = _get_combined_key(factory_name, key)
|
||||
raw = _lock_registry.get(combined_key)
|
||||
count = _lock_registry_count.get(combined_key)
|
||||
if raw is None and count is None:
|
||||
return
|
||||
elif raw is None or count is None:
|
||||
raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}")
|
||||
|
||||
count -= 1
|
||||
if count < 0:
|
||||
raise RuntimeError(f"Attempting to remove lock for {key} but it is not in the registry")
|
||||
else:
|
||||
_lock_registry_count[combined_key] = count
|
||||
|
||||
if count == 0:
|
||||
_lock_cleanup_data[combined_key] = time.time()
|
||||
|
||||
for combined_key, value in list(_lock_cleanup_data.items()):
|
||||
if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS:
|
||||
_lock_registry.pop(combined_key)
|
||||
_lock_registry_count.pop(combined_key)
|
||||
_lock_cleanup_data.pop(combined_key)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 3. PARAMETER‑KEYED WRAPPER (unchanged except it *accepts a factory*)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
class KeyedUnifiedLock:
|
||||
"""
|
||||
Parameter‑keyed wrapper around `UnifiedLock`.
|
||||
|
||||
• Keeps only a table of per‑key *asyncio* gates locally
|
||||
• Fetches the shared process‑wide mutex on *every* acquire
|
||||
• Builds a fresh `UnifiedLock` each time, so `enable_logging`
|
||||
(or future options) can vary per call.
|
||||
"""
|
||||
|
||||
# ---------------- construction ----------------
|
||||
def __init__(self, factory_name: str, *, default_enable_logging: bool = True) -> None:
|
||||
self._factory_name = factory_name
|
||||
self._default_enable_logging = default_enable_logging
|
||||
self._async_lock: Dict[str, asyncio.Lock] = {} # key → asyncio.Lock
|
||||
self._async_lock_count: Dict[str, int] = {} # key → asyncio.Lock count
|
||||
self._async_lock_cleanup_data: Dict[str, time.time] = {} # key → time.time
|
||||
self._mp_locks: Dict[str, mp.synchronize.Lock] = {} # key → mp.synchronize.Lock
|
||||
|
||||
# ---------------- public API ------------------
|
||||
def __call__(self, keys: list[str], *, enable_logging: Optional[bool] = None):
|
||||
"""
|
||||
Ergonomic helper so you can write:
|
||||
|
||||
async with keyed_locks("alpha"):
|
||||
...
|
||||
"""
|
||||
if enable_logging is None:
|
||||
enable_logging = self._default_enable_logging
|
||||
return _KeyedLockContext(self, factory_name=self._factory_name, keys=keys, enable_logging=enable_logging)
|
||||
|
||||
def _get_or_create_async_lock(self, key: str) -> asyncio.Lock:
|
||||
async_lock = self._async_lock.get(key)
|
||||
count = self._async_lock_count.get(key, 0)
|
||||
if async_lock is None:
|
||||
async_lock = asyncio.Lock()
|
||||
self._async_lock[key] = async_lock
|
||||
elif count == 0 and key in self._async_lock_cleanup_data:
|
||||
self._async_lock_cleanup_data.pop(key)
|
||||
count += 1
|
||||
self._async_lock_count[key] = count
|
||||
return async_lock
|
||||
|
||||
def _release_async_lock(self, key: str):
|
||||
count = self._async_lock_count.get(key, 0)
|
||||
count -= 1
|
||||
if count == 0:
|
||||
self._async_lock_cleanup_data[key] = time.time()
|
||||
self._async_lock_count[key] = count
|
||||
|
||||
for key, value in list(self._async_lock_cleanup_data.items()):
|
||||
if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS:
|
||||
self._async_lock.pop(key)
|
||||
self._async_lock_count.pop(key)
|
||||
self._async_lock_cleanup_data.pop(key)
|
||||
|
||||
|
||||
def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock:
|
||||
# 1. get (or create) the per‑process async gate for this key
|
||||
# Is synchronous, so no need to acquire a lock
|
||||
async_lock = self._get_or_create_async_lock(key)
|
||||
|
||||
# 2. fetch the shared raw lock
|
||||
raw_lock = _get_or_create_shared_raw_mp_lock(self._factory_name, key)
|
||||
is_multiprocess = raw_lock is not None
|
||||
if not is_multiprocess:
|
||||
raw_lock = async_lock
|
||||
|
||||
# 3. build a *fresh* UnifiedLock with the chosen logging flag
|
||||
if is_multiprocess:
|
||||
return UnifiedLock(
|
||||
lock=raw_lock,
|
||||
is_async=False, # manager.Lock is synchronous
|
||||
name=f"key:{self._factory_name}:{key}",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock, # prevents event‑loop blocking
|
||||
)
|
||||
else:
|
||||
return UnifiedLock(
|
||||
lock=raw_lock,
|
||||
is_async=True,
|
||||
name=f"key:{self._factory_name}:{key}",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=None, # No need for async lock in single process mode
|
||||
)
|
||||
|
||||
def _release_lock_for_key(self, key: str):
|
||||
self._release_async_lock(key)
|
||||
_release_shared_raw_mp_lock(self._factory_name, key)
|
||||
|
||||
class _KeyedLockContext:
|
||||
def __init__(
|
||||
self,
|
||||
parent: KeyedUnifiedLock,
|
||||
factory_name: str,
|
||||
keys: list[str],
|
||||
enable_logging: bool,
|
||||
) -> None:
|
||||
self._parent = parent
|
||||
self._factory_name = factory_name
|
||||
|
||||
# The sorting is critical to ensure proper lock and release order
|
||||
# to avoid deadlocks
|
||||
self._keys = sorted(keys)
|
||||
self._enable_logging = (
|
||||
enable_logging if enable_logging is not None
|
||||
else parent._default_enable_logging
|
||||
)
|
||||
self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__
|
||||
|
||||
# ----- enter -----
|
||||
async def __aenter__(self):
|
||||
if self._ul is not None:
|
||||
raise RuntimeError("KeyedUnifiedLock already acquired in current context")
|
||||
|
||||
# 4. acquire it
|
||||
self._ul = []
|
||||
for key in self._keys:
|
||||
lock = self._parent._get_lock_for_key(key, enable_logging=self._enable_logging)
|
||||
await lock.__aenter__()
|
||||
inc_debug_n_locks_acquired()
|
||||
self._ul.append(lock)
|
||||
return self # or return self._key if you prefer
|
||||
|
||||
# ----- exit -----
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
# The UnifiedLock takes care of proper release order
|
||||
for ul, key in zip(reversed(self._ul), reversed(self._keys)):
|
||||
await ul.__aexit__(exc_type, exc, tb)
|
||||
self._parent._release_lock_for_key(key)
|
||||
dec_debug_n_locks_acquired()
|
||||
self._ul = None
|
||||
|
||||
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
@ -258,6 +489,14 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
async_lock=async_lock,
|
||||
)
|
||||
|
||||
def get_graph_db_lock_keyed(keys: str | list[str], enable_logging: bool = False) -> KeyedUnifiedLock:
|
||||
"""return unified graph database lock for ensuring atomic operations"""
|
||||
global _graph_db_lock_keyed
|
||||
if _graph_db_lock_keyed is None:
|
||||
raise RuntimeError("Shared-Data is not initialized")
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
return _graph_db_lock_keyed(keys, enable_logging=enable_logging)
|
||||
|
||||
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified data initialization lock for ensuring atomic data initialization"""
|
||||
@ -294,6 +533,10 @@ def initialize_share_data(workers: int = 1):
|
||||
_workers, \
|
||||
_is_multiprocess, \
|
||||
_storage_lock, \
|
||||
_lock_registry, \
|
||||
_lock_registry_count, \
|
||||
_lock_cleanup_data, \
|
||||
_registry_guard, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_graph_db_lock, \
|
||||
@ -302,7 +545,8 @@ def initialize_share_data(workers: int = 1):
|
||||
_init_flags, \
|
||||
_initialized, \
|
||||
_update_flags, \
|
||||
_async_locks
|
||||
_async_locks, \
|
||||
_graph_db_lock_keyed
|
||||
|
||||
# Check if already initialized
|
||||
if _initialized:
|
||||
@ -316,6 +560,10 @@ def initialize_share_data(workers: int = 1):
|
||||
if workers > 1:
|
||||
_is_multiprocess = True
|
||||
_manager = Manager()
|
||||
_lock_registry = _manager.dict()
|
||||
_lock_registry_count = _manager.dict()
|
||||
_lock_cleanup_data = _manager.dict()
|
||||
_registry_guard = _manager.RLock()
|
||||
_internal_lock = _manager.Lock()
|
||||
_storage_lock = _manager.Lock()
|
||||
_pipeline_status_lock = _manager.Lock()
|
||||
@ -324,6 +572,10 @@ def initialize_share_data(workers: int = 1):
|
||||
_shared_dicts = _manager.dict()
|
||||
_init_flags = _manager.dict()
|
||||
_update_flags = _manager.dict()
|
||||
|
||||
_graph_db_lock_keyed = KeyedUnifiedLock(
|
||||
factory_name="graph_db_lock",
|
||||
)
|
||||
|
||||
# Initialize async locks for multiprocess mode
|
||||
_async_locks = {
|
||||
@ -348,6 +600,10 @@ def initialize_share_data(workers: int = 1):
|
||||
_init_flags = {}
|
||||
_update_flags = {}
|
||||
_async_locks = None # No need for async locks in single process mode
|
||||
|
||||
_graph_db_lock_keyed = KeyedUnifiedLock(
|
||||
factory_name="graph_db_lock",
|
||||
)
|
||||
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
||||
|
||||
# Mark as initialized
|
||||
|
||||
@ -1024,73 +1024,73 @@ class LightRAG:
|
||||
}
|
||||
)
|
||||
|
||||
# Semphore was released here
|
||||
# Semphore is NOT released here, however, the profile context is
|
||||
|
||||
if file_extraction_stage_ok:
|
||||
try:
|
||||
# Get chunk_results from entity_relation_task
|
||||
chunk_results = await entity_relation_task
|
||||
await merge_nodes_and_edges(
|
||||
chunk_results=chunk_results, # result collected from entity_relation_task
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
pipeline_status=pipeline_status,
|
||||
pipeline_status_lock=pipeline_status_lock,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
current_file_number=current_file_number,
|
||||
total_files=total_files,
|
||||
file_path=file_path,
|
||||
)
|
||||
if file_extraction_stage_ok:
|
||||
try:
|
||||
# Get chunk_results from entity_relation_task
|
||||
chunk_results = await entity_relation_task
|
||||
await merge_nodes_and_edges(
|
||||
chunk_results=chunk_results, # result collected from entity_relation_task
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
pipeline_status=pipeline_status,
|
||||
pipeline_status_lock=pipeline_status_lock,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
current_file_number=current_file_number,
|
||||
total_files=total_files,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"file_path": file_path,
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"file_path": file_path,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Call _insert_done after processing each file
|
||||
await self._insert_done()
|
||||
# Call _insert_done after processing each file
|
||||
await self._insert_done()
|
||||
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
except Exception as e:
|
||||
# Log error and update pipeline status
|
||||
error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = error_msg
|
||||
pipeline_status["history_messages"].append(error_msg)
|
||||
except Exception as e:
|
||||
# Log error and update pipeline status
|
||||
error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = error_msg
|
||||
pipeline_status["history_messages"].append(error_msg)
|
||||
|
||||
# Update document status to failed
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"file_path": file_path,
|
||||
# Update document status to failed
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"file_path": file_path,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Create processing tasks for all documents
|
||||
doc_tasks = []
|
||||
|
||||
@ -9,6 +9,8 @@ import os
|
||||
from typing import Any, AsyncIterator
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
from .kg.shared_storage import get_graph_db_lock_keyed
|
||||
|
||||
from .utils import (
|
||||
logger,
|
||||
clean_str,
|
||||
@ -403,27 +405,31 @@ async def _merge_edges_then_upsert(
|
||||
)
|
||||
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||
# # Discard this edge if the node does not exist
|
||||
# if need_insert_id == src_id:
|
||||
# logger.warning(
|
||||
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
|
||||
# )
|
||||
# else:
|
||||
# logger.warning(
|
||||
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
|
||||
# )
|
||||
# return None
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
need_insert_id,
|
||||
node_data={
|
||||
"entity_id": need_insert_id,
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": "UNKNOWN",
|
||||
"file_path": file_path,
|
||||
},
|
||||
)
|
||||
if (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||
# This is so that the initial check for the existence of the node need not be locked
|
||||
continue
|
||||
async with get_graph_db_lock_keyed([need_insert_id], enable_logging=False):
|
||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||
# # Discard this edge if the node does not exist
|
||||
# if need_insert_id == src_id:
|
||||
# logger.warning(
|
||||
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
|
||||
# )
|
||||
# else:
|
||||
# logger.warning(
|
||||
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
|
||||
# )
|
||||
# return None
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
need_insert_id,
|
||||
node_data={
|
||||
"entity_id": need_insert_id,
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": "UNKNOWN",
|
||||
"file_path": file_path,
|
||||
},
|
||||
)
|
||||
|
||||
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
||||
|
||||
@ -523,23 +529,30 @@ async def merge_nodes_and_edges(
|
||||
all_edges[sorted_edge_key].extend(edges)
|
||||
|
||||
# Centralized processing of all nodes and edges
|
||||
entities_data = []
|
||||
relationships_data = []
|
||||
total_entities_count = len(all_nodes)
|
||||
total_relations_count = len(all_edges)
|
||||
|
||||
# Merge nodes and edges
|
||||
# Use graph database lock to ensure atomic merges and updates
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
async with graph_db_lock:
|
||||
async with pipeline_status_lock:
|
||||
log_message = (
|
||||
f"Merging stage {current_file_number}/{total_files}: {file_path}"
|
||||
)
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# Process and update all entities at once
|
||||
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
log_message = (
|
||||
f"Merging stage {current_file_number}/{total_files}: {file_path}"
|
||||
)
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# Process and update all entities at once
|
||||
for entity_name, entities in all_nodes.items():
|
||||
async def _locked_process_entity_name(entity_name, entities):
|
||||
async with get_graph_db_lock_keyed([entity_name], enable_logging=False):
|
||||
entity_data = await _merge_nodes_then_upsert(
|
||||
entity_name,
|
||||
entities,
|
||||
@ -549,10 +562,34 @@ async def merge_nodes_and_edges(
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
entities_data.append(entity_data)
|
||||
if entity_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): {
|
||||
"entity_name": entity_data["entity_name"],
|
||||
"entity_type": entity_data["entity_type"],
|
||||
"content": f"{entity_data['entity_name']}\n{entity_data['description']}",
|
||||
"source_id": entity_data["source_id"],
|
||||
"file_path": entity_data.get("file_path", "unknown_source"),
|
||||
}
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
return entity_data
|
||||
|
||||
# Process and update all relationships at once
|
||||
for edge_key, edges in all_edges.items():
|
||||
tasks = []
|
||||
for entity_name, entities in all_nodes.items():
|
||||
tasks.append(asyncio.create_task(_locked_process_entity_name(entity_name, entities)))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Process and update all relationships at once
|
||||
log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
async def _locked_process_edges(edge_key, edges):
|
||||
async with get_graph_db_lock_keyed(f"{edge_key[0]}-{edge_key[1]}", enable_logging=False):
|
||||
edge_data = await _merge_edges_then_upsert(
|
||||
edge_key[0],
|
||||
edge_key[1],
|
||||
@ -563,55 +600,27 @@ async def merge_nodes_and_edges(
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
if edge_data is not None:
|
||||
relationships_data.append(edge_data)
|
||||
if edge_data is None:
|
||||
return None
|
||||
|
||||
# Update total counts
|
||||
total_entities_count = len(entities_data)
|
||||
total_relations_count = len(relationships_data)
|
||||
|
||||
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# Update vector databases with all collected data
|
||||
if entity_vdb is not None and entities_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"entity_name": dp["entity_name"],
|
||||
"entity_type": dp["entity_type"],
|
||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
if relationships_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"): {
|
||||
"src_id": edge_data["src_id"],
|
||||
"tgt_id": edge_data["tgt_id"],
|
||||
"keywords": edge_data["keywords"],
|
||||
"content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}",
|
||||
"source_id": edge_data["source_id"],
|
||||
"file_path": edge_data.get("file_path", "unknown_source"),
|
||||
}
|
||||
}
|
||||
for dp in entities_data
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
|
||||
log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
if relationships_vdb is not None and relationships_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"keywords": dp["keywords"],
|
||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in relationships_data
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
return edge_data
|
||||
|
||||
tasks = []
|
||||
for edge_key, edges in all_edges.items():
|
||||
tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges)))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def extract_entities(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user