mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-03 11:20:13 +00:00
feat: Optimize document deletion performance
- To enhance performance during document deletion, new batch-get methods, `get_nodes_by_chunk_ids` and `get_edges_by_chunk_ids`, have been added to the graph storage layer (`BaseGraphStorage` and its implementations). The [`adelete_by_doc_id`](lightrag/lightrag.py:1681) function now leverages these methods to avoid unnecessary iteration over the entire knowledge graph, significantly improving efficiency. - Graph storage updated: Networkx, Neo4j, Postgres AGE
This commit is contained in:
parent
ebe5b1e0d2
commit
da46b341dc
@ -14,6 +14,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
from .types import KnowledgeGraph
|
from .types import KnowledgeGraph
|
||||||
|
from .constants import GRAPH_FIELD_SEP
|
||||||
|
|
||||||
# use the .env that is inside the current folder
|
# use the .env that is inside the current folder
|
||||||
# allows to use different .env file for each lightrag instance
|
# allows to use different .env file for each lightrag instance
|
||||||
@ -456,6 +457,67 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
result[node_id] = edges if edges is not None else []
|
result[node_id] = edges if edges is not None else []
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
|
"""Get all nodes that are associated with the given chunk_ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
||||||
|
An empty list if no matching nodes are found.
|
||||||
|
"""
|
||||||
|
# Default implementation iterates through all nodes, which is inefficient.
|
||||||
|
# This method should be overridden by subclasses for better performance.
|
||||||
|
all_nodes = []
|
||||||
|
all_labels = await self.get_all_labels()
|
||||||
|
for label in all_labels:
|
||||||
|
node = await self.get_node(label)
|
||||||
|
if node and "source_id" in node:
|
||||||
|
source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP))
|
||||||
|
if not source_ids.isdisjoint(chunk_ids):
|
||||||
|
all_nodes.append(node)
|
||||||
|
return all_nodes
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
|
"""Get all edges that are associated with the given chunk_ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
||||||
|
An empty list if no matching edges are found.
|
||||||
|
"""
|
||||||
|
# Default implementation iterates through all nodes and their edges, which is inefficient.
|
||||||
|
# This method should be overridden by subclasses for better performance.
|
||||||
|
all_edges = []
|
||||||
|
all_labels = await self.get_all_labels()
|
||||||
|
processed_edges = set()
|
||||||
|
|
||||||
|
for label in all_labels:
|
||||||
|
edges = await self.get_node_edges(label)
|
||||||
|
if edges:
|
||||||
|
for src_id, tgt_id in edges:
|
||||||
|
# Avoid processing the same edge twice in an undirected graph
|
||||||
|
edge_tuple = tuple(sorted((src_id, tgt_id)))
|
||||||
|
if edge_tuple in processed_edges:
|
||||||
|
continue
|
||||||
|
processed_edges.add(edge_tuple)
|
||||||
|
|
||||||
|
edge = await self.get_edge(src_id, tgt_id)
|
||||||
|
if edge and "source_id" in edge:
|
||||||
|
source_ids = set(edge["source_id"].split(GRAPH_FIELD_SEP))
|
||||||
|
if not source_ids.isdisjoint(chunk_ids):
|
||||||
|
# Add source and target to the edge dict for easier processing later
|
||||||
|
edge_with_nodes = edge.copy()
|
||||||
|
edge_with_nodes["source"] = src_id
|
||||||
|
edge_with_nodes["target"] = tgt_id
|
||||||
|
all_edges.append(edge_with_nodes)
|
||||||
|
return all_edges
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""Insert a new node or update an existing node in the graph.
|
"""Insert a new node or update an existing node in the graph.
|
||||||
|
|||||||
@ -12,6 +12,9 @@ DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
|
|||||||
DEFAULT_WOKERS = 2
|
DEFAULT_WOKERS = 2
|
||||||
DEFAULT_TIMEOUT = 150
|
DEFAULT_TIMEOUT = 150
|
||||||
|
|
||||||
|
# Separator for graph fields
|
||||||
|
GRAPH_FIELD_SEP = "<SEP>"
|
||||||
|
|
||||||
# Logging configuration defaults
|
# Logging configuration defaults
|
||||||
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
|
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
|
||||||
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
|
DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import logging
|
|||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
from ..constants import GRAPH_FIELD_SEP
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
if not pm.is_installed("neo4j"):
|
||||||
@ -725,6 +726,47 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
await result.consume() # Ensure results are fully consumed
|
await result.consume() # Ensure results are fully consumed
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
|
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
|
async with self._driver.session(
|
||||||
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
|
) as session:
|
||||||
|
query = """
|
||||||
|
UNWIND $chunk_ids AS chunk_id
|
||||||
|
MATCH (n:base)
|
||||||
|
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
||||||
|
RETURN DISTINCT n
|
||||||
|
"""
|
||||||
|
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||||
|
nodes = []
|
||||||
|
async for record in result:
|
||||||
|
node = record["n"]
|
||||||
|
node_dict = dict(node)
|
||||||
|
# Add node id (entity_id) to the dictionary for easier access
|
||||||
|
node_dict["id"] = node_dict.get("entity_id")
|
||||||
|
nodes.append(node_dict)
|
||||||
|
await result.consume()
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
|
async with self._driver.session(
|
||||||
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
|
) as session:
|
||||||
|
query = """
|
||||||
|
UNWIND $chunk_ids AS chunk_id
|
||||||
|
MATCH (a:base)-[r]-(b:base)
|
||||||
|
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
||||||
|
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
||||||
|
"""
|
||||||
|
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||||
|
edges = []
|
||||||
|
async for record in result:
|
||||||
|
edge_properties = record["properties"]
|
||||||
|
edge_properties["source"] = record["source"]
|
||||||
|
edge_properties["target"] = record["target"]
|
||||||
|
edges.append(edge_properties)
|
||||||
|
await result.consume()
|
||||||
|
return edges
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import final
|
|||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
from lightrag.base import BaseGraphStorage
|
from lightrag.base import BaseGraphStorage
|
||||||
|
from lightrag.constants import GRAPH_FIELD_SEP
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
@ -357,6 +358,33 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
|
chunk_ids_set = set(chunk_ids)
|
||||||
|
graph = await self._get_graph()
|
||||||
|
matching_nodes = []
|
||||||
|
for node_id, node_data in graph.nodes(data=True):
|
||||||
|
if "source_id" in node_data:
|
||||||
|
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||||
|
if not node_source_ids.isdisjoint(chunk_ids_set):
|
||||||
|
node_data_with_id = node_data.copy()
|
||||||
|
node_data_with_id["id"] = node_id
|
||||||
|
matching_nodes.append(node_data_with_id)
|
||||||
|
return matching_nodes
|
||||||
|
|
||||||
|
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
|
chunk_ids_set = set(chunk_ids)
|
||||||
|
graph = await self._get_graph()
|
||||||
|
matching_edges = []
|
||||||
|
for u, v, edge_data in graph.edges(data=True):
|
||||||
|
if "source_id" in edge_data:
|
||||||
|
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||||
|
if not edge_source_ids.isdisjoint(chunk_ids_set):
|
||||||
|
edge_data_with_nodes = edge_data.copy()
|
||||||
|
edge_data_with_nodes["source"] = u
|
||||||
|
edge_data_with_nodes["target"] = v
|
||||||
|
matching_edges.append(edge_data_with_nodes)
|
||||||
|
return matching_edges
|
||||||
|
|
||||||
async def index_done_callback(self) -> bool:
|
async def index_done_callback(self) -> bool:
|
||||||
"""Save data to disk"""
|
"""Save data to disk"""
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from ..base import (
|
|||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
from ..constants import GRAPH_FIELD_SEP
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
@ -1422,8 +1423,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
# Process string result, parse it to JSON dictionary
|
# Process string result, parse it to JSON dictionary
|
||||||
if isinstance(node_dict, str):
|
if isinstance(node_dict, str):
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
node_dict = json.loads(node_dict)
|
node_dict = json.loads(node_dict)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Failed to parse node string: {node_dict}")
|
logger.warning(f"Failed to parse node string: {node_dict}")
|
||||||
@ -1479,8 +1478,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
# Process string result, parse it to JSON dictionary
|
# Process string result, parse it to JSON dictionary
|
||||||
if isinstance(result, str):
|
if isinstance(result, str):
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
result = json.loads(result)
|
result = json.loads(result)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Failed to parse edge string: {result}")
|
logger.warning(f"Failed to parse edge string: {result}")
|
||||||
@ -1697,8 +1694,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
# Process string result, parse it to JSON dictionary
|
# Process string result, parse it to JSON dictionary
|
||||||
if isinstance(node_dict, str):
|
if isinstance(node_dict, str):
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
node_dict = json.loads(node_dict)
|
node_dict = json.loads(node_dict)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -1861,8 +1856,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
# Process string result, parse it to JSON dictionary
|
# Process string result, parse it to JSON dictionary
|
||||||
if isinstance(edge_props, str):
|
if isinstance(edge_props, str):
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
edge_props = json.loads(edge_props)
|
edge_props = json.loads(edge_props)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -1879,8 +1872,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
# Process string result, parse it to JSON dictionary
|
# Process string result, parse it to JSON dictionary
|
||||||
if isinstance(edge_props, str):
|
if isinstance(edge_props, str):
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
edge_props = json.loads(edge_props)
|
edge_props = json.loads(edge_props)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -1975,6 +1966,102 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
labels.append(result["label"])
|
labels.append(result["label"])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Retrieves nodes from the graph that are associated with a given list of chunk IDs.
|
||||||
|
This method uses a Cypher query with UNWIND to efficiently find all nodes
|
||||||
|
where the `source_id` property contains any of the specified chunk IDs.
|
||||||
|
"""
|
||||||
|
# The string representation of the list for the cypher query
|
||||||
|
chunk_ids_str = json.dumps(chunk_ids)
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
|
UNWIND {chunk_ids_str} AS chunk_id
|
||||||
|
MATCH (n:base)
|
||||||
|
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}')
|
||||||
|
RETURN n
|
||||||
|
$$) AS (n agtype);
|
||||||
|
"""
|
||||||
|
results = await self._query(query)
|
||||||
|
|
||||||
|
# Build result list
|
||||||
|
nodes = []
|
||||||
|
for result in results:
|
||||||
|
if result["n"]:
|
||||||
|
node_dict = result["n"]["properties"]
|
||||||
|
|
||||||
|
# Process string result, parse it to JSON dictionary
|
||||||
|
if isinstance(node_dict, str):
|
||||||
|
try:
|
||||||
|
node_dict = json.loads(node_dict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse node string in batch: {node_dict}"
|
||||||
|
)
|
||||||
|
|
||||||
|
node_dict["id"] = node_dict["entity_id"]
|
||||||
|
nodes.append(node_dict)
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Retrieves edges from the graph that are associated with a given list of chunk IDs.
|
||||||
|
This method uses a Cypher query with UNWIND to efficiently find all edges
|
||||||
|
where the `source_id` property contains any of the specified chunk IDs.
|
||||||
|
"""
|
||||||
|
chunk_ids_str = json.dumps(chunk_ids)
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
|
UNWIND {chunk_ids_str} AS chunk_id
|
||||||
|
MATCH (a:base)-[r]-(b:base)
|
||||||
|
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}')
|
||||||
|
RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target
|
||||||
|
$$) AS (edge agtype, source agtype, target agtype);
|
||||||
|
"""
|
||||||
|
results = await self._query(query)
|
||||||
|
edges = []
|
||||||
|
if results:
|
||||||
|
for item in results:
|
||||||
|
edge_agtype = item["edge"]["properties"]
|
||||||
|
# Process string result, parse it to JSON dictionary
|
||||||
|
if isinstance(edge_agtype, str):
|
||||||
|
try:
|
||||||
|
edge_agtype = json.loads(edge_agtype)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse edge string in batch: {edge_agtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
source_agtype = item["source"]["properties"]
|
||||||
|
# Process string result, parse it to JSON dictionary
|
||||||
|
if isinstance(source_agtype, str):
|
||||||
|
try:
|
||||||
|
source_agtype = json.loads(source_agtype)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse node string in batch: {source_agtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
target_agtype = item["target"]["properties"]
|
||||||
|
# Process string result, parse it to JSON dictionary
|
||||||
|
if isinstance(target_agtype, str):
|
||||||
|
try:
|
||||||
|
target_agtype = json.loads(target_agtype)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse node string in batch: {target_agtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if edge_agtype and source_agtype and target_agtype:
|
||||||
|
edge_properties = edge_agtype
|
||||||
|
edge_properties["source"] = source_agtype["entity_id"]
|
||||||
|
edge_properties["target"] = target_agtype["entity_id"]
|
||||||
|
edges.append(edge_properties)
|
||||||
|
return edges
|
||||||
|
|
||||||
async def _bfs_subgraph(
|
async def _bfs_subgraph(
|
||||||
self, node_label: str, max_depth: int, max_nodes: int
|
self, node_label: str, max_depth: int, max_nodes: int
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
|
|||||||
@ -60,7 +60,7 @@ from .operate import (
|
|||||||
query_with_keywords,
|
query_with_keywords,
|
||||||
_rebuild_knowledge_from_chunks,
|
_rebuild_knowledge_from_chunks,
|
||||||
)
|
)
|
||||||
from .prompt import GRAPH_FIELD_SEP
|
from .constants import GRAPH_FIELD_SEP
|
||||||
from .utils import (
|
from .utils import (
|
||||||
Tokenizer,
|
Tokenizer,
|
||||||
TiktokenTokenizer,
|
TiktokenTokenizer,
|
||||||
@ -1761,68 +1761,54 @@ class LightRAG:
|
|||||||
# Use graph database lock to ensure atomic merges and updates
|
# Use graph database lock to ensure atomic merges and updates
|
||||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||||
async with graph_db_lock:
|
async with graph_db_lock:
|
||||||
# Process entities
|
# Get all affected nodes and edges in batch
|
||||||
# TODO There is performance when iterating get_all_labels for PostgresSQL
|
affected_nodes = (
|
||||||
all_labels = await self.chunk_entity_relation_graph.get_all_labels()
|
await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids(
|
||||||
for node_label in all_labels:
|
list(chunk_ids)
|
||||||
node_data = await self.chunk_entity_relation_graph.get_node(
|
|
||||||
node_label
|
|
||||||
)
|
)
|
||||||
if node_data and "source_id" in node_data:
|
)
|
||||||
# Split source_id using GRAPH_FIELD_SEP
|
affected_edges = (
|
||||||
|
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
|
||||||
|
list(chunk_ids)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# logger.info(f"chunk_ids: {chunk_ids}")
|
||||||
|
# logger.info(f"affected_nodes: {affected_nodes}")
|
||||||
|
# logger.info(f"affected_edges: {affected_edges}")
|
||||||
|
|
||||||
|
# Process entities
|
||||||
|
for node_data in affected_nodes:
|
||||||
|
node_label = node_data.get("entity_id")
|
||||||
|
if node_label and "source_id" in node_data:
|
||||||
sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
|
sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||||
remaining_sources = sources - chunk_ids
|
remaining_sources = sources - chunk_ids
|
||||||
|
|
||||||
if not remaining_sources:
|
if not remaining_sources:
|
||||||
entities_to_delete.add(node_label)
|
entities_to_delete.add(node_label)
|
||||||
logger.debug(
|
|
||||||
f"Entity {node_label} marked for deletion - no remaining sources"
|
|
||||||
)
|
|
||||||
elif remaining_sources != sources:
|
elif remaining_sources != sources:
|
||||||
# Entity needs to be rebuilt from remaining chunks
|
|
||||||
entities_to_rebuild[node_label] = remaining_sources
|
entities_to_rebuild[node_label] = remaining_sources
|
||||||
logger.debug(
|
|
||||||
f"Entity {node_label} will be rebuilt from {len(remaining_sources)} remaining chunks"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process relationships
|
# Process relationships
|
||||||
# TODO There is performance when iterating get_all_labels for PostgresSQL
|
for edge_data in affected_edges:
|
||||||
for node_label in all_labels:
|
src = edge_data.get("source")
|
||||||
node_edges = await self.chunk_entity_relation_graph.get_node_edges(
|
tgt = edge_data.get("target")
|
||||||
node_label
|
|
||||||
)
|
if src and tgt and "source_id" in edge_data:
|
||||||
if node_edges:
|
edge_tuple = tuple(sorted((src, tgt)))
|
||||||
for src, tgt in node_edges:
|
if (
|
||||||
# To avoid processing the same edge twice in an undirected graph
|
edge_tuple in relationships_to_delete
|
||||||
if (tgt, src) in relationships_to_delete or (
|
or edge_tuple in relationships_to_rebuild
|
||||||
tgt,
|
):
|
||||||
src,
|
|
||||||
) in relationships_to_rebuild:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
|
||||||
src, tgt
|
|
||||||
)
|
|
||||||
if edge_data and "source_id" in edge_data:
|
|
||||||
# Split source_id using GRAPH_FIELD_SEP
|
|
||||||
sources = set(
|
|
||||||
edge_data["source_id"].split(GRAPH_FIELD_SEP)
|
|
||||||
)
|
|
||||||
remaining_sources = sources - chunk_ids
|
remaining_sources = sources - chunk_ids
|
||||||
|
|
||||||
if not remaining_sources:
|
if not remaining_sources:
|
||||||
relationships_to_delete.add((src, tgt))
|
relationships_to_delete.add(edge_tuple)
|
||||||
logger.debug(
|
|
||||||
f"Relationship {src}-{tgt} marked for deletion - no remaining sources"
|
|
||||||
)
|
|
||||||
elif remaining_sources != sources:
|
elif remaining_sources != sources:
|
||||||
# Relationship needs to be rebuilt from remaining chunks
|
relationships_to_rebuild[edge_tuple] = remaining_sources
|
||||||
relationships_to_rebuild[(src, tgt)] = (
|
|
||||||
remaining_sources
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Relationship {src}-{tgt} will be rebuilt from {len(remaining_sources)} remaining chunks"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Delete chunks from storage
|
# 5. Delete chunks from storage
|
||||||
if chunk_ids:
|
if chunk_ids:
|
||||||
|
|||||||
@ -33,7 +33,8 @@ from .base import (
|
|||||||
TextChunkSchema,
|
TextChunkSchema,
|
||||||
QueryParam,
|
QueryParam,
|
||||||
)
|
)
|
||||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
from .prompt import PROMPTS
|
||||||
|
from .constants import GRAPH_FIELD_SEP
|
||||||
import time
|
import time
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
GRAPH_FIELD_SEP = "<SEP>"
|
|
||||||
|
|
||||||
PROMPTS: dict[str, Any] = {}
|
PROMPTS: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user