LightRAG/lightrag/kg/neo4j_impl.py

1047 lines
41 KiB
Python
Raw Normal View History

import inspect
2024-10-26 19:29:45 -04:00
import os
import re
2024-10-26 19:29:45 -04:00
from dataclasses import dataclass
from typing import Any, final, Optional
import numpy as np
import configparser
2025-01-27 23:21:34 +08:00
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import logging
from ..utils import logger
from ..base import BaseGraphStorage
2025-02-20 14:29:36 +01:00
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
2025-02-16 16:04:07 +01:00
import pipmaster as pm
2025-02-16 16:04:07 +01:00
if not pm.is_installed("neo4j"):
pm.install("neo4j")
2025-02-16 16:04:35 +01:00
from neo4j import ( # type: ignore
2025-02-19 19:32:23 +01:00
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
)
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
# Set neo4j logger level to ERROR to suppress warning logs
logging.getLogger("neo4j").setLevel(logging.ERROR)
2025-02-11 03:29:40 +08:00
@final
2024-10-26 19:29:45 -04:00
@dataclass
2024-11-02 18:35:07 -04:00
class Neo4JStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
2024-11-02 18:35:07 -04:00
self._driver = None
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
USERNAME = os.environ.get(
2025-02-11 03:29:40 +08:00
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
)
PASSWORD = os.environ.get(
2025-02-11 03:29:40 +08:00
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
)
MAX_CONNECTION_POOL_SIZE = int(
os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE",
2025-03-08 02:39:51 +08:00
config.get("neo4j", "connection_pool_size", fallback=50),
)
2025-02-11 03:29:40 +08:00
)
CONNECTION_TIMEOUT = float(
2025-02-17 20:54:08 +03:00
os.environ.get(
"NEO4J_CONNECTION_TIMEOUT",
2025-03-08 02:39:51 +08:00
config.get("neo4j", "connection_timeout", fallback=30.0),
2025-02-17 20:54:08 +03:00
),
)
CONNECTION_ACQUISITION_TIMEOUT = float(
2025-02-17 20:54:08 +03:00
os.environ.get(
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
2025-03-08 02:39:51 +08:00
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
),
)
MAX_TRANSACTION_RETRY_TIME = float(
os.environ.get(
"NEO4J_MAX_TRANSACTION_RETRY_TIME",
config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
2025-02-17 20:54:08 +03:00
),
)
DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
)
2024-11-06 11:18:14 -05:00
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
2024-11-06 11:18:14 -05:00
)
# Try to connect to the database and create it if it doesn't exist
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try:
async with self._driver.session(database=database) as session:
try:
result = await session.run("MATCH (n) RETURN n LIMIT 0")
await result.consume() # Ensure result is consumed
logger.info(f"Connected to {database} at {URI}")
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{database} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
async with self._driver.session() as session:
result = await session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
2024-11-02 18:35:07 -04:00
if connected:
break
2024-10-26 19:29:45 -04:00
async def finalize(self):
"""Close the Neo4j driver and release all resources"""
2024-11-02 18:35:07 -04:00
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits"""
await self.finalize()
2024-11-02 18:35:07 -04:00
async def index_done_callback(self) -> None:
2025-02-16 16:04:07 +01:00
# Noe4J handles persistence automatically
pass
2024-11-02 18:35:07 -04:00
async def has_node(self, node_id: str) -> bool:
"""
Check if a node with the given label exists in the database
Args:
node_id: Label of the node to check
Returns:
bool: True if node exists, False otherwise
Raises:
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
2025-03-11 10:28:25 +08:00
result = await session.run(query, entity_id=node_id)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["node_exists"]
except Exception as e:
2025-03-11 10:28:25 +08:00
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
await result.consume() # Ensure results are consumed even on error
raise
2024-11-06 11:18:14 -05:00
2024-10-29 15:36:07 -04:00
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if an edge exists between two nodes
Args:
source_node_id: Label of the source node
target_node_id: Label of the target node
Returns:
bool: True if edge exists, False otherwise
Raises:
ValueError: If either node_id is invalid
Exception: If there is an error executing the query
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = (
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
2025-03-11 10:28:25 +08:00
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["edgeExists"]
except Exception as e:
logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure results are consumed even on error
raise
2024-10-26 19:29:45 -04:00
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier.
Args:
node_id: The node label to look up
Returns:
dict: Node properties if found
None: If node not found
Raises:
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
result = await session.run(query, entity_id=node_id)
try:
2025-03-09 01:00:42 +08:00
records = await result.fetch(
2
) # Get 2 records for duplication check
if len(records) > 1:
logger.warning(
f"Multiple nodes found with label '{node_id}'. Using first node."
)
if records:
node = records[0]["n"]
node_dict = dict(node)
# Remove base label from labels list if it exists
if "labels" in node_dict:
2025-03-11 10:28:25 +08:00
node_dict["labels"] = [
label
for label in node_dict["labels"]
if label != "base"
]
logger.debug(f"Neo4j query node {query} return: {node_dict}")
return node_dict
return None
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node for {node_id}: {str(e)}")
raise
2024-10-26 19:29:45 -04:00
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
If no node is found, returns 0.
2025-03-08 02:39:51 +08:00
Args:
node_id: The label of the node
2025-03-08 02:39:51 +08:00
Returns:
int: The number of relationships the node has, or 0 if no node found
Raises:
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """
MATCH (n:base {entity_id: $entity_id})
OPTIONAL MATCH (n)-[r]-()
RETURN COUNT(r) AS degree
"""
2025-03-11 10:28:25 +08:00
result = await session.run(query, entity_id=node_id)
try:
record = await result.single()
2025-03-08 02:39:51 +08:00
if not record:
2025-03-11 10:28:25 +08:00
logger.warning(f"No node found with label '{node_id}'")
return 0
2025-03-08 02:39:51 +08:00
degree = record["degree"]
2025-03-11 10:28:25 +08:00
logger.debug(
"Neo4j query node degree for {node_id} return: {degree}"
)
return degree
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
2025-03-11 10:28:25 +08:00
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
raise
2024-10-26 19:29:45 -04:00
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes.
Args:
src_id: Label of the source node
tgt_id: Label of the target node
Returns:
int: Sum of the degrees of both nodes
"""
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
2024-11-06 11:18:14 -05:00
2024-11-02 18:35:07 -04:00
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
return degrees
2024-11-06 11:18:14 -05:00
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes.
Args:
source_node_id: Label of the source node
target_node_id: Label of the target node
Returns:
dict: Edge properties if found, default properties if not found or on error
Raises:
ValueError: If either node_id is invalid
Exception: If there is an error executing the query
"""
try:
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
RETURN properties(r) as edge_properties
2025-03-01 17:45:06 +08:00
"""
2025-03-11 10:28:25 +08:00
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
try:
records = await result.fetch(2)
if len(records) > 1:
logger.warning(
f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
)
if records:
try:
edge_result = dict(records[0]["edge_properties"])
logger.debug(f"Result: {edge_result}")
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in edge_result:
edge_result[key] = default_value
logger.warning(
f"Edge between {source_node_id} and {target_node_id} "
f"missing {key}, using default: {default_value}"
)
2024-11-06 11:18:14 -05:00
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}"
)
return edge_result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {source_node_id} "
f"and {target_node_id}: {str(e)}"
)
# Return default edge properties on error
return {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
)
# Return default edge properties when no edge found
return {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
)
raise
2024-10-29 15:36:07 -04:00
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.
2024-11-06 11:18:14 -05:00
Args:
source_node_id: Label of the node to get edges for
Returns:
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
None: If no edges found
Raises:
ValueError: If source_node_id is invalid
Exception: If there is an error executing the query
2024-10-29 15:36:07 -04:00
"""
try:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """MATCH (n:base {entity_id: $entity_id})
OPTIONAL MATCH (n)-[r]-(connected:base)
WHERE connected.entity_id IS NOT NULL
RETURN n, r, connected"""
results = await session.run(query, entity_id=source_node_id)
2024-10-29 15:36:07 -04:00
edges = []
async for record in results:
source_node = record["n"]
connected_node = record["connected"]
# Skip if either node is None
if not source_node or not connected_node:
continue
source_label = (
2025-03-11 10:28:25 +08:00
source_node.get("entity_id")
if source_node.get("entity_id")
else None
)
target_label = (
2025-03-11 10:28:25 +08:00
connected_node.get("entity_id")
if connected_node.get("entity_id")
else None
)
if source_label and target_label:
edges.append((source_label, target_label))
await results.consume() # Ensure results are consumed
return edges
except Exception as e:
2025-03-11 10:28:25 +08:00
logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise
2024-10-29 15:36:07 -04:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
2024-11-06 11:18:14 -05:00
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
2024-11-06 11:18:14 -05:00
)
),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
2024-10-26 19:29:45 -04:00
"""
2024-11-02 18:35:07 -04:00
Upsert a node in the Neo4j database.
2024-10-26 19:29:45 -04:00
Args:
2024-11-02 18:35:07 -04:00
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
2024-10-26 19:29:45 -04:00
"""
2024-11-02 18:35:07 -04:00
properties = node_data
entity_type = properties["entity_type"]
entity_id = properties["entity_id"]
if "entity_id" not in properties:
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
2024-10-29 15:36:07 -04:00
2024-11-02 18:35:07 -04:00
try:
async with self._driver.session(database=self._DATABASE) as session:
2025-03-09 01:00:42 +08:00
async def execute_upsert(tx: AsyncManagedTransaction):
2025-03-11 10:28:25 +08:00
query = (
"""
MERGE (n:base {entity_id: $properties.entity_id})
SET n += $properties
SET n:`%s`
2025-03-11 10:28:25 +08:00
"""
% entity_type
)
result = await tx.run(query, properties=properties)
logger.debug(
f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
)
await result.consume() # Ensure result is fully consumed
2025-03-09 01:00:42 +08:00
await session.execute_write(execute_upsert)
2024-11-02 18:35:07 -04:00
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
2024-11-06 11:18:14 -05:00
2024-11-02 18:35:07 -04:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
2024-11-06 11:18:14 -05:00
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
2024-11-06 11:18:14 -05:00
)
),
2024-11-02 18:35:07 -04:00
)
2024-11-06 11:18:14 -05:00
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
2024-10-26 19:29:45 -04:00
"""
Upsert an edge and its properties between two nodes identified by their labels.
Ensures both source and target nodes exist and are unique before creating the edge.
Uses entity_id property to uniquely identify nodes.
2024-11-02 18:35:07 -04:00
2024-10-26 19:29:45 -04:00
Args:
2024-11-02 18:35:07 -04:00
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
Raises:
ValueError: If either source or target node does not exist or is not unique
2024-10-26 19:29:45 -04:00
"""
2024-11-02 18:35:07 -04:00
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
2025-03-09 01:00:42 +08:00
async def execute_upsert(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})
WITH source
MATCH (target:base {entity_id: $target_entity_id})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
2025-03-09 01:00:42 +08:00
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
2025-03-09 01:00:42 +08:00
properties=edge_properties,
)
try:
records = await result.fetch(2)
if records:
logger.debug(
f"Upserted edge from '{source_node_id}' to '{target_node_id}'"
f"with properties: {edge_properties}"
)
finally:
await result.consume() # Ensure result is consumed
2025-03-09 01:00:42 +08:00
await session.execute_write(execute_upsert)
2024-11-02 18:35:07 -04:00
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
2024-11-06 11:18:14 -05:00
2024-10-26 19:29:45 -04:00
async def _node2vec_embed(self):
2024-11-06 11:18:14 -05:00
print("Implemented but never called.")
2025-02-20 14:29:36 +01:00
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
2025-02-20 14:29:36 +01:00
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
2025-03-02 17:32:25 +08:00
When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
2025-02-20 14:29:36 +01:00
Args:
node_label: Label of the starting node
max_depth: Maximum depth of the subgraph
min_degree: Minimum degree of nodes to include. Defaults to 0
inclusive: Do an inclusive search if true
Returns:
KnowledgeGraph: Complete connected subgraph for specified node
2025-02-20 14:29:36 +01:00
"""
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
2025-02-20 14:29:36 +01:00
try:
if node_label == "*":
2025-02-20 14:29:36 +01:00
main_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
WITH n, COALESCE(count(r), 0) AS degree
WHERE degree >= $min_degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect({node: n}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
2025-02-20 14:29:36 +01:00
"""
result_set = await session.run(
main_query,
{"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
)
2025-02-20 14:29:36 +01:00
else:
# Main query uses partial matching
main_query = """
MATCH (start)
WHERE
CASE
WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
ELSE start.entity_id = $entity_id
END
2025-02-20 14:29:36 +01:00
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
2025-02-20 14:29:36 +01:00
minLevel: 0,
maxLevel: $max_depth,
2025-02-20 14:29:36 +01:00
bfs: true
})
2025-02-20 14:29:36 +01:00
YIELD nodes, relationships
WITH start, nodes, relationships
UNWIND nodes AS node
OPTIONAL MATCH (node)-[r]-()
WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
ORDER BY
CASE
WHEN node = start THEN 3
WHEN EXISTS((start)--(node)) THEN 2
ELSE 1
END DESC,
degree DESC
LIMIT $max_nodes
WITH collect({node: node}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
2025-02-20 14:29:36 +01:00
"""
result_set = await session.run(
main_query,
{
"max_nodes": MAX_GRAPH_NODES,
"entity_id": node_label,
"inclusive": inclusive,
"max_depth": max_depth,
"min_degree": min_degree,
},
)
try:
record = await result_set.single()
if record:
# Handle nodes (compatible with multi-label cases)
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
2025-02-20 14:29:36 +01:00
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
2025-02-20 14:29:36 +01:00
)
seen_edges.add(edge_id)
2025-02-20 14:29:36 +01:00
logger.info(
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
)
finally:
await result_set.consume() # Ensure result set is consumed
2025-02-20 14:29:36 +01:00
except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}")
if node_label != "*":
logger.warning(
"Neo4j: falling back to basic Cypher recursive search..."
)
if inclusive:
logger.warning(
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
)
2025-03-11 10:28:25 +08:00
return await self._robust_fallback(
node_label, max_depth, min_degree
)
2025-02-20 14:29:36 +01:00
return result
async def _robust_fallback(
self, node_label: str, max_depth: int, min_degree: int = 0
) -> KnowledgeGraph:
"""
Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures.
"""
result = KnowledgeGraph()
visited_nodes = set()
visited_edges = set()
2025-03-02 17:32:25 +08:00
async def traverse(
node: KnowledgeGraphNode,
edge: Optional[KnowledgeGraphEdge],
current_depth: int,
):
# Check traversal limits
if current_depth > max_depth:
logger.debug(f"Reached max depth: {max_depth}")
return
if len(visited_nodes) >= MAX_GRAPH_NODES:
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
return
# Check if node already visited
if node.id in visited_nodes:
return
# Get all edges and target nodes
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=node.id)
# Get all records and release database connection
2025-03-09 01:00:42 +08:00
records = await results.fetch(
1000
) # Max neighbour nodes we can handled
await results.consume() # Ensure results are consumed
# Nodes not connected to start node need to check degree
if current_depth > 1 and len(records) < min_degree:
return
# Add current node to result
result.nodes.append(node)
visited_nodes.add(node.id)
# Add edge to result if it exists and not already added
if edge and edge.id not in visited_edges:
result.edges.append(edge)
visited_edges.add(edge.id)
# Prepare nodes and edges for recursive processing
nodes_to_process = []
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=list(f"{target_id}"),
properties=dict(b_node.properties),
2025-03-08 02:39:51 +08:00
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{node.id}",
target=f"{target_id}",
properties=dict(rel),
)
nodes_to_process.append((target_node, target_edge))
else:
2025-03-08 02:39:51 +08:00
logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
# Process nodes after releasing database connection
for target_node, target_edge in nodes_to_process:
await traverse(target_node, target_edge, current_depth + 1)
# Get the starting node's data
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (n:base {entity_id: $entity_id})
RETURN id(n) as node_id, n
"""
node_result = await session.run(query, entity_id=node_label)
try:
node_record = await node_result.single()
if not node_record:
return result
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}",
labels=list(f"{node_record['n'].get('entity_id')}"),
properties=dict(node_record["n"].properties),
)
finally:
await node_result.consume() # Ensure results are consumed
# Start traversal with the initial node
await traverse(start_node, None, 0)
return result
2025-02-20 15:09:43 +01:00
async def get_all_labels(self) -> list[str]:
"""
Get all existing node labels in the database
Returns:
["Person", "Company", ...] # Alphabetically sorted label list
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
2025-02-20 15:09:43 +01:00
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label"
# Method 2: Query compatible with older versions
query = """
MATCH (n)
WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label
ORDER BY label
2025-02-20 15:09:43 +01:00
"""
result = await session.run(query)
labels = []
try:
async for record in result:
labels.append(record["label"])
finally:
2025-03-08 02:39:51 +08:00
await (
result.consume()
) # Ensure results are consumed even if processing fails
2025-02-20 15:09:43 +01:00
return labels
2025-03-04 14:20:55 +08:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def delete_node(self, node_id: str) -> None:
2025-03-04 14:20:55 +08:00
"""Delete a node with the specified label
Args:
node_id: The label of the node to delete
"""
2025-03-11 10:28:25 +08:00
2025-03-04 14:20:55 +08:00
async def _do_delete(tx: AsyncManagedTransaction):
query = """
MATCH (n:base {entity_id: $entity_id})
2025-03-04 14:20:55 +08:00
DETACH DELETE n
"""
result = await tx.run(query, entity_id=node_id)
logger.debug(f"Deleted node with label '{node_id}'")
2025-03-08 02:39:51 +08:00
await result.consume() # Ensure result is fully consumed
2025-03-04 14:20:55 +08:00
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node labels to be deleted
"""
for node in nodes:
await self.delete_node(node)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
2025-03-11 10:28:25 +08:00
2025-03-04 14:20:55 +08:00
async def _do_delete_edge(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
2025-03-04 14:20:55 +08:00
DELETE r
"""
2025-03-11 10:28:25 +08:00
result = await tx.run(
query, source_entity_id=source, target_entity_id=target
)
logger.debug(f"Deleted edge from '{source}' to '{target}'")
2025-03-08 02:39:51 +08:00
await result.consume() # Ensure result is fully consumed
2025-03-04 14:20:55 +08:00
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
2025-02-16 13:55:30 +01:00
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
2025-02-16 13:55:30 +01:00
raise NotImplementedError
2025-03-31 23:22:27 +08:00
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
2025-03-31 23:22:27 +08:00
This method will delete all nodes and relationships in the Neo4j database.
2025-03-31 23:22:27 +08:00
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._driver.session(database=self._DATABASE) as session:
# Delete all nodes and relationships
query = "MATCH (n) DETACH DELETE n"
result = await session.run(query)
await result.consume() # Ensure result is fully consumed
2025-03-31 23:22:27 +08:00
logger.info(
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
return {"status": "error", "message": str(e)}