From da46b341dc1b2c6c578439374ed45a30bea493db Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 25 Jun 2025 12:37:57 +0800 Subject: [PATCH] feat: Optimize document deletion performance - To enhance performance during document deletion, new batch-get methods, `get_nodes_by_chunk_ids` and `get_edges_by_chunk_ids`, have been added to the graph storage layer (`BaseGraphStorage` and its implementations). The [`adelete_by_doc_id`](lightrag/lightrag.py:1681) function now leverages these methods to avoid unnecessary iteration over the entire knowledge graph, significantly improving efficiency. - Graph storage updated: Networkx, Neo4j, Postgres AGE --- lightrag/base.py | 62 ++++++++++++++++++++ lightrag/constants.py | 3 + lightrag/kg/neo4j_impl.py | 42 ++++++++++++++ lightrag/kg/networkx_impl.py | 28 +++++++++ lightrag/kg/postgres_impl.py | 107 +++++++++++++++++++++++++++++++---- lightrag/lightrag.py | 88 ++++++++++++---------------- lightrag/operate.py | 3 +- lightrag/prompt.py | 1 - 8 files changed, 271 insertions(+), 63 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 84fc7564..add2318e 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -14,6 +14,7 @@ from typing import ( ) from .utils import EmbeddingFunc from .types import KnowledgeGraph +from .constants import GRAPH_FIELD_SEP # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance @@ -456,6 +457,67 @@ class BaseGraphStorage(StorageNameSpace, ABC): result[node_id] = edges if edges is not None else [] return result + @abstractmethod + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids. + + Args: + chunk_ids (list[str]): A list of chunk IDs to find associated nodes for. + + Returns: + list[dict]: A list of nodes, where each node is a dictionary of its properties. + An empty list if no matching nodes are found. + """ + # Default implementation iterates through all nodes, which is inefficient. + # This method should be overridden by subclasses for better performance. + all_nodes = [] + all_labels = await self.get_all_labels() + for label in all_labels: + node = await self.get_node(label) + if node and "source_id" in node: + source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP)) + if not source_ids.isdisjoint(chunk_ids): + all_nodes.append(node) + return all_nodes + + @abstractmethod + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids. + + Args: + chunk_ids (list[str]): A list of chunk IDs to find associated edges for. + + Returns: + list[dict]: A list of edges, where each edge is a dictionary of its properties. + An empty list if no matching edges are found. + """ + # Default implementation iterates through all nodes and their edges, which is inefficient. + # This method should be overridden by subclasses for better performance. + all_edges = [] + all_labels = await self.get_all_labels() + processed_edges = set() + + for label in all_labels: + edges = await self.get_node_edges(label) + if edges: + for src_id, tgt_id in edges: + # Avoid processing the same edge twice in an undirected graph + edge_tuple = tuple(sorted((src_id, tgt_id))) + if edge_tuple in processed_edges: + continue + processed_edges.add(edge_tuple) + + edge = await self.get_edge(src_id, tgt_id) + if edge and "source_id" in edge: + source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP)) + if not source_ids.isdisjoint(chunk_ids): + # Add source and target to the edge dict for easier processing later + edge_with_nodes = edge.copy() + edge_with_nodes["source"] = src_id + edge_with_nodes["target"] = tgt_id + all_edges.append(edge_with_nodes) + return all_edges + @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Insert a new node or update an existing node in the graph. diff --git a/lightrag/constants.py b/lightrag/constants.py index 787e1c49..f8345994 100644 --- a/lightrag/constants.py +++ b/lightrag/constants.py @@ -12,6 +12,9 @@ DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6 DEFAULT_WOKERS = 2 DEFAULT_TIMEOUT = 150 +# Separator for graph fields +GRAPH_FIELD_SEP = "" + # Logging configuration defaults DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 7fe3da15..d4fbc59c 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -16,6 +16,7 @@ import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP import pipmaster as pm if not pm.is_installed("neo4j"): @@ -725,6 +726,47 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are fully consumed return edges_dict + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (n:base) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) + RETURN DISTINCT n + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + nodes = [] + async for record in result: + node = record["n"] + node_dict = dict(node) + # Add node id (entity_id) to the dictionary for easier access + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + await result.consume() + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (a:base)-[r]-(b:base) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + edges = [] + async for record in result: + edge_properties = record["properties"] + edge_properties["source"] = record["source"] + edge_properties["target"] = record["target"] + edges.append(edge_properties) + await result.consume() + return edges + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index c92bbd30..a4c46122 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -5,6 +5,7 @@ from typing import final from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from lightrag.base import BaseGraphStorage +from lightrag.constants import GRAPH_FIELD_SEP import pipmaster as pm @@ -357,6 +358,33 @@ class NetworkXStorage(BaseGraphStorage): ) return result + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + chunk_ids_set = set(chunk_ids) + graph = await self._get_graph() + matching_nodes = [] + for node_id, node_data in graph.nodes(data=True): + if "source_id" in node_data: + node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP)) + if not node_source_ids.isdisjoint(chunk_ids_set): + node_data_with_id = node_data.copy() + node_data_with_id["id"] = node_id + matching_nodes.append(node_data_with_id) + return matching_nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + chunk_ids_set = set(chunk_ids) + graph = await self._get_graph() + matching_edges = [] + for u, v, edge_data in graph.edges(data=True): + if "source_id" in edge_data: + edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP)) + if not edge_source_ids.isdisjoint(chunk_ids_set): + edge_data_with_nodes = edge_data.copy() + edge_data_with_nodes["source"] = u + edge_data_with_nodes["target"] = v + matching_edges.append(edge_data_with_nodes) + return matching_edges + async def index_done_callback(self) -> bool: """Save data to disk""" async with self._storage_lock: diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index bacd8894..888b97c7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -27,6 +27,7 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger +from ..constants import GRAPH_FIELD_SEP import pipmaster as pm @@ -1422,8 +1423,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(node_dict, str): try: - import json - node_dict = json.loads(node_dict) except json.JSONDecodeError: logger.warning(f"Failed to parse node string: {node_dict}") @@ -1479,8 +1478,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(result, str): try: - import json - result = json.loads(result) except json.JSONDecodeError: logger.warning(f"Failed to parse edge string: {result}") @@ -1697,8 +1694,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(node_dict, str): try: - import json - node_dict = json.loads(node_dict) except json.JSONDecodeError: logger.warning( @@ -1861,8 +1856,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(edge_props, str): try: - import json - edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( @@ -1879,8 +1872,6 @@ class PGGraphStorage(BaseGraphStorage): # Process string result, parse it to JSON dictionary if isinstance(edge_props, str): try: - import json - edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( @@ -1975,6 +1966,102 @@ class PGGraphStorage(BaseGraphStorage): labels.append(result["label"]) return labels + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """ + Retrieves nodes from the graph that are associated with a given list of chunk IDs. + This method uses a Cypher query with UNWIND to efficiently find all nodes + where the `source_id` property contains any of the specified chunk IDs. + """ + # The string representation of the list for the cypher query + chunk_ids_str = json.dumps(chunk_ids) + + query = f""" + SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND {chunk_ids_str} AS chunk_id + MATCH (n:base) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}') + RETURN n + $$) AS (n agtype); + """ + results = await self._query(query) + + # Build result list + nodes = [] + for result in results: + if result["n"]: + node_dict = result["n"]["properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(node_dict, str): + try: + node_dict = json.loads(node_dict) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse node string in batch: {node_dict}" + ) + + node_dict["id"] = node_dict["entity_id"] + nodes.append(node_dict) + + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """ + Retrieves edges from the graph that are associated with a given list of chunk IDs. + This method uses a Cypher query with UNWIND to efficiently find all edges + where the `source_id` property contains any of the specified chunk IDs. + """ + chunk_ids_str = json.dumps(chunk_ids) + + query = f""" + SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND {chunk_ids_str} AS chunk_id + MATCH (a:base)-[r]-(b:base) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}') + RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target + $$) AS (edge agtype, source agtype, target agtype); + """ + results = await self._query(query) + edges = [] + if results: + for item in results: + edge_agtype = item["edge"]["properties"] + # Process string result, parse it to JSON dictionary + if isinstance(edge_agtype, str): + try: + edge_agtype = json.loads(edge_agtype) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge string in batch: {edge_agtype}" + ) + + source_agtype = item["source"]["properties"] + # Process string result, parse it to JSON dictionary + if isinstance(source_agtype, str): + try: + source_agtype = json.loads(source_agtype) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse node string in batch: {source_agtype}" + ) + + target_agtype = item["target"]["properties"] + # Process string result, parse it to JSON dictionary + if isinstance(target_agtype, str): + try: + target_agtype = json.loads(target_agtype) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse node string in batch: {target_agtype}" + ) + + if edge_agtype and source_agtype and target_agtype: + edge_properties = edge_agtype + edge_properties["source"] = source_agtype["entity_id"] + edge_properties["target"] = target_agtype["entity_id"] + edges.append(edge_properties) + return edges + async def _bfs_subgraph( self, node_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index f631992d..b94709f2 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -60,7 +60,7 @@ from .operate import ( query_with_keywords, _rebuild_knowledge_from_chunks, ) -from .prompt import GRAPH_FIELD_SEP +from .constants import GRAPH_FIELD_SEP from .utils import ( Tokenizer, TiktokenTokenizer, @@ -1761,68 +1761,54 @@ class LightRAG: # Use graph database lock to ensure atomic merges and updates graph_db_lock = get_graph_db_lock(enable_logging=False) async with graph_db_lock: - # Process entities - # TODO There is performance when iterating get_all_labels for PostgresSQL - all_labels = await self.chunk_entity_relation_graph.get_all_labels() - for node_label in all_labels: - node_data = await self.chunk_entity_relation_graph.get_node( - node_label + # Get all affected nodes and edges in batch + affected_nodes = ( + await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids( + list(chunk_ids) ) - if node_data and "source_id" in node_data: - # Split source_id using GRAPH_FIELD_SEP + ) + affected_edges = ( + await self.chunk_entity_relation_graph.get_edges_by_chunk_ids( + list(chunk_ids) + ) + ) + + # logger.info(f"chunk_ids: {chunk_ids}") + # logger.info(f"affected_nodes: {affected_nodes}") + # logger.info(f"affected_edges: {affected_edges}") + + # Process entities + for node_data in affected_nodes: + node_label = node_data.get("entity_id") + if node_label and "source_id" in node_data: sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP)) remaining_sources = sources - chunk_ids if not remaining_sources: entities_to_delete.add(node_label) - logger.debug( - f"Entity {node_label} marked for deletion - no remaining sources" - ) elif remaining_sources != sources: - # Entity needs to be rebuilt from remaining chunks entities_to_rebuild[node_label] = remaining_sources - logger.debug( - f"Entity {node_label} will be rebuilt from {len(remaining_sources)} remaining chunks" - ) # Process relationships - # TODO There is performance when iterating get_all_labels for PostgresSQL - for node_label in all_labels: - node_edges = await self.chunk_entity_relation_graph.get_node_edges( - node_label - ) - if node_edges: - for src, tgt in node_edges: - # To avoid processing the same edge twice in an undirected graph - if (tgt, src) in relationships_to_delete or ( - tgt, - src, - ) in relationships_to_rebuild: - continue + for edge_data in affected_edges: + src = edge_data.get("source") + tgt = edge_data.get("target") - edge_data = await self.chunk_entity_relation_graph.get_edge( - src, tgt - ) - if edge_data and "source_id" in edge_data: - # Split source_id using GRAPH_FIELD_SEP - sources = set( - edge_data["source_id"].split(GRAPH_FIELD_SEP) - ) - remaining_sources = sources - chunk_ids + if src and tgt and "source_id" in edge_data: + edge_tuple = tuple(sorted((src, tgt))) + if ( + edge_tuple in relationships_to_delete + or edge_tuple in relationships_to_rebuild + ): + continue - if not remaining_sources: - relationships_to_delete.add((src, tgt)) - logger.debug( - f"Relationship {src}-{tgt} marked for deletion - no remaining sources" - ) - elif remaining_sources != sources: - # Relationship needs to be rebuilt from remaining chunks - relationships_to_rebuild[(src, tgt)] = ( - remaining_sources - ) - logger.debug( - f"Relationship {src}-{tgt} will be rebuilt from {len(remaining_sources)} remaining chunks" - ) + sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP)) + remaining_sources = sources - chunk_ids + + if not remaining_sources: + relationships_to_delete.add(edge_tuple) + elif remaining_sources != sources: + relationships_to_rebuild[edge_tuple] = remaining_sources # 5. Delete chunks from storage if chunk_ids: diff --git a/lightrag/operate.py b/lightrag/operate.py index b19f739c..d5026203 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -33,7 +33,8 @@ from .base import ( TextChunkSchema, QueryParam, ) -from .prompt import GRAPH_FIELD_SEP, PROMPTS +from .prompt import PROMPTS +from .constants import GRAPH_FIELD_SEP import time from dotenv import load_dotenv diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 5ed630f9..a4641480 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any -GRAPH_FIELD_SEP = "" PROMPTS: dict[str, Any] = {}