feat: Optimize document deletion performance

- To enhance performance during document deletion, new batch-get methods, `get_nodes_by_chunk_ids` and `get_edges_by_chunk_ids`, have been added to the graph storage layer (`BaseGraphStorage` and its implementations). The [`adelete_by_doc_id`](lightrag/lightrag.py:1681) function now leverages these methods to avoid unnecessary iteration over the entire knowledge graph, significantly improving efficiency.
- Graph storage updated: Networkx, Neo4j, Postgres AGE
This commit is contained in:
yangdx 2025-06-25 12:37:57 +08:00
parent ebe5b1e0d2
commit da46b341dc
8 changed files with 271 additions and 63 deletions

View File

@ -14,6 +14,7 @@ from typing import (
) )
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph from .types import KnowledgeGraph
from .constants import GRAPH_FIELD_SEP
# use the .env that is inside the current folder # use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance # allows to use different .env file for each lightrag instance
@ -456,6 +457,67 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = edges if edges is not None else [] result[node_id] = edges if edges is not None else []
return result return result
@abstractmethod
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
# Default implementation iterates through all nodes, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_nodes = []
all_labels = await self.get_all_labels()
for label in all_labels:
node = await self.get_node(label)
if node and "source_id" in node:
source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
all_nodes.append(node)
return all_nodes
@abstractmethod
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
# Default implementation iterates through all nodes and their edges, which is inefficient.
# This method should be overridden by subclasses for better performance.
all_edges = []
all_labels = await self.get_all_labels()
processed_edges = set()
for label in all_labels:
edges = await self.get_node_edges(label)
if edges:
for src_id, tgt_id in edges:
# Avoid processing the same edge twice in an undirected graph
edge_tuple = tuple(sorted((src_id, tgt_id)))
if edge_tuple in processed_edges:
continue
processed_edges.add(edge_tuple)
edge = await self.get_edge(src_id, tgt_id)
if edge and "source_id" in edge:
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
if not source_ids.isdisjoint(chunk_ids):
# Add source and target to the edge dict for easier processing later
edge_with_nodes = edge.copy()
edge_with_nodes["source"] = src_id
edge_with_nodes["target"] = tgt_id
all_edges.append(edge_with_nodes)
return all_edges
@abstractmethod @abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph. """Insert a new node or update an existing node in the graph.

View File

@ -12,6 +12,9 @@ DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
DEFAULT_WOKERS = 2 DEFAULT_WOKERS = 2
DEFAULT_TIMEOUT = 150 DEFAULT_TIMEOUT = 150
# Separator for graph fields
GRAPH_FIELD_SEP = "<SEP>"
# Logging configuration defaults # Logging configuration defaults
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups

View File

@ -16,6 +16,7 @@ import logging
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -725,6 +726,47 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are fully consumed await result.consume() # Ensure results are fully consumed
return edges_dict return edges_dict
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $chunk_ids AS chunk_id
MATCH (n:base)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $chunk_ids AS chunk_id
MATCH (a:base)-[r]-(b:base)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),

View File

@ -5,6 +5,7 @@ from typing import final
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger from lightrag.utils import logger
from lightrag.base import BaseGraphStorage from lightrag.base import BaseGraphStorage
from lightrag.constants import GRAPH_FIELD_SEP
import pipmaster as pm import pipmaster as pm
@ -357,6 +358,33 @@ class NetworkXStorage(BaseGraphStorage):
) )
return result return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_nodes = []
for node_id, node_data in graph.nodes(data=True):
if "source_id" in node_data:
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
if not node_source_ids.isdisjoint(chunk_ids_set):
node_data_with_id = node_data.copy()
node_data_with_id["id"] = node_id
matching_nodes.append(node_data_with_id)
return matching_nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_edges = []
for u, v, edge_data in graph.edges(data=True):
if "source_id" in edge_data:
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
if not edge_source_ids.isdisjoint(chunk_ids_set):
edge_data_with_nodes = edge_data.copy()
edge_data_with_nodes["source"] = u
edge_data_with_nodes["target"] = v
matching_edges.append(edge_data_with_nodes)
return matching_edges
async def index_done_callback(self) -> bool: async def index_done_callback(self) -> bool:
"""Save data to disk""" """Save data to disk"""
async with self._storage_lock: async with self._storage_lock:

View File

@ -27,6 +27,7 @@ from ..base import (
) )
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..constants import GRAPH_FIELD_SEP
import pipmaster as pm import pipmaster as pm
@ -1422,8 +1423,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary # Process string result, parse it to JSON dictionary
if isinstance(node_dict, str): if isinstance(node_dict, str):
try: try:
import json
node_dict = json.loads(node_dict) node_dict = json.loads(node_dict)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse node string: {node_dict}") logger.warning(f"Failed to parse node string: {node_dict}")
@ -1479,8 +1478,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary # Process string result, parse it to JSON dictionary
if isinstance(result, str): if isinstance(result, str):
try: try:
import json
result = json.loads(result) result = json.loads(result)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse edge string: {result}") logger.warning(f"Failed to parse edge string: {result}")
@ -1697,8 +1694,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary # Process string result, parse it to JSON dictionary
if isinstance(node_dict, str): if isinstance(node_dict, str):
try: try:
import json
node_dict = json.loads(node_dict) node_dict = json.loads(node_dict)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
@ -1861,8 +1856,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary # Process string result, parse it to JSON dictionary
if isinstance(edge_props, str): if isinstance(edge_props, str):
try: try:
import json
edge_props = json.loads(edge_props) edge_props = json.loads(edge_props)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
@ -1879,8 +1872,6 @@ class PGGraphStorage(BaseGraphStorage):
# Process string result, parse it to JSON dictionary # Process string result, parse it to JSON dictionary
if isinstance(edge_props, str): if isinstance(edge_props, str):
try: try:
import json
edge_props = json.loads(edge_props) edge_props = json.loads(edge_props)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
@ -1975,6 +1966,102 @@ class PGGraphStorage(BaseGraphStorage):
labels.append(result["label"]) labels.append(result["label"])
return labels return labels
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all nodes
where the `source_id` property contains any of the specified chunk IDs.
"""
# The string representation of the list for the cypher query
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH (n:base)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
RETURN n
$$) AS (n agtype);
"""
results = await self._query(query)
# Build result list
nodes = []
for result in results:
if result["n"]:
node_dict = result["n"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {node_dict}"
)
node_dict["id"] = node_dict["entity_id"]
nodes.append(node_dict)
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""
Retrieves edges from the graph that are associated with a given list of chunk IDs.
This method uses a Cypher query with UNWIND to efficiently find all edges
where the `source_id` property contains any of the specified chunk IDs.
"""
chunk_ids_str = json.dumps(chunk_ids)
query = f"""
SELECT * FROM cypher('{self.graph_name}', $$
UNWIND {chunk_ids_str} AS chunk_id
MATCH (a:base)-[r]-(b:base)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
$$) AS (edge agtype, source agtype, target agtype);
"""
results = await self._query(query)
edges = []
if results:
for item in results:
edge_agtype = item["edge"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(edge_agtype, str):
try:
edge_agtype = json.loads(edge_agtype)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge string in batch: {edge_agtype}"
)
source_agtype = item["source"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(source_agtype, str):
try:
source_agtype = json.loads(source_agtype)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {source_agtype}"
)
target_agtype = item["target"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(target_agtype, str):
try:
target_agtype = json.loads(target_agtype)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse node string in batch: {target_agtype}"
)
if edge_agtype and source_agtype and target_agtype:
edge_properties = edge_agtype
edge_properties["source"] = source_agtype["entity_id"]
edge_properties["target"] = target_agtype["entity_id"]
edges.append(edge_properties)
return edges
async def _bfs_subgraph( async def _bfs_subgraph(
self, node_label: str, max_depth: int, max_nodes: int self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@ -60,7 +60,7 @@ from .operate import (
query_with_keywords, query_with_keywords,
_rebuild_knowledge_from_chunks, _rebuild_knowledge_from_chunks,
) )
from .prompt import GRAPH_FIELD_SEP from .constants import GRAPH_FIELD_SEP
from .utils import ( from .utils import (
Tokenizer, Tokenizer,
TiktokenTokenizer, TiktokenTokenizer,
@ -1761,68 +1761,54 @@ class LightRAG:
# Use graph database lock to ensure atomic merges and updates # Use graph database lock to ensure atomic merges and updates
graph_db_lock = get_graph_db_lock(enable_logging=False) graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock: async with graph_db_lock:
# Process entities # Get all affected nodes and edges in batch
# TODO There is performance when iterating get_all_labels for PostgresSQL affected_nodes = (
all_labels = await self.chunk_entity_relation_graph.get_all_labels() await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids(
for node_label in all_labels: list(chunk_ids)
node_data = await self.chunk_entity_relation_graph.get_node(
node_label
) )
if node_data and "source_id" in node_data: )
# Split source_id using GRAPH_FIELD_SEP affected_edges = (
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
list(chunk_ids)
)
)
# logger.info(f"chunk_ids: {chunk_ids}")
# logger.info(f"affected_nodes: {affected_nodes}")
# logger.info(f"affected_edges: {affected_edges}")
# Process entities
for node_data in affected_nodes:
node_label = node_data.get("entity_id")
if node_label and "source_id" in node_data:
sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP)) sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
remaining_sources = sources - chunk_ids remaining_sources = sources - chunk_ids
if not remaining_sources: if not remaining_sources:
entities_to_delete.add(node_label) entities_to_delete.add(node_label)
logger.debug(
f"Entity {node_label} marked for deletion - no remaining sources"
)
elif remaining_sources != sources: elif remaining_sources != sources:
# Entity needs to be rebuilt from remaining chunks
entities_to_rebuild[node_label] = remaining_sources entities_to_rebuild[node_label] = remaining_sources
logger.debug(
f"Entity {node_label} will be rebuilt from {len(remaining_sources)} remaining chunks"
)
# Process relationships # Process relationships
# TODO There is performance when iterating get_all_labels for PostgresSQL for edge_data in affected_edges:
for node_label in all_labels: src = edge_data.get("source")
node_edges = await self.chunk_entity_relation_graph.get_node_edges( tgt = edge_data.get("target")
node_label
) if src and tgt and "source_id" in edge_data:
if node_edges: edge_tuple = tuple(sorted((src, tgt)))
for src, tgt in node_edges: if (
# To avoid processing the same edge twice in an undirected graph edge_tuple in relationships_to_delete
if (tgt, src) in relationships_to_delete or ( or edge_tuple in relationships_to_rebuild
tgt, ):
src,
) in relationships_to_rebuild:
continue continue
edge_data = await self.chunk_entity_relation_graph.get_edge( sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
src, tgt
)
if edge_data and "source_id" in edge_data:
# Split source_id using GRAPH_FIELD_SEP
sources = set(
edge_data["source_id"].split(GRAPH_FIELD_SEP)
)
remaining_sources = sources - chunk_ids remaining_sources = sources - chunk_ids
if not remaining_sources: if not remaining_sources:
relationships_to_delete.add((src, tgt)) relationships_to_delete.add(edge_tuple)
logger.debug(
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
)
elif remaining_sources != sources: elif remaining_sources != sources:
# Relationship needs to be rebuilt from remaining chunks relationships_to_rebuild[edge_tuple] = remaining_sources
relationships_to_rebuild[(src, tgt)] = (
remaining_sources
)
logger.debug(
f"Relationship {src}-{tgt} will be rebuilt from {len(remaining_sources)} remaining chunks"
)
# 5. Delete chunks from storage # 5. Delete chunks from storage
if chunk_ids: if chunk_ids:

View File

@ -33,7 +33,8 @@ from .base import (
TextChunkSchema, TextChunkSchema,
QueryParam, QueryParam,
) )
from .prompt import GRAPH_FIELD_SEP, PROMPTS from .prompt import PROMPTS
from .constants import GRAPH_FIELD_SEP
import time import time
from dotenv import load_dotenv from dotenv import load_dotenv

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any
GRAPH_FIELD_SEP = "<SEP>"
PROMPTS: dict[str, Any] = {} PROMPTS: dict[str, Any] = {}