feat: Optimize document deletion performance

- To enhance performance during document deletion, new batch-get methods, `get_nodes_by_chunk_ids` and `get_edges_by_chunk_ids`, have been added to the graph storage layer (`BaseGraphStorage` and its implementations). The [`adelete_by_doc_id`](lightrag/lightrag.py:1681) function now leverages these methods to avoid unnecessary iteration over the entire knowledge graph, significantly improving efficiency.
- Graph storage updated: Networkx, Neo4j, Postgres AGE
This commit is contained in:
yangdx 2025-06-25 12:37:57 +08:00
parent ebe5b1e0d2
commit da46b341dc
8 changed files with 271 additions and 63 deletions

View File

@ -14,6 +14,7 @@ from typing import (
)
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
from .constants import GRAPH_FIELD_SEP
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
@ -456,6 +457,67 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = edges if edges is not None else []
return result
@abstractmethod
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
# Default implementation iterates through all nodes, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_nodes = []
all_labels = await self.get_all_labels()
for label in all_labels:
node = await self.get_node(label)
if node and "source_id" in node:
source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
all_nodes.append(node)
return all_nodes
@abstractmethod
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
# Default implementation iterates through all nodes and their edges, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_edges = []
all_labels = await self.get_all_labels()
processed_edges = set()
for label in all_labels:
edges = await self.get_node_edges(label)
if edges:
for src_id, tgt_id in edges:
# Avoid processing the same edge twice in an undirected graph
edge_tuple = tuple(sorted((src_id, tgt_id)))
if edge_tuple in processed_edges:
continue
processed_edges.add(edge_tuple)
edge = await self.get_edge(src_id, tgt_id)
if edge and "source_id" in edge:
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
# Add source and target to the edge dict for easier processing later
edge_with_nodes = edge.copy()
edge_with_nodes["source"] = src_id
edge_with_nodes["target"] = tgt_id
all_edges.append(edge_with_nodes)
return all_edges
@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph.

View File

@ -12,6 +12,9 @@ DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
DEFAULT_WOKERS = 2
DEFAULT_TIMEOUT = 150
# Separator for graph fields
GRAPH_FIELD_SEP = "<SEP>"
# Logging configuration defaults
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups

View File

@ -16,6 +16,7 @@ import logging
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
import pipmaster as pm
if not pm.is_installed("neo4j"):
@ -725,6 +726,47 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are fully consumed
return edges_dict
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $chunk_ids AS chunk_id
MATCH (n:base)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $chunk_ids AS chunk_id
MATCH (a:base)-[r]-(b:base)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),

View File

@ -5,6 +5,7 @@ from typing import final
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger
from lightrag.base import BaseGraphStorage
from lightrag.constants import GRAPH_FIELD_SEP
import pipmaster as pm
@ -357,6 +358,33 @@ class NetworkXStorage(BaseGraphStorage):
)
return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_nodes = []
for node_id, node_data in graph.nodes(data=True):
if "source_id" in node_data:
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
if not node_source_ids.isdisjoint(chunk_ids_set):
node_data_with_id = node_data.copy()
node_data_with_id["id"] = node_id
matching_nodes.append(node_data_with_id)
return matching_nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_edges = []
for u, v, edge_data in graph.edges(data=True):
if "source_id" in edge_data:
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
if not edge_source_ids.isdisjoint(chunk_ids_set):
edge_data_with_nodes = edge_data.copy()
edge_data_with_nodes["source"] = u
edge_data_with_nodes["target"] = v
matching_edges.append(edge_data_with_nodes)
return matching_edges
async def index_done_callback(self) -> bool:
"""Save data to disk"""
async with self._storage_lock:

View File

@ -27,6 +27,7 @@ from ..base import (
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..constants import GRAPH_FIELD_SEP
import pipmaster as pm
@ -1422,8 +1423,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
import json
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(f"Failed to parse node string: {node_dict}")
@ -1479,8 +1478,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary
if isinstance(result, str):
try:
import json
result = json.loads(result)
except json.JSONDecodeError:
logger.warning(f"Failed to parse edge string: {result}")
@ -1697,8 +1694,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
import json
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
@ -1861,8 +1856,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary
if isinstance(edge_props, str):
try:
import json
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
@ -1879,8 +1872,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary
if isinstance(edge_props, str):
try:
import json
edge_props = json.loads(edge_props)
except json.JSONDecodeError:
logger.warning(
@ -1975,6 +1966,102 @@ class PGGraphStorage(BaseGraphStorage):
labels.append(result["label"])
return labels
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all nodes
where the `source_id` property contains any of the specified chunk IDs.
"""
# The string representation of the list for the cypher query
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH (n:base)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
RETURN n
$$) AS (n agtype);
"""
results = await self._query(query)
# Build result list
nodes = []
for result in results:
if result["n"]:
node_dict = result["n"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {node_dict}"
)
node_dict["id"] = node_dict["entity_id"]
nodes.append(node_dict)
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves edges from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all edges
where the `source_id` property contains any of the specified chunk IDs.
"""
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH (a:base)-[r]-(b:base)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
$$) AS (edge agtype, source agtype, target agtype);
"""
results = await self._query(query)
edges = []
if results:
for item in results:
edge_agtype = item["edge"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(edge_agtype, str):
try:
edge_agtype = json.loads(edge_agtype)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge string in batch: {edge_agtype}"
)
source_agtype = item["source"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(source_agtype, str):
try:
source_agtype = json.loads(source_agtype)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {source_agtype}"
)
target_agtype = item["target"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(target_agtype, str):
try:
target_agtype = json.loads(target_agtype)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {target_agtype}"
)
if edge_agtype and source_agtype and target_agtype:
edge_properties = edge_agtype
edge_properties["source"] = source_agtype["entity_id"]
edge_properties["target"] = target_agtype["entity_id"]
edges.append(edge_properties)
return edges
async def _bfs_subgraph(
self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:

View File

@ -60,7 +60,7 @@ from .operate import (
query_with_keywords,
_rebuild_knowledge_from_chunks,
)
from .prompt import GRAPH_FIELD_SEP
from .constants import GRAPH_FIELD_SEP
from .utils import (
Tokenizer,
TiktokenTokenizer,
@ -1761,68 +1761,54 @@ class LightRAG:
# Use graph database lock to ensure atomic merges and updates
graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock:
# Process entities
# TODO There is performance when iterating get_all_labels for PostgresSQL
all_labels = await self.chunk_entity_relation_graph.get_all_labels()
for node_label in all_labels:
node_data = await self.chunk_entity_relation_graph.get_node(
node_label
# Get all affected nodes and edges in batch
affected_nodes = (
await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids(
list(chunk_ids)
)
if node_data and "source_id" in node_data:
# Split source_id using GRAPH_FIELD_SEP
)
affected_edges = (
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
list(chunk_ids)
)
)
# logger.info(f"chunk_ids: {chunk_ids}")
# logger.info(f"affected_nodes: {affected_nodes}")
# logger.info(f"affected_edges: {affected_edges}")
# Process entities
for node_data in affected_nodes:
node_label = node_data.get("entity_id")
if node_label and "source_id" in node_data:
sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
remaining_sources = sources - chunk_ids
if not remaining_sources:
entities_to_delete.add(node_label)
logger.debug(
f"Entity {node_label} marked for deletion - no remaining sources"
)
elif remaining_sources != sources:
# Entity needs to be rebuilt from remaining chunks
entities_to_rebuild[node_label] = remaining_sources
logger.debug(
f"Entity {node_label} will be rebuilt from {len(remaining_sources)} remaining chunks"
)
# Process relationships
# TODO There is performance when iterating get_all_labels for PostgresSQL
for node_label in all_labels:
node_edges = await self.chunk_entity_relation_graph.get_node_edges(
node_label
)
if node_edges:
for src, tgt in node_edges:
# To avoid processing the same edge twice in an undirected graph
if (tgt, src) in relationships_to_delete or (
tgt,
src,
) in relationships_to_rebuild:
continue
for edge_data in affected_edges:
src = edge_data.get("source")
tgt = edge_data.get("target")
edge_data = await self.chunk_entity_relation_graph.get_edge(
src, tgt
)
if edge_data and "source_id" in edge_data:
# Split source_id using GRAPH_FIELD_SEP
sources = set(
edge_data["source_id"].split(GRAPH_FIELD_SEP)
)
remaining_sources = sources - chunk_ids
if src and tgt and "source_id" in edge_data:
edge_tuple = tuple(sorted((src, tgt)))
if (
edge_tuple in relationships_to_delete
or edge_tuple in relationships_to_rebuild
):
continue
if not remaining_sources:
relationships_to_delete.add((src, tgt))
logger.debug(
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
)
elif remaining_sources != sources:
# Relationship needs to be rebuilt from remaining chunks
relationships_to_rebuild[(src, tgt)] = (
remaining_sources
)
logger.debug(
f"Relationship {src}-{tgt} will be rebuilt from {len(remaining_sources)} remaining chunks"
)
sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
remaining_sources = sources - chunk_ids
if not remaining_sources:
relationships_to_delete.add(edge_tuple)
elif remaining_sources != sources:
relationships_to_rebuild[edge_tuple] = remaining_sources
# 5. Delete chunks from storage
if chunk_ids:

View File

@ -33,7 +33,8 @@ from .base import (
TextChunkSchema,
QueryParam,
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
from .prompt import PROMPTS
from .constants import GRAPH_FIELD_SEP
import time
from dotenv import load_dotenv

View File

@ -1,7 +1,6 @@
from __future__ import annotations
from typing import Any
GRAPH_FIELD_SEP = "<SEP>"
PROMPTS: dict[str, Any] = {}