mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-11-04 11:49:29 +00:00 
			
		
		
		
	feat: Optimize document deletion performance
- To enhance performance during document deletion, new batch-get methods, `get_nodes_by_chunk_ids` and `get_edges_by_chunk_ids`, have been added to the graph storage layer (`BaseGraphStorage` and its implementations). The [`adelete_by_doc_id`](lightrag/lightrag.py:1681) function now leverages these methods to avoid unnecessary iteration over the entire knowledge graph, significantly improving efficiency. - Graph storage updated: Networkx, Neo4j, Postgres AGE
This commit is contained in:
		
							parent
							
								
									ebe5b1e0d2
								
							
						
					
					
						commit
						da46b341dc
					
				@ -14,6 +14,7 @@ from typing import (
 | 
			
		||||
)
 | 
			
		||||
from .utils import EmbeddingFunc
 | 
			
		||||
from .types import KnowledgeGraph
 | 
			
		||||
from .constants import GRAPH_FIELD_SEP
 | 
			
		||||
 | 
			
		||||
# use the .env that is inside the current folder
 | 
			
		||||
# allows to use different .env file for each lightrag instance
 | 
			
		||||
@ -456,6 +457,67 @@ class BaseGraphStorage(StorageNameSpace, ABC):
 | 
			
		||||
            result[node_id] = edges if edges is not None else []
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 | 
			
		||||
        """Get all nodes that are associated with the given chunk_ids.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list[dict]: A list of nodes, where each node is a dictionary of its properties.
 | 
			
		||||
                        An empty list if no matching nodes are found.
 | 
			
		||||
        """
 | 
			
		||||
        # Default implementation iterates through all nodes, which is inefficient.
 | 
			
		||||
        # This method should be overridden by subclasses for better performance.
 | 
			
		||||
        all_nodes = []
 | 
			
		||||
        all_labels = await self.get_all_labels()
 | 
			
		||||
        for label in all_labels:
 | 
			
		||||
            node = await self.get_node(label)
 | 
			
		||||
            if node and "source_id" in node:
 | 
			
		||||
                source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP))
 | 
			
		||||
                if not source_ids.isdisjoint(chunk_ids):
 | 
			
		||||
                    all_nodes.append(node)
 | 
			
		||||
        return all_nodes
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 | 
			
		||||
        """Get all edges that are associated with the given chunk_ids.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list[dict]: A list of edges, where each edge is a dictionary of its properties.
 | 
			
		||||
                        An empty list if no matching edges are found.
 | 
			
		||||
        """
 | 
			
		||||
        # Default implementation iterates through all nodes and their edges, which is inefficient.
 | 
			
		||||
        # This method should be overridden by subclasses for better performance.
 | 
			
		||||
        all_edges = []
 | 
			
		||||
        all_labels = await self.get_all_labels()
 | 
			
		||||
        processed_edges = set()
 | 
			
		||||
 | 
			
		||||
        for label in all_labels:
 | 
			
		||||
            edges = await self.get_node_edges(label)
 | 
			
		||||
            if edges:
 | 
			
		||||
                for src_id, tgt_id in edges:
 | 
			
		||||
                    # Avoid processing the same edge twice in an undirected graph
 | 
			
		||||
                    edge_tuple = tuple(sorted((src_id, tgt_id)))
 | 
			
		||||
                    if edge_tuple in processed_edges:
 | 
			
		||||
                        continue
 | 
			
		||||
                    processed_edges.add(edge_tuple)
 | 
			
		||||
 | 
			
		||||
                    edge = await self.get_edge(src_id, tgt_id)
 | 
			
		||||
                    if edge and "source_id" in edge:
 | 
			
		||||
                        source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
 | 
			
		||||
                        if not source_ids.isdisjoint(chunk_ids):
 | 
			
		||||
                            # Add source and target to the edge dict for easier processing later
 | 
			
		||||
                            edge_with_nodes = edge.copy()
 | 
			
		||||
                            edge_with_nodes["source"] = src_id
 | 
			
		||||
                            edge_with_nodes["target"] = tgt_id
 | 
			
		||||
                            all_edges.append(edge_with_nodes)
 | 
			
		||||
        return all_edges
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
 | 
			
		||||
        """Insert a new node or update an existing node in the graph.
 | 
			
		||||
 | 
			
		||||
@ -12,6 +12,9 @@ DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
 | 
			
		||||
DEFAULT_WOKERS = 2
 | 
			
		||||
DEFAULT_TIMEOUT = 150
 | 
			
		||||
 | 
			
		||||
# Separator for graph fields
 | 
			
		||||
GRAPH_FIELD_SEP = "<SEP>"
 | 
			
		||||
 | 
			
		||||
# Logging configuration defaults
 | 
			
		||||
DEFAULT_LOG_MAX_BYTES = 10485760  # Default 10MB
 | 
			
		||||
DEFAULT_LOG_BACKUP_COUNT = 5  # Default 5 backups
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ import logging
 | 
			
		||||
from ..utils import logger
 | 
			
		||||
from ..base import BaseGraphStorage
 | 
			
		||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
 | 
			
		||||
from ..constants import GRAPH_FIELD_SEP
 | 
			
		||||
import pipmaster as pm
 | 
			
		||||
 | 
			
		||||
if not pm.is_installed("neo4j"):
 | 
			
		||||
@ -725,6 +726,47 @@ class Neo4JStorage(BaseGraphStorage):
 | 
			
		||||
            await result.consume()  # Ensure results are fully consumed
 | 
			
		||||
            return edges_dict
 | 
			
		||||
 | 
			
		||||
    async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 | 
			
		||||
        async with self._driver.session(
 | 
			
		||||
            database=self._DATABASE, default_access_mode="READ"
 | 
			
		||||
        ) as session:
 | 
			
		||||
            query = """
 | 
			
		||||
            UNWIND $chunk_ids AS chunk_id
 | 
			
		||||
            MATCH (n:base)
 | 
			
		||||
            WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
 | 
			
		||||
            RETURN DISTINCT n
 | 
			
		||||
            """
 | 
			
		||||
            result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
 | 
			
		||||
            nodes = []
 | 
			
		||||
            async for record in result:
 | 
			
		||||
                node = record["n"]
 | 
			
		||||
                node_dict = dict(node)
 | 
			
		||||
                # Add node id (entity_id) to the dictionary for easier access
 | 
			
		||||
                node_dict["id"] = node_dict.get("entity_id")
 | 
			
		||||
                nodes.append(node_dict)
 | 
			
		||||
            await result.consume()
 | 
			
		||||
            return nodes
 | 
			
		||||
 | 
			
		||||
    async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 | 
			
		||||
        async with self._driver.session(
 | 
			
		||||
            database=self._DATABASE, default_access_mode="READ"
 | 
			
		||||
        ) as session:
 | 
			
		||||
            query = """
 | 
			
		||||
            UNWIND $chunk_ids AS chunk_id
 | 
			
		||||
            MATCH (a:base)-[r]-(b:base)
 | 
			
		||||
            WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
 | 
			
		||||
            RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
 | 
			
		||||
            """
 | 
			
		||||
            result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
 | 
			
		||||
            edges = []
 | 
			
		||||
            async for record in result:
 | 
			
		||||
                edge_properties = record["properties"]
 | 
			
		||||
                edge_properties["source"] = record["source"]
 | 
			
		||||
                edge_properties["target"] = record["target"]
 | 
			
		||||
                edges.append(edge_properties)
 | 
			
		||||
            await result.consume()
 | 
			
		||||
            return edges
 | 
			
		||||
 | 
			
		||||
    @retry(
 | 
			
		||||
        stop=stop_after_attempt(3),
 | 
			
		||||
        wait=wait_exponential(multiplier=1, min=4, max=10),
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,7 @@ from typing import final
 | 
			
		||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
 | 
			
		||||
from lightrag.utils import logger
 | 
			
		||||
from lightrag.base import BaseGraphStorage
 | 
			
		||||
from lightrag.constants import GRAPH_FIELD_SEP
 | 
			
		||||
 | 
			
		||||
import pipmaster as pm
 | 
			
		||||
 | 
			
		||||
@ -357,6 +358,33 @@ class NetworkXStorage(BaseGraphStorage):
 | 
			
		||||
        )
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 | 
			
		||||
        chunk_ids_set = set(chunk_ids)
 | 
			
		||||
        graph = await self._get_graph()
 | 
			
		||||
        matching_nodes = []
 | 
			
		||||
        for node_id, node_data in graph.nodes(data=True):
 | 
			
		||||
            if "source_id" in node_data:
 | 
			
		||||
                node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
 | 
			
		||||
                if not node_source_ids.isdisjoint(chunk_ids_set):
 | 
			
		||||
                    node_data_with_id = node_data.copy()
 | 
			
		||||
                    node_data_with_id["id"] = node_id
 | 
			
		||||
                    matching_nodes.append(node_data_with_id)
 | 
			
		||||
        return matching_nodes
 | 
			
		||||
 | 
			
		||||
    async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 | 
			
		||||
        chunk_ids_set = set(chunk_ids)
 | 
			
		||||
        graph = await self._get_graph()
 | 
			
		||||
        matching_edges = []
 | 
			
		||||
        for u, v, edge_data in graph.edges(data=True):
 | 
			
		||||
            if "source_id" in edge_data:
 | 
			
		||||
                edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
 | 
			
		||||
                if not edge_source_ids.isdisjoint(chunk_ids_set):
 | 
			
		||||
                    edge_data_with_nodes = edge_data.copy()
 | 
			
		||||
                    edge_data_with_nodes["source"] = u
 | 
			
		||||
                    edge_data_with_nodes["target"] = v
 | 
			
		||||
                    matching_edges.append(edge_data_with_nodes)
 | 
			
		||||
        return matching_edges
 | 
			
		||||
 | 
			
		||||
    async def index_done_callback(self) -> bool:
 | 
			
		||||
        """Save data to disk"""
 | 
			
		||||
        async with self._storage_lock:
 | 
			
		||||
 | 
			
		||||
@ -27,6 +27,7 @@ from ..base import (
 | 
			
		||||
)
 | 
			
		||||
from ..namespace import NameSpace, is_namespace
 | 
			
		||||
from ..utils import logger
 | 
			
		||||
from ..constants import GRAPH_FIELD_SEP
 | 
			
		||||
 | 
			
		||||
import pipmaster as pm
 | 
			
		||||
 | 
			
		||||
@ -1422,8 +1423,6 @@ class PGGraphStorage(BaseGraphStorage):
 | 
			
		||||
            # Process string result, parse it to JSON dictionary
 | 
			
		||||
            if isinstance(node_dict, str):
 | 
			
		||||
                try:
 | 
			
		||||
                    import json
 | 
			
		||||
 | 
			
		||||
                    node_dict = json.loads(node_dict)
 | 
			
		||||
                except json.JSONDecodeError:
 | 
			
		||||
                    logger.warning(f"Failed to parse node string: {node_dict}")
 | 
			
		||||
@ -1479,8 +1478,6 @@ class PGGraphStorage(BaseGraphStorage):
 | 
			
		||||
            # Process string result, parse it to JSON dictionary
 | 
			
		||||
            if isinstance(result, str):
 | 
			
		||||
                try:
 | 
			
		||||
                    import json
 | 
			
		||||
 | 
			
		||||
                    result = json.loads(result)
 | 
			
		||||
                except json.JSONDecodeError:
 | 
			
		||||
                    logger.warning(f"Failed to parse edge string: {result}")
 | 
			
		||||
@ -1697,8 +1694,6 @@ class PGGraphStorage(BaseGraphStorage):
 | 
			
		||||
                # Process string result, parse it to JSON dictionary
 | 
			
		||||
                if isinstance(node_dict, str):
 | 
			
		||||
                    try:
 | 
			
		||||
                        import json
 | 
			
		||||
 | 
			
		||||
                        node_dict = json.loads(node_dict)
 | 
			
		||||
                    except json.JSONDecodeError:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
@ -1861,8 +1856,6 @@ class PGGraphStorage(BaseGraphStorage):
 | 
			
		||||
                # Process string result, parse it to JSON dictionary
 | 
			
		||||
                if isinstance(edge_props, str):
 | 
			
		||||
                    try:
 | 
			
		||||
                        import json
 | 
			
		||||
 | 
			
		||||
                        edge_props = json.loads(edge_props)
 | 
			
		||||
                    except json.JSONDecodeError:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
@ -1879,8 +1872,6 @@ class PGGraphStorage(BaseGraphStorage):
 | 
			
		||||
                # Process string result, parse it to JSON dictionary
 | 
			
		||||
                if isinstance(edge_props, str):
 | 
			
		||||
                    try:
 | 
			
		||||
                        import json
 | 
			
		||||
 | 
			
		||||
                        edge_props = json.loads(edge_props)
 | 
			
		||||
                    except json.JSONDecodeError:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
@ -1975,6 +1966,102 @@ class PGGraphStorage(BaseGraphStorage):
 | 
			
		||||
                labels.append(result["label"])
 | 
			
		||||
        return labels
 | 
			
		||||
 | 
			
		||||
    async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieves nodes from the graph that are associated with a given list of chunk IDs.
 | 
			
		||||
        This method uses a Cypher query with UNWIND to efficiently find all nodes
 | 
			
		||||
        where the `source_id` property contains any of the specified chunk IDs.
 | 
			
		||||
        """
 | 
			
		||||
        # The string representation of the list for the cypher query
 | 
			
		||||
        chunk_ids_str = json.dumps(chunk_ids)
 | 
			
		||||
 | 
			
		||||
        query = f"""
 | 
			
		||||
            SELECT * FROM cypher('{self.graph_name}', $$
 | 
			
		||||
                UNWIND {chunk_ids_str} AS chunk_id
 | 
			
		||||
                MATCH (n:base)
 | 
			
		||||
                WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
 | 
			
		||||
                RETURN n
 | 
			
		||||
            $$) AS (n agtype);
 | 
			
		||||
        """
 | 
			
		||||
        results = await self._query(query)
 | 
			
		||||
 | 
			
		||||
        # Build result list
 | 
			
		||||
        nodes = []
 | 
			
		||||
        for result in results:
 | 
			
		||||
            if result["n"]:
 | 
			
		||||
                node_dict = result["n"]["properties"]
 | 
			
		||||
 | 
			
		||||
                # Process string result, parse it to JSON dictionary
 | 
			
		||||
                if isinstance(node_dict, str):
 | 
			
		||||
                    try:
 | 
			
		||||
                        node_dict = json.loads(node_dict)
 | 
			
		||||
                    except json.JSONDecodeError:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
                            f"Failed to parse node string in batch: {node_dict}"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                node_dict["id"] = node_dict["entity_id"]                
 | 
			
		||||
                nodes.append(node_dict)
 | 
			
		||||
 | 
			
		||||
        return nodes
 | 
			
		||||
 | 
			
		||||
    async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieves edges from the graph that are associated with a given list of chunk IDs.
 | 
			
		||||
        This method uses a Cypher query with UNWIND to efficiently find all edges
 | 
			
		||||
        where the `source_id` property contains any of the specified chunk IDs.
 | 
			
		||||
        """
 | 
			
		||||
        chunk_ids_str = json.dumps(chunk_ids)
 | 
			
		||||
 | 
			
		||||
        query = f"""
 | 
			
		||||
            SELECT * FROM cypher('{self.graph_name}', $$
 | 
			
		||||
                UNWIND {chunk_ids_str} AS chunk_id
 | 
			
		||||
                MATCH (a:base)-[r]-(b:base)
 | 
			
		||||
                WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
 | 
			
		||||
                RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
 | 
			
		||||
            $$) AS (edge agtype, source agtype, target agtype);
 | 
			
		||||
        """
 | 
			
		||||
        results = await self._query(query)
 | 
			
		||||
        edges = []
 | 
			
		||||
        if results:
 | 
			
		||||
            for item in results:
 | 
			
		||||
                edge_agtype = item["edge"]["properties"]
 | 
			
		||||
                # Process string result, parse it to JSON dictionary
 | 
			
		||||
                if isinstance(edge_agtype, str):
 | 
			
		||||
                    try:
 | 
			
		||||
                        edge_agtype = json.loads(edge_agtype)
 | 
			
		||||
                    except json.JSONDecodeError:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
                            f"Failed to parse edge string in batch: {edge_agtype}"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                source_agtype = item["source"]["properties"]
 | 
			
		||||
                # Process string result, parse it to JSON dictionary
 | 
			
		||||
                if isinstance(source_agtype, str):
 | 
			
		||||
                    try:
 | 
			
		||||
                        source_agtype = json.loads(source_agtype)
 | 
			
		||||
                    except json.JSONDecodeError:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
                            f"Failed to parse node string in batch: {source_agtype}"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                target_agtype = item["target"]["properties"]
 | 
			
		||||
                # Process string result, parse it to JSON dictionary
 | 
			
		||||
                if isinstance(target_agtype, str):
 | 
			
		||||
                    try:
 | 
			
		||||
                        target_agtype = json.loads(target_agtype)
 | 
			
		||||
                    except json.JSONDecodeError:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
                            f"Failed to parse node string in batch: {target_agtype}"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                if edge_agtype and source_agtype and target_agtype:
 | 
			
		||||
                    edge_properties = edge_agtype
 | 
			
		||||
                    edge_properties["source"] = source_agtype["entity_id"]
 | 
			
		||||
                    edge_properties["target"] = target_agtype["entity_id"]
 | 
			
		||||
                    edges.append(edge_properties)
 | 
			
		||||
        return edges
 | 
			
		||||
 | 
			
		||||
    async def _bfs_subgraph(
 | 
			
		||||
        self, node_label: str, max_depth: int, max_nodes: int
 | 
			
		||||
    ) -> KnowledgeGraph:
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ from .operate import (
 | 
			
		||||
    query_with_keywords,
 | 
			
		||||
    _rebuild_knowledge_from_chunks,
 | 
			
		||||
)
 | 
			
		||||
from .prompt import GRAPH_FIELD_SEP
 | 
			
		||||
from .constants import GRAPH_FIELD_SEP
 | 
			
		||||
from .utils import (
 | 
			
		||||
    Tokenizer,
 | 
			
		||||
    TiktokenTokenizer,
 | 
			
		||||
@ -1761,68 +1761,54 @@ class LightRAG:
 | 
			
		||||
            # Use graph database lock to ensure atomic merges and updates
 | 
			
		||||
            graph_db_lock = get_graph_db_lock(enable_logging=False)
 | 
			
		||||
            async with graph_db_lock:
 | 
			
		||||
                # Process entities
 | 
			
		||||
                # TODO There is performance when iterating get_all_labels for PostgresSQL
 | 
			
		||||
                all_labels = await self.chunk_entity_relation_graph.get_all_labels()
 | 
			
		||||
                for node_label in all_labels:
 | 
			
		||||
                    node_data = await self.chunk_entity_relation_graph.get_node(
 | 
			
		||||
                        node_label
 | 
			
		||||
                # Get all affected nodes and edges in batch
 | 
			
		||||
                affected_nodes = (
 | 
			
		||||
                    await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids(
 | 
			
		||||
                        list(chunk_ids)
 | 
			
		||||
                    )
 | 
			
		||||
                    if node_data and "source_id" in node_data:
 | 
			
		||||
                        # Split source_id using GRAPH_FIELD_SEP
 | 
			
		||||
                )
 | 
			
		||||
                affected_edges = (
 | 
			
		||||
                    await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
 | 
			
		||||
                        list(chunk_ids)
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # logger.info(f"chunk_ids: {chunk_ids}")
 | 
			
		||||
                # logger.info(f"affected_nodes: {affected_nodes}")
 | 
			
		||||
                # logger.info(f"affected_edges: {affected_edges}")
 | 
			
		||||
 | 
			
		||||
                # Process entities
 | 
			
		||||
                for node_data in affected_nodes:
 | 
			
		||||
                    node_label = node_data.get("entity_id")
 | 
			
		||||
                    if node_label and "source_id" in node_data:
 | 
			
		||||
                        sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
 | 
			
		||||
                        remaining_sources = sources - chunk_ids
 | 
			
		||||
 | 
			
		||||
                        if not remaining_sources:
 | 
			
		||||
                            entities_to_delete.add(node_label)
 | 
			
		||||
                            logger.debug(
 | 
			
		||||
                                f"Entity {node_label} marked for deletion - no remaining sources"
 | 
			
		||||
                            )
 | 
			
		||||
                        elif remaining_sources != sources:
 | 
			
		||||
                            # Entity needs to be rebuilt from remaining chunks
 | 
			
		||||
                            entities_to_rebuild[node_label] = remaining_sources
 | 
			
		||||
                            logger.debug(
 | 
			
		||||
                                f"Entity {node_label} will be rebuilt from {len(remaining_sources)} remaining chunks"
 | 
			
		||||
                            )
 | 
			
		||||
 | 
			
		||||
                # Process relationships
 | 
			
		||||
                # TODO There is performance when iterating get_all_labels for PostgresSQL
 | 
			
		||||
                for node_label in all_labels:
 | 
			
		||||
                    node_edges = await self.chunk_entity_relation_graph.get_node_edges(
 | 
			
		||||
                        node_label
 | 
			
		||||
                    )
 | 
			
		||||
                    if node_edges:
 | 
			
		||||
                        for src, tgt in node_edges:
 | 
			
		||||
                            # To avoid processing the same edge twice in an undirected graph
 | 
			
		||||
                            if (tgt, src) in relationships_to_delete or (
 | 
			
		||||
                                tgt,
 | 
			
		||||
                                src,
 | 
			
		||||
                            ) in relationships_to_rebuild:
 | 
			
		||||
                                continue
 | 
			
		||||
                for edge_data in affected_edges:
 | 
			
		||||
                    src = edge_data.get("source")
 | 
			
		||||
                    tgt = edge_data.get("target")
 | 
			
		||||
 | 
			
		||||
                            edge_data = await self.chunk_entity_relation_graph.get_edge(
 | 
			
		||||
                                src, tgt
 | 
			
		||||
                            )
 | 
			
		||||
                            if edge_data and "source_id" in edge_data:
 | 
			
		||||
                                # Split source_id using GRAPH_FIELD_SEP
 | 
			
		||||
                                sources = set(
 | 
			
		||||
                                    edge_data["source_id"].split(GRAPH_FIELD_SEP)
 | 
			
		||||
                                )
 | 
			
		||||
                                remaining_sources = sources - chunk_ids
 | 
			
		||||
                    if src and tgt and "source_id" in edge_data:
 | 
			
		||||
                        edge_tuple = tuple(sorted((src, tgt)))
 | 
			
		||||
                        if (
 | 
			
		||||
                            edge_tuple in relationships_to_delete
 | 
			
		||||
                            or edge_tuple in relationships_to_rebuild
 | 
			
		||||
                        ):
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                                if not remaining_sources:
 | 
			
		||||
                                    relationships_to_delete.add((src, tgt))
 | 
			
		||||
                                    logger.debug(
 | 
			
		||||
                                        f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
 | 
			
		||||
                                    )
 | 
			
		||||
                                elif remaining_sources != sources:
 | 
			
		||||
                                    # Relationship needs to be rebuilt from remaining chunks
 | 
			
		||||
                                    relationships_to_rebuild[(src, tgt)] = (
 | 
			
		||||
                                        remaining_sources
 | 
			
		||||
                                    )
 | 
			
		||||
                                    logger.debug(
 | 
			
		||||
                                        f"Relationship {src}-{tgt} will be rebuilt from {len(remaining_sources)} remaining chunks"
 | 
			
		||||
                                    )
 | 
			
		||||
                        sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
 | 
			
		||||
                        remaining_sources = sources - chunk_ids
 | 
			
		||||
 | 
			
		||||
                        if not remaining_sources:
 | 
			
		||||
                            relationships_to_delete.add(edge_tuple)
 | 
			
		||||
                        elif remaining_sources != sources:
 | 
			
		||||
                            relationships_to_rebuild[edge_tuple] = remaining_sources
 | 
			
		||||
 | 
			
		||||
                # 5. Delete chunks from storage
 | 
			
		||||
                if chunk_ids:
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,8 @@ from .base import (
 | 
			
		||||
    TextChunkSchema,
 | 
			
		||||
    QueryParam,
 | 
			
		||||
)
 | 
			
		||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
 | 
			
		||||
from .prompt import PROMPTS
 | 
			
		||||
from .constants import GRAPH_FIELD_SEP
 | 
			
		||||
import time
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
GRAPH_FIELD_SEP = "<SEP>"
 | 
			
		||||
 | 
			
		||||
PROMPTS: dict[str, Any] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user