LightRAG/lightrag/kg/neo4j_impl.py

1061 lines
42 KiB
Python
Raw Normal View History

import inspect
2024-10-26 19:29:45 -04:00
import os
import re
2024-10-26 19:29:45 -04:00
from dataclasses import dataclass
from typing import Any, final, Optional
import numpy as np
import configparser
2025-01-27 23:21:34 +08:00
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import logging
from ..utils import logger
from ..base import BaseGraphStorage
2025-02-20 14:29:36 +01:00
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
2025-02-16 16:04:07 +01:00
import pipmaster as pm
2025-02-16 16:04:07 +01:00
if not pm.is_installed("neo4j"):
pm.install("neo4j")
2025-02-16 16:04:35 +01:00
from neo4j import ( # type: ignore
2025-02-19 19:32:23 +01:00
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
)
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
# Set neo4j logger level to ERROR to suppress warning logs
logging.getLogger("neo4j").setLevel(logging.ERROR)
2025-02-11 03:29:40 +08:00
@final
2024-10-26 19:29:45 -04:00
@dataclass
2024-11-02 18:35:07 -04:00
class Neo4JStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
2024-11-02 18:35:07 -04:00
self._driver = None
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
USERNAME = os.environ.get(
2025-02-11 03:29:40 +08:00
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
)
PASSWORD = os.environ.get(
2025-02-11 03:29:40 +08:00
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
)
MAX_CONNECTION_POOL_SIZE = int(
os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE",
2025-03-08 02:39:51 +08:00
config.get("neo4j", "connection_pool_size", fallback=50),
)
2025-02-11 03:29:40 +08:00
)
CONNECTION_TIMEOUT = float(
2025-02-17 20:54:08 +03:00
os.environ.get(
"NEO4J_CONNECTION_TIMEOUT",
2025-03-08 02:39:51 +08:00
config.get("neo4j", "connection_timeout", fallback=30.0),
2025-02-17 20:54:08 +03:00
),
)
CONNECTION_ACQUISITION_TIMEOUT = float(
2025-02-17 20:54:08 +03:00
os.environ.get(
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
2025-03-08 02:39:51 +08:00
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
),
)
MAX_TRANSACTION_RETRY_TIME = float(
os.environ.get(
"NEO4J_MAX_TRANSACTION_RETRY_TIME",
config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
2025-02-17 20:54:08 +03:00
),
)
DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
)
2024-11-06 11:18:14 -05:00
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
2024-11-06 11:18:14 -05:00
)
# Try to connect to the database and create it if it doesn't exist
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try:
async with self._driver.session(database=database) as session:
try:
result = await session.run("MATCH (n) RETURN n LIMIT 0")
await result.consume() # Ensure result is consumed
logger.info(f"Connected to {database} at {URI}")
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{database} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
async with self._driver.session() as session:
result = await session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
2024-11-02 18:35:07 -04:00
if connected:
2025-04-02 16:36:02 +08:00
# Create index for base nodes on entity_id if it doesn't exist
try:
async with self._driver.session(database=database) as session:
# Check if index exists first
check_query = """
CALL db.indexes() YIELD name, labelsOrTypes, properties
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
RETURN count(*) > 0 AS exists
"""
try:
check_result = await session.run(check_query)
record = await check_result.single()
await check_result.consume()
index_exists = record and record.get("exists", False)
if not index_exists:
# Create index only if it doesn't exist
result = await session.run(
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
)
await result.consume()
logger.info(f"Created index for base nodes on entity_id in {database}")
except Exception:
# Fallback if db.indexes() is not supported in this Neo4j version
result = await session.run(
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
)
await result.consume()
except Exception as e:
logger.warning(f"Failed to create index: {str(e)}")
break
2024-10-26 19:29:45 -04:00
async def finalize(self):
"""Close the Neo4j driver and release all resources"""
2024-11-02 18:35:07 -04:00
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits"""
await self.finalize()
2024-11-02 18:35:07 -04:00
async def index_done_callback(self) -> None:
2025-02-16 16:04:07 +01:00
# Noe4J handles persistence automatically
pass
2024-11-02 18:35:07 -04:00
async def has_node(self, node_id: str) -> bool:
"""
Check if a node with the given label exists in the database
Args:
node_id: Label of the node to check
Returns:
bool: True if node exists, False otherwise
Raises:
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
2025-03-11 10:28:25 +08:00
result = await session.run(query, entity_id=node_id)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["node_exists"]
except Exception as e:
2025-03-11 10:28:25 +08:00
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
await result.consume() # Ensure results are consumed even on error
raise
2024-11-06 11:18:14 -05:00
2024-10-29 15:36:07 -04:00
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if an edge exists between two nodes
Args:
source_node_id: Label of the source node
target_node_id: Label of the target node
Returns:
bool: True if edge exists, False otherwise
Raises:
ValueError: If either node_id is invalid
Exception: If there is an error executing the query
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = (
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
2025-03-11 10:28:25 +08:00
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["edgeExists"]
except Exception as e:
logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure results are consumed even on error
raise
2024-10-26 19:29:45 -04:00
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier.
Args:
node_id: The node label to look up
Returns:
dict: Node properties if found
None: If node not found
Raises:
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
result = await session.run(query, entity_id=node_id)
try:
2025-03-09 01:00:42 +08:00
records = await result.fetch(
2
) # Get 2 records for duplication check
if len(records) > 1:
logger.warning(
f"Multiple nodes found with label '{node_id}'. Using first node."
)
if records:
node = records[0]["n"]
node_dict = dict(node)
# Remove base label from labels list if it exists
if "labels" in node_dict:
2025-03-11 10:28:25 +08:00
node_dict["labels"] = [
label
for label in node_dict["labels"]
if label != "base"
]
logger.debug(f"Neo4j query node {query} return: {node_dict}")
return node_dict
return None
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node for {node_id}: {str(e)}")
raise
2024-10-26 19:29:45 -04:00
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
If no node is found, returns 0.
2025-03-08 02:39:51 +08:00
Args:
node_id: The label of the node
2025-03-08 02:39:51 +08:00
Returns:
int: The number of relationships the node has, or 0 if no node found
Raises:
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """
MATCH (n:base {entity_id: $entity_id})
OPTIONAL MATCH (n)-[r]-()
RETURN COUNT(r) AS degree
"""
2025-03-11 10:28:25 +08:00
result = await session.run(query, entity_id=node_id)
try:
record = await result.single()
2025-03-08 02:39:51 +08:00
if not record:
2025-03-11 10:28:25 +08:00
logger.warning(f"No node found with label '{node_id}'")
return 0
2025-03-08 02:39:51 +08:00
degree = record["degree"]
2025-03-11 10:28:25 +08:00
logger.debug(
"Neo4j query node degree for {node_id} return: {degree}"
)
return degree
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
2025-03-11 10:28:25 +08:00
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
raise
2024-10-26 19:29:45 -04:00
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes.
Args:
src_id: Label of the source node
tgt_id: Label of the target node
Returns:
int: Sum of the degrees of both nodes
"""
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
2024-11-06 11:18:14 -05:00
2024-11-02 18:35:07 -04:00
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
return degrees
2024-11-06 11:18:14 -05:00
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes.
Args:
source_node_id: Label of the source node
target_node_id: Label of the target node
Returns:
dict: Edge properties if found, default properties if not found or on error
Raises:
ValueError: If either node_id is invalid
Exception: If there is an error executing the query
"""
try:
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
RETURN properties(r) as edge_properties
2025-03-01 17:45:06 +08:00
"""
2025-03-11 10:28:25 +08:00
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
try:
records = await result.fetch(2)
if len(records) > 1:
logger.warning(
f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
)
if records:
try:
edge_result = dict(records[0]["edge_properties"])
logger.debug(f"Result: {edge_result}")
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in edge_result:
edge_result[key] = default_value
logger.warning(
f"Edge between {source_node_id} and {target_node_id} "
f"missing {key}, using default: {default_value}"
)
2024-11-06 11:18:14 -05:00
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}"
)
return edge_result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {source_node_id} "
f"and {target_node_id}: {str(e)}"
)
# Return default edge properties on error
return {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
)
# Return default edge properties when no edge found
return {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
)
raise
2024-10-29 15:36:07 -04:00
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.
2024-11-06 11:18:14 -05:00
Args:
source_node_id: Label of the node to get edges for
Returns:
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
None: If no edges found
Raises:
ValueError: If source_node_id is invalid
Exception: If there is an error executing the query
2024-10-29 15:36:07 -04:00
"""
try:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """MATCH (n:base {entity_id: $entity_id})
OPTIONAL MATCH (n)-[r]-(connected:base)
WHERE connected.entity_id IS NOT NULL
RETURN n, r, connected"""
results = await session.run(query, entity_id=source_node_id)
2024-10-29 15:36:07 -04:00
edges = []
async for record in results:
source_node = record["n"]
connected_node = record["connected"]
# Skip if either node is None
if not source_node or not connected_node:
continue
source_label = (
2025-03-11 10:28:25 +08:00
source_node.get("entity_id")
if source_node.get("entity_id")
else None
)
target_label = (
2025-03-11 10:28:25 +08:00
connected_node.get("entity_id")
if connected_node.get("entity_id")
else None
)
if source_label and target_label:
edges.append((source_label, target_label))
await results.consume() # Ensure results are consumed
return edges
except Exception as e:
2025-03-11 10:28:25 +08:00
logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise
2024-10-29 15:36:07 -04:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
2024-11-06 11:18:14 -05:00
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
2024-11-06 11:18:14 -05:00
)
),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
2024-10-26 19:29:45 -04:00
"""
2024-11-02 18:35:07 -04:00
Upsert a node in the Neo4j database.
2024-10-26 19:29:45 -04:00
Args:
2024-11-02 18:35:07 -04:00
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
2024-10-26 19:29:45 -04:00
"""
2024-11-02 18:35:07 -04:00
properties = node_data
entity_type = properties["entity_type"]
if "entity_id" not in properties:
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
2024-10-29 15:36:07 -04:00
2024-11-02 18:35:07 -04:00
try:
async with self._driver.session(database=self._DATABASE) as session:
2025-03-09 01:00:42 +08:00
async def execute_upsert(tx: AsyncManagedTransaction):
2025-03-11 10:28:25 +08:00
query = (
"""
MERGE (n:base {entity_id: $entity_id})
SET n += $properties
SET n:`%s`
2025-03-11 10:28:25 +08:00
"""
% entity_type
)
2025-04-02 12:16:40 +08:00
result = await tx.run(
query, entity_id=node_id, properties=properties
)
logger.debug(
f"Upserted node with entity_id '{node_id}' and properties: {properties}"
)
await result.consume() # Ensure result is fully consumed
2025-03-09 01:00:42 +08:00
await session.execute_write(execute_upsert)
2024-11-02 18:35:07 -04:00
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
2024-11-06 11:18:14 -05:00
2024-11-02 18:35:07 -04:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
2024-11-06 11:18:14 -05:00
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
2024-11-06 11:18:14 -05:00
)
),
2024-11-02 18:35:07 -04:00
)
2024-11-06 11:18:14 -05:00
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
2024-10-26 19:29:45 -04:00
"""
Upsert an edge and its properties between two nodes identified by their labels.
Ensures both source and target nodes exist and are unique before creating the edge.
Uses entity_id property to uniquely identify nodes.
2024-11-02 18:35:07 -04:00
2024-10-26 19:29:45 -04:00
Args:
2024-11-02 18:35:07 -04:00
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
Raises:
ValueError: If either source or target node does not exist or is not unique
2024-10-26 19:29:45 -04:00
"""
2024-11-02 18:35:07 -04:00
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
2025-03-09 01:00:42 +08:00
async def execute_upsert(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})
WITH source
MATCH (target:base {entity_id: $target_entity_id})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
2025-03-09 01:00:42 +08:00
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
2025-03-09 01:00:42 +08:00
properties=edge_properties,
)
try:
records = await result.fetch(2)
if records:
logger.debug(
f"Upserted edge from '{source_node_id}' to '{target_node_id}'"
f"with properties: {edge_properties}"
)
finally:
await result.consume() # Ensure result is consumed
2025-03-09 01:00:42 +08:00
await session.execute_write(execute_upsert)
2024-11-02 18:35:07 -04:00
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
2024-11-06 11:18:14 -05:00
2024-10-26 19:29:45 -04:00
async def _node2vec_embed(self):
2024-11-06 11:18:14 -05:00
print("Implemented but never called.")
2025-02-20 14:29:36 +01:00
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
2025-02-20 14:29:36 +01:00
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
2025-03-02 17:32:25 +08:00
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
2025-02-20 14:29:36 +01:00
Args:
node_label: Label of the starting node
max_depth: Maximum depth of the subgraph
inclusive: Do an inclusive search if true
Returns:
KnowledgeGraph: Complete connected subgraph for specified node
2025-02-20 14:29:36 +01:00
"""
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
2025-02-20 14:29:36 +01:00
try:
if node_label == "*":
2025-02-20 14:29:36 +01:00
main_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
WITH n, COALESCE(count(r), 0) AS degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect({node: n}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
2025-02-20 14:29:36 +01:00
"""
result_set = await session.run(
main_query,
{"max_nodes": max_nodes},
)
2025-02-20 14:29:36 +01:00
else:
# Main query uses partial matching
main_query = """
MATCH (start)
WHERE start.entity_id = $entity_id
2025-02-20 14:29:36 +01:00
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
2025-02-20 14:29:36 +01:00
minLevel: 0,
maxLevel: $max_depth,
2025-02-20 14:29:36 +01:00
bfs: true
})
2025-02-20 14:29:36 +01:00
YIELD nodes, relationships
WITH start, nodes, relationships
UNWIND nodes AS node
OPTIONAL MATCH (node)-[r]-()
WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
ORDER BY
CASE
WHEN node = start THEN 0
ELSE length(shortestPath((start)--(node)))
END ASC,
degree DESC
LIMIT $max_nodes
WITH collect({node: node}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
2025-02-20 14:29:36 +01:00
"""
result_set = await session.run(
main_query,
{
"max_nodes": max_nodes,
"entity_id": node_label,
"max_depth": max_depth,
},
)
try:
record = await result_set.single()
if record:
# Handle nodes (compatible with multi-label cases)
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
2025-02-20 14:29:36 +01:00
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
2025-02-20 14:29:36 +01:00
)
seen_edges.add(edge_id)
2025-02-20 14:29:36 +01:00
logger.info(
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
)
finally:
await result_set.consume() # Ensure result set is consumed
2025-02-20 14:29:36 +01:00
except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}")
if node_label != "*":
logger.warning(
"Neo4j: falling back to basic Cypher recursive search..."
)
2025-03-11 10:28:25 +08:00
return await self._robust_fallback(
node_label, max_depth, max_nodes
2025-03-11 10:28:25 +08:00
)
2025-02-20 14:29:36 +01:00
return result
async def _robust_fallback(self, node_label: str, max_depth: int, max_nodes: int) -> KnowledgeGraph:
"""
Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures.
"""
result = KnowledgeGraph()
visited_nodes = set()
visited_edges = set()
2025-03-02 17:32:25 +08:00
async def traverse(
node: KnowledgeGraphNode,
edge: Optional[KnowledgeGraphEdge],
current_depth: int,
):
# Check traversal limits
if current_depth > max_depth:
logger.debug(f"Reached max depth: {max_depth}")
return
if len(visited_nodes) >= max_nodes:
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
return
# Check if node already visited
if node.id in visited_nodes:
return
# Get all edges and target nodes
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=node.id)
# Get all records and release database connection
2025-03-09 01:00:42 +08:00
records = await results.fetch(
1000
) # Max neighbour nodes we can handled
await results.consume() # Ensure results are consumed
# Nodes not connected to start node need to check degree
# if current_depth > 1 and len(records) < min_degree:
# return
# Add current node to result
result.nodes.append(node)
visited_nodes.add(node.id)
# Add edge to result if it exists and not already added
if edge and edge.id not in visited_edges:
result.edges.append(edge)
visited_edges.add(edge.id)
# Prepare nodes and edges for recursive processing
nodes_to_process = []
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=list(f"{target_id}"),
properties=dict(b_node.properties),
2025-03-08 02:39:51 +08:00
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{node.id}",
target=f"{target_id}",
properties=dict(rel),
)
nodes_to_process.append((target_node, target_edge))
else:
2025-03-08 02:39:51 +08:00
logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
# Process nodes after releasing database connection
for target_node, target_edge in nodes_to_process:
await traverse(target_node, target_edge, current_depth + 1)
# Get the starting node's data
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (n:base {entity_id: $entity_id})
RETURN id(n) as node_id, n
"""
node_result = await session.run(query, entity_id=node_label)
try:
node_record = await node_result.single()
if not node_record:
return result
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}",
labels=list(f"{node_record['n'].get('entity_id')}"),
properties=dict(node_record["n"].properties),
)
finally:
await node_result.consume() # Ensure results are consumed
# Start traversal with the initial node
await traverse(start_node, None, 0)
return result
2025-02-20 15:09:43 +01:00
async def get_all_labels(self) -> list[str]:
"""
Get all existing node labels in the database
Returns:
["Person", "Company", ...] # Alphabetically sorted label list
"""
2025-03-08 02:39:51 +08:00
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
2025-02-20 15:09:43 +01:00
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label"
# Method 2: Query compatible with older versions
query = """
MATCH (n:base)
WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label
ORDER BY label
2025-02-20 15:09:43 +01:00
"""
result = await session.run(query)
labels = []
try:
async for record in result:
labels.append(record["label"])
finally:
2025-03-08 02:39:51 +08:00
await (
result.consume()
) # Ensure results are consumed even if processing fails
2025-02-20 15:09:43 +01:00
return labels
2025-03-04 14:20:55 +08:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def delete_node(self, node_id: str) -> None:
2025-03-04 14:20:55 +08:00
"""Delete a node with the specified label
Args:
node_id: The label of the node to delete
"""
2025-03-11 10:28:25 +08:00
2025-03-04 14:20:55 +08:00
async def _do_delete(tx: AsyncManagedTransaction):
query = """
MATCH (n:base {entity_id: $entity_id})
2025-03-04 14:20:55 +08:00
DETACH DELETE n
"""
result = await tx.run(query, entity_id=node_id)
logger.debug(f"Deleted node with label '{node_id}'")
2025-03-08 02:39:51 +08:00
await result.consume() # Ensure result is fully consumed
2025-03-04 14:20:55 +08:00
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node labels to be deleted
"""
for node in nodes:
await self.delete_node(node)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
2025-03-11 10:28:25 +08:00
2025-03-04 14:20:55 +08:00
async def _do_delete_edge(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
2025-03-04 14:20:55 +08:00
DELETE r
"""
2025-03-11 10:28:25 +08:00
result = await tx.run(
query, source_entity_id=source, target_entity_id=target
)
logger.debug(f"Deleted edge from '{source}' to '{target}'")
2025-03-08 02:39:51 +08:00
await result.consume() # Ensure result is fully consumed
2025-03-04 14:20:55 +08:00
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
2025-02-16 13:55:30 +01:00
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
2025-02-16 13:55:30 +01:00
raise NotImplementedError
2025-03-31 23:22:27 +08:00
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
2025-03-31 23:22:27 +08:00
This method will delete all nodes and relationships in the Neo4j database.
2025-03-31 23:22:27 +08:00
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._driver.session(database=self._DATABASE) as session:
# Delete all nodes and relationships
query = "MATCH (n) DETACH DELETE n"
result = await session.run(query)
await result.consume() # Ensure result is fully consumed
2025-03-31 23:22:27 +08:00
logger.info(
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
return {"status": "error", "message": str(e)}