mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-02 18:59:32 +00:00
feat: Optimize document deletion performance
- To enhance performance during document deletion, new batch-get methods, `get_nodes_by_chunk_ids` and `get_edges_by_chunk_ids`, have been added to the graph storage layer (`BaseGraphStorage` and its implementations). The [`adelete_by_doc_id`](lightrag/lightrag.py:1681) function now leverages these methods to avoid unnecessary iteration over the entire knowledge graph, significantly improving efficiency. - Graph storage updated: Networkx, Neo4j, Postgres AGE
This commit is contained in:
parent
ebe5b1e0d2
commit
da46b341dc
@ -14,6 +14,7 @@ from typing import (
|
||||
)
|
||||
from .utils import EmbeddingFunc
|
||||
from .types import KnowledgeGraph
|
||||
from .constants import GRAPH_FIELD_SEP
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
@ -456,6 +457,67 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
result[node_id] = edges if edges is not None else []
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""Get all nodes that are associated with the given chunk_ids.
|
||||
|
||||
Args:
|
||||
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
||||
An empty list if no matching nodes are found.
|
||||
"""
|
||||
# Default implementation iterates through all nodes, which is inefficient.
|
||||
# This method should be overridden by subclasses for better performance.
|
||||
all_nodes = []
|
||||
all_labels = await self.get_all_labels()
|
||||
for label in all_labels:
|
||||
node = await self.get_node(label)
|
||||
if node and "source_id" in node:
|
||||
source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP))
|
||||
if not source_ids.isdisjoint(chunk_ids):
|
||||
all_nodes.append(node)
|
||||
return all_nodes
|
||||
|
||||
@abstractmethod
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""Get all edges that are associated with the given chunk_ids.
|
||||
|
||||
Args:
|
||||
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
||||
An empty list if no matching edges are found.
|
||||
"""
|
||||
# Default implementation iterates through all nodes and their edges, which is inefficient.
|
||||
# This method should be overridden by subclasses for better performance.
|
||||
all_edges = []
|
||||
all_labels = await self.get_all_labels()
|
||||
processed_edges = set()
|
||||
|
||||
for label in all_labels:
|
||||
edges = await self.get_node_edges(label)
|
||||
if edges:
|
||||
for src_id, tgt_id in edges:
|
||||
# Avoid processing the same edge twice in an undirected graph
|
||||
edge_tuple = tuple(sorted((src_id, tgt_id)))
|
||||
if edge_tuple in processed_edges:
|
||||
continue
|
||||
processed_edges.add(edge_tuple)
|
||||
|
||||
edge = await self.get_edge(src_id, tgt_id)
|
||||
if edge and "source_id" in edge:
|
||||
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
|
||||
if not source_ids.isdisjoint(chunk_ids):
|
||||
# Add source and target to the edge dict for easier processing later
|
||||
edge_with_nodes = edge.copy()
|
||||
edge_with_nodes["source"] = src_id
|
||||
edge_with_nodes["target"] = tgt_id
|
||||
all_edges.append(edge_with_nodes)
|
||||
return all_edges
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""Insert a new node or update an existing node in the graph.
|
||||
|
||||
@ -12,6 +12,9 @@ DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
|
||||
DEFAULT_WOKERS = 2
|
||||
DEFAULT_TIMEOUT = 150
|
||||
|
||||
# Separator for graph fields
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
# Logging configuration defaults
|
||||
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
|
||||
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
|
||||
|
||||
@ -16,6 +16,7 @@ import logging
|
||||
from ..utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from ..constants import GRAPH_FIELD_SEP
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("neo4j"):
|
||||
@ -725,6 +726,47 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
await result.consume() # Ensure results are fully consumed
|
||||
return edges_dict
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $chunk_ids AS chunk_id
|
||||
MATCH (n:base)
|
||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
||||
RETURN DISTINCT n
|
||||
"""
|
||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||
nodes = []
|
||||
async for record in result:
|
||||
node = record["n"]
|
||||
node_dict = dict(node)
|
||||
# Add node id (entity_id) to the dictionary for easier access
|
||||
node_dict["id"] = node_dict.get("entity_id")
|
||||
nodes.append(node_dict)
|
||||
await result.consume()
|
||||
return nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $chunk_ids AS chunk_id
|
||||
MATCH (a:base)-[r]-(b:base)
|
||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
||||
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
||||
"""
|
||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||
edges = []
|
||||
async for record in result:
|
||||
edge_properties = record["properties"]
|
||||
edge_properties["source"] = record["source"]
|
||||
edge_properties["target"] = record["target"]
|
||||
edges.append(edge_properties)
|
||||
await result.consume()
|
||||
return edges
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import final
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from lightrag.utils import logger
|
||||
from lightrag.base import BaseGraphStorage
|
||||
from lightrag.constants import GRAPH_FIELD_SEP
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
@ -357,6 +358,33 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
chunk_ids_set = set(chunk_ids)
|
||||
graph = await self._get_graph()
|
||||
matching_nodes = []
|
||||
for node_id, node_data in graph.nodes(data=True):
|
||||
if "source_id" in node_data:
|
||||
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||
if not node_source_ids.isdisjoint(chunk_ids_set):
|
||||
node_data_with_id = node_data.copy()
|
||||
node_data_with_id["id"] = node_id
|
||||
matching_nodes.append(node_data_with_id)
|
||||
return matching_nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
chunk_ids_set = set(chunk_ids)
|
||||
graph = await self._get_graph()
|
||||
matching_edges = []
|
||||
for u, v, edge_data in graph.edges(data=True):
|
||||
if "source_id" in edge_data:
|
||||
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||
if not edge_source_ids.isdisjoint(chunk_ids_set):
|
||||
edge_data_with_nodes = edge_data.copy()
|
||||
edge_data_with_nodes["source"] = u
|
||||
edge_data_with_nodes["target"] = v
|
||||
matching_edges.append(edge_data_with_nodes)
|
||||
return matching_edges
|
||||
|
||||
async def index_done_callback(self) -> bool:
|
||||
"""Save data to disk"""
|
||||
async with self._storage_lock:
|
||||
|
||||
@ -27,6 +27,7 @@ from ..base import (
|
||||
)
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
from ..constants import GRAPH_FIELD_SEP
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
@ -1422,8 +1423,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(node_dict, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
node_dict = json.loads(node_dict)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse node string: {node_dict}")
|
||||
@ -1479,8 +1478,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
result = json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse edge string: {result}")
|
||||
@ -1697,8 +1694,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(node_dict, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
node_dict = json.loads(node_dict)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
@ -1861,8 +1856,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(edge_props, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
edge_props = json.loads(edge_props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
@ -1879,8 +1872,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(edge_props, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
edge_props = json.loads(edge_props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
@ -1975,6 +1966,102 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
labels.append(result["label"])
|
||||
return labels
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""
|
||||
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
|
||||
This method uses a Cypher query with UNWIND to efficiently find all nodes
|
||||
where the `source_id` property contains any of the specified chunk IDs.
|
||||
"""
|
||||
# The string representation of the list for the cypher query
|
||||
chunk_ids_str = json.dumps(chunk_ids)
|
||||
|
||||
query = f"""
|
||||
SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND {chunk_ids_str} AS chunk_id
|
||||
MATCH (n:base)
|
||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
|
||||
RETURN n
|
||||
$$) AS (n agtype);
|
||||
"""
|
||||
results = await self._query(query)
|
||||
|
||||
# Build result list
|
||||
nodes = []
|
||||
for result in results:
|
||||
if result["n"]:
|
||||
node_dict = result["n"]["properties"]
|
||||
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(node_dict, str):
|
||||
try:
|
||||
node_dict = json.loads(node_dict)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse node string in batch: {node_dict}"
|
||||
)
|
||||
|
||||
node_dict["id"] = node_dict["entity_id"]
|
||||
nodes.append(node_dict)
|
||||
|
||||
return nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""
|
||||
Retrieves edges from the graph that are associated with a given list of chunk IDs.
|
||||
This method uses a Cypher query with UNWIND to efficiently find all edges
|
||||
where the `source_id` property contains any of the specified chunk IDs.
|
||||
"""
|
||||
chunk_ids_str = json.dumps(chunk_ids)
|
||||
|
||||
query = f"""
|
||||
SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND {chunk_ids_str} AS chunk_id
|
||||
MATCH (a:base)-[r]-(b:base)
|
||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
|
||||
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
|
||||
$$) AS (edge agtype, source agtype, target agtype);
|
||||
"""
|
||||
results = await self._query(query)
|
||||
edges = []
|
||||
if results:
|
||||
for item in results:
|
||||
edge_agtype = item["edge"]["properties"]
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(edge_agtype, str):
|
||||
try:
|
||||
edge_agtype = json.loads(edge_agtype)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge string in batch: {edge_agtype}"
|
||||
)
|
||||
|
||||
source_agtype = item["source"]["properties"]
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(source_agtype, str):
|
||||
try:
|
||||
source_agtype = json.loads(source_agtype)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse node string in batch: {source_agtype}"
|
||||
)
|
||||
|
||||
target_agtype = item["target"]["properties"]
|
||||
# Process string result, parse it to JSON dictionary
|
||||
if isinstance(target_agtype, str):
|
||||
try:
|
||||
target_agtype = json.loads(target_agtype)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse node string in batch: {target_agtype}"
|
||||
)
|
||||
|
||||
if edge_agtype and source_agtype and target_agtype:
|
||||
edge_properties = edge_agtype
|
||||
edge_properties["source"] = source_agtype["entity_id"]
|
||||
edge_properties["target"] = target_agtype["entity_id"]
|
||||
edges.append(edge_properties)
|
||||
return edges
|
||||
|
||||
async def _bfs_subgraph(
|
||||
self, node_label: str, max_depth: int, max_nodes: int
|
||||
) -> KnowledgeGraph:
|
||||
|
||||
@ -60,7 +60,7 @@ from .operate import (
|
||||
query_with_keywords,
|
||||
_rebuild_knowledge_from_chunks,
|
||||
)
|
||||
from .prompt import GRAPH_FIELD_SEP
|
||||
from .constants import GRAPH_FIELD_SEP
|
||||
from .utils import (
|
||||
Tokenizer,
|
||||
TiktokenTokenizer,
|
||||
@ -1761,68 +1761,54 @@ class LightRAG:
|
||||
# Use graph database lock to ensure atomic merges and updates
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
async with graph_db_lock:
|
||||
# Process entities
|
||||
# TODO There is performance when iterating get_all_labels for PostgresSQL
|
||||
all_labels = await self.chunk_entity_relation_graph.get_all_labels()
|
||||
for node_label in all_labels:
|
||||
node_data = await self.chunk_entity_relation_graph.get_node(
|
||||
node_label
|
||||
# Get all affected nodes and edges in batch
|
||||
affected_nodes = (
|
||||
await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids(
|
||||
list(chunk_ids)
|
||||
)
|
||||
if node_data and "source_id" in node_data:
|
||||
# Split source_id using GRAPH_FIELD_SEP
|
||||
)
|
||||
affected_edges = (
|
||||
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
|
||||
list(chunk_ids)
|
||||
)
|
||||
)
|
||||
|
||||
# logger.info(f"chunk_ids: {chunk_ids}")
|
||||
# logger.info(f"affected_nodes: {affected_nodes}")
|
||||
# logger.info(f"affected_edges: {affected_edges}")
|
||||
|
||||
# Process entities
|
||||
for node_data in affected_nodes:
|
||||
node_label = node_data.get("entity_id")
|
||||
if node_label and "source_id" in node_data:
|
||||
sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||
remaining_sources = sources - chunk_ids
|
||||
|
||||
if not remaining_sources:
|
||||
entities_to_delete.add(node_label)
|
||||
logger.debug(
|
||||
f"Entity {node_label} marked for deletion - no remaining sources"
|
||||
)
|
||||
elif remaining_sources != sources:
|
||||
# Entity needs to be rebuilt from remaining chunks
|
||||
entities_to_rebuild[node_label] = remaining_sources
|
||||
logger.debug(
|
||||
f"Entity {node_label} will be rebuilt from {len(remaining_sources)} remaining chunks"
|
||||
)
|
||||
|
||||
# Process relationships
|
||||
# TODO There is performance when iterating get_all_labels for PostgresSQL
|
||||
for node_label in all_labels:
|
||||
node_edges = await self.chunk_entity_relation_graph.get_node_edges(
|
||||
node_label
|
||||
)
|
||||
if node_edges:
|
||||
for src, tgt in node_edges:
|
||||
# To avoid processing the same edge twice in an undirected graph
|
||||
if (tgt, src) in relationships_to_delete or (
|
||||
tgt,
|
||||
src,
|
||||
) in relationships_to_rebuild:
|
||||
continue
|
||||
for edge_data in affected_edges:
|
||||
src = edge_data.get("source")
|
||||
tgt = edge_data.get("target")
|
||||
|
||||
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
||||
src, tgt
|
||||
)
|
||||
if edge_data and "source_id" in edge_data:
|
||||
# Split source_id using GRAPH_FIELD_SEP
|
||||
sources = set(
|
||||
edge_data["source_id"].split(GRAPH_FIELD_SEP)
|
||||
)
|
||||
remaining_sources = sources - chunk_ids
|
||||
if src and tgt and "source_id" in edge_data:
|
||||
edge_tuple = tuple(sorted((src, tgt)))
|
||||
if (
|
||||
edge_tuple in relationships_to_delete
|
||||
or edge_tuple in relationships_to_rebuild
|
||||
):
|
||||
continue
|
||||
|
||||
if not remaining_sources:
|
||||
relationships_to_delete.add((src, tgt))
|
||||
logger.debug(
|
||||
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
|
||||
)
|
||||
elif remaining_sources != sources:
|
||||
# Relationship needs to be rebuilt from remaining chunks
|
||||
relationships_to_rebuild[(src, tgt)] = (
|
||||
remaining_sources
|
||||
)
|
||||
logger.debug(
|
||||
f"Relationship {src}-{tgt} will be rebuilt from {len(remaining_sources)} remaining chunks"
|
||||
)
|
||||
sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||
remaining_sources = sources - chunk_ids
|
||||
|
||||
if not remaining_sources:
|
||||
relationships_to_delete.add(edge_tuple)
|
||||
elif remaining_sources != sources:
|
||||
relationships_to_rebuild[edge_tuple] = remaining_sources
|
||||
|
||||
# 5. Delete chunks from storage
|
||||
if chunk_ids:
|
||||
|
||||
@ -33,7 +33,8 @@ from .base import (
|
||||
TextChunkSchema,
|
||||
QueryParam,
|
||||
)
|
||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||
from .prompt import PROMPTS
|
||||
from .constants import GRAPH_FIELD_SEP
|
||||
import time
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
PROMPTS: dict[str, Any] = {}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user