2024-10-26 19:29:45 -04:00
import asyncio
2025-01-04 21:47:52 +08:00
import inspect
2024-10-26 19:29:45 -04:00
import os
2025-02-08 16:06:07 +08:00
import re
2024-10-26 19:29:45 -04:00
from dataclasses import dataclass
2025-03-08 04:28:54 +08:00
from typing import Any , final , Optional
2025-02-16 13:53:59 +01:00
import numpy as np
2025-02-11 00:55:52 +08:00
import configparser
2025-01-27 23:21:34 +08:00
2025-01-04 21:47:52 +08:00
2024-10-30 17:48:14 -04:00
from tenacity import (
retry ,
stop_after_attempt ,
wait_exponential ,
retry_if_exception_type ,
)
2025-03-07 16:56:48 +08:00
import logging
2025-02-08 16:06:07 +08:00
from . . utils import logger
2025-01-04 21:47:52 +08:00
from . . base import BaseGraphStorage
2025-02-20 14:29:36 +01:00
from . . types import KnowledgeGraph , KnowledgeGraphNode , KnowledgeGraphEdge
2025-02-16 16:04:07 +01:00
import pipmaster as pm
2025-01-04 21:47:52 +08:00
2025-02-16 16:04:07 +01:00
if not pm . is_installed ( " neo4j " ) :
pm . install ( " neo4j " )
2025-02-16 16:04:35 +01:00
2025-03-02 15:39:14 +08:00
from neo4j import ( # type: ignore
2025-02-19 19:32:23 +01:00
AsyncGraphDatabase ,
exceptions as neo4jExceptions ,
AsyncDriver ,
AsyncManagedTransaction ,
GraphDatabase ,
)
2025-02-16 14:38:09 +01:00
2025-02-11 00:55:52 +08:00
config = configparser . ConfigParser ( )
config . read ( " config.ini " , " utf-8 " )
2025-03-02 16:20:37 +08:00
# Get maximum number of graph nodes from environment variable, default is 1000
2025-03-02 15:39:14 +08:00
MAX_GRAPH_NODES = int ( os . getenv ( " MAX_GRAPH_NODES " , 1000 ) )
2025-03-07 16:56:48 +08:00
# Set neo4j logger level to ERROR to suppress warning logs
logging . getLogger ( " neo4j " ) . setLevel ( logging . ERROR )
2025-02-11 03:29:40 +08:00
2025-03-08 01:20:36 +08:00
2025-02-16 14:38:09 +01:00
@final
2024-10-26 19:29:45 -04:00
@dataclass
2024-11-02 18:35:07 -04:00
class Neo4JStorage ( BaseGraphStorage ) :
2024-12-02 02:44:47 +05:30
def __init__ ( self , namespace , global_config , embedding_func ) :
super ( ) . __init__ (
namespace = namespace ,
global_config = global_config ,
embedding_func = embedding_func ,
)
2024-11-02 18:35:07 -04:00
self . _driver = None
self . _driver_lock = asyncio . Lock ( )
2025-02-11 00:55:52 +08:00
2025-02-13 18:07:24 +03:00
URI = os . environ . get ( " NEO4J_URI " , config . get ( " neo4j " , " uri " , fallback = None ) )
USERNAME = os . environ . get (
2025-02-11 03:29:40 +08:00
" NEO4J_USERNAME " , config . get ( " neo4j " , " username " , fallback = None )
2025-02-13 18:07:24 +03:00
)
2025-02-13 18:09:24 +03:00
PASSWORD = os . environ . get (
2025-02-11 03:29:40 +08:00
" NEO4J_PASSWORD " , config . get ( " neo4j " , " password " , fallback = None )
2025-02-13 18:07:24 +03:00
)
2025-02-13 15:26:45 +03:00
MAX_CONNECTION_POOL_SIZE = int (
os . environ . get (
" NEO4J_MAX_CONNECTION_POOL_SIZE " ,
2025-03-08 02:39:51 +08:00
config . get ( " neo4j " , " connection_pool_size " , fallback = 50 ) ,
2025-02-13 15:26:45 +03:00
)
2025-02-11 03:29:40 +08:00
)
2025-02-17 12:58:04 +03:00
CONNECTION_TIMEOUT = float (
2025-02-17 20:54:08 +03:00
os . environ . get (
" NEO4J_CONNECTION_TIMEOUT " ,
2025-03-08 02:39:51 +08:00
config . get ( " neo4j " , " connection_timeout " , fallback = 30.0 ) ,
2025-02-17 20:54:08 +03:00
) ,
2025-02-17 12:58:04 +03:00
)
CONNECTION_ACQUISITION_TIMEOUT = float (
2025-02-17 20:54:08 +03:00
os . environ . get (
" NEO4J_CONNECTION_ACQUISITION_TIMEOUT " ,
2025-03-08 02:39:51 +08:00
config . get ( " neo4j " , " connection_acquisition_timeout " , fallback = 30.0 ) ,
2025-03-08 01:20:36 +08:00
) ,
)
MAX_TRANSACTION_RETRY_TIME = float (
os . environ . get (
" NEO4J_MAX_TRANSACTION_RETRY_TIME " ,
config . get ( " neo4j " , " max_transaction_retry_time " , fallback = 30.0 ) ,
2025-02-17 20:54:08 +03:00
) ,
2025-02-17 12:58:04 +03:00
)
2025-01-04 21:47:52 +08:00
DATABASE = os . environ . get (
2025-02-08 16:06:07 +08:00
" NEO4J_DATABASE " , re . sub ( r " [^a-zA-Z0-9-] " , " - " , namespace )
)
2025-02-11 00:55:52 +08:00
2024-11-06 11:18:14 -05:00
self . _driver : AsyncDriver = AsyncGraphDatabase . driver (
2025-02-17 12:58:04 +03:00
URI ,
auth = ( USERNAME , PASSWORD ) ,
max_connection_pool_size = MAX_CONNECTION_POOL_SIZE ,
connection_timeout = CONNECTION_TIMEOUT ,
connection_acquisition_timeout = CONNECTION_ACQUISITION_TIMEOUT ,
2025-03-08 01:20:36 +08:00
max_transaction_retry_time = MAX_TRANSACTION_RETRY_TIME ,
2024-11-06 11:18:14 -05:00
)
2025-02-08 16:06:07 +08:00
# Try to connect to the database
2025-01-13 07:27:30 +00:00
with GraphDatabase . driver (
URI ,
auth = ( USERNAME , PASSWORD ) ,
max_connection_pool_size = MAX_CONNECTION_POOL_SIZE ,
2025-02-17 12:58:04 +03:00
connection_timeout = CONNECTION_TIMEOUT ,
connection_acquisition_timeout = CONNECTION_ACQUISITION_TIMEOUT ,
2025-01-13 07:27:30 +00:00
) as _sync_driver :
2025-02-08 16:06:07 +08:00
for database in ( DATABASE , None ) :
self . _DATABASE = database
connected = False
2025-01-04 21:47:52 +08:00
try :
2025-02-08 16:06:07 +08:00
with _sync_driver . session ( database = database ) as session :
try :
session . run ( " MATCH (n) RETURN n LIMIT 0 " )
logger . info ( f " Connected to { database } at { URI } " )
connected = True
except neo4jExceptions . ServiceUnavailable as e :
logger . error (
f " { database } at { URI } is not available " . capitalize ( )
)
raise e
except neo4jExceptions . AuthError as e :
logger . error ( f " Authentication failed for { database } at { URI } " )
raise e
2025-01-04 21:47:52 +08:00
except neo4jExceptions . ClientError as e :
2025-02-08 16:06:07 +08:00
if e . code == " Neo.ClientError.Database.DatabaseNotFound " :
logger . info (
f " { database } at { URI } not found. Try to create specified database. " . capitalize ( )
2025-01-04 22:33:35 +08:00
)
2025-02-08 16:06:07 +08:00
try :
with _sync_driver . session ( ) as session :
session . run (
f " CREATE DATABASE ` { database } ` IF NOT EXISTS "
)
logger . info ( f " { database } at { URI } created " . capitalize ( ) )
connected = True
except (
neo4jExceptions . ClientError ,
neo4jExceptions . DatabaseError ,
) as e :
if (
e . code
== " Neo.ClientError.Statement.UnsupportedAdministrationCommand "
) or (
e . code == " Neo.DatabaseError.Statement.ExecutionFailed "
) :
if database is not None :
logger . warning (
" This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database. "
)
if database is None :
logger . error ( f " Failed to create { database } at { URI } " )
raise e
if connected :
break
2024-11-02 18:35:07 -04:00
2024-10-26 19:29:45 -04:00
def __post_init__ ( self ) :
self . _node_embed_algorithms = {
" node2vec " : self . _node2vec_embed ,
}
2024-11-02 18:35:07 -04:00
async def close ( self ) :
2025-03-08 10:19:20 +08:00
""" Close the Neo4j driver and release all resources """
2024-11-02 18:35:07 -04:00
if self . _driver :
await self . _driver . close ( )
self . _driver = None
async def __aexit__ ( self , exc_type , exc , tb ) :
2025-03-08 10:19:20 +08:00
""" Ensure driver is closed when context manager exits """
await self . close ( )
2024-11-02 18:35:07 -04:00
2025-02-16 14:38:09 +01:00
async def index_done_callback ( self ) - > None :
2025-02-16 16:04:07 +01:00
# Noe4J handles persistence automatically
2025-02-16 14:38:09 +01:00
pass
2024-11-02 18:35:07 -04:00
2025-03-08 10:23:27 +08:00
def _ensure_label ( self , label : str ) - > str :
2025-03-07 16:43:18 +08:00
""" Ensure a label is valid
2025-03-08 01:20:36 +08:00
2025-03-07 16:43:18 +08:00
Args :
label : The label to validate
2025-03-09 01:00:42 +08:00
2025-03-08 10:23:27 +08:00
Returns :
str : The cleaned label
2025-03-09 01:00:42 +08:00
2025-03-08 10:23:27 +08:00
Raises :
ValueError : If label is empty after cleaning
2025-03-07 16:43:18 +08:00
"""
2025-02-14 16:04:06 +01:00
clean_label = label . strip ( ' " ' )
2025-03-08 01:20:36 +08:00
if not clean_label :
raise ValueError ( " Neo4j: Label cannot be empty " )
2025-02-14 16:04:06 +01:00
return clean_label
async def has_node ( self , node_id : str ) - > bool :
2025-03-08 10:19:20 +08:00
"""
Check if a node with the given label exists in the database
Args :
node_id : Label of the node to check
Returns :
bool : True if node exists , False otherwise
Raises :
ValueError : If node_id is invalid
Exception : If there is an error executing the query
"""
2025-03-08 10:23:27 +08:00
entity_name_label = self . _ensure_label ( node_id )
2025-03-08 02:39:51 +08:00
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
2025-03-08 10:19:20 +08:00
try :
query = f " MATCH (n:` { entity_name_label } `) RETURN count(n) > 0 AS node_exists "
result = await session . run ( query )
single_result = await result . single ( )
await result . consume ( ) # Ensure result is fully consumed
return single_result [ " node_exists " ]
except Exception as e :
logger . error (
f " Error checking node existence for { entity_name_label } : { str ( e ) } "
)
await result . consume ( ) # Ensure results are consumed even on error
raise
2024-11-06 11:18:14 -05:00
2024-10-29 15:36:07 -04:00
async def has_edge ( self , source_node_id : str , target_node_id : str ) - > bool :
2025-03-08 10:19:20 +08:00
"""
Check if an edge exists between two nodes
Args :
source_node_id : Label of the source node
target_node_id : Label of the target node
Returns :
bool : True if edge exists , False otherwise
Raises :
ValueError : If either node_id is invalid
Exception : If there is an error executing the query
"""
2025-03-08 10:23:27 +08:00
entity_name_label_source = self . _ensure_label ( source_node_id )
entity_name_label_target = self . _ensure_label ( target_node_id )
2024-11-06 11:18:14 -05:00
2025-03-08 02:39:51 +08:00
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
2025-03-08 10:19:20 +08:00
try :
query = (
f " MATCH (a:` { entity_name_label_source } `)-[r]-(b:` { entity_name_label_target } `) "
" RETURN COUNT(r) > 0 AS edgeExists "
)
result = await session . run ( query )
single_result = await result . single ( )
await result . consume ( ) # Ensure result is fully consumed
return single_result [ " edgeExists " ]
except Exception as e :
logger . error (
f " Error checking edge existence between { entity_name_label_source } and { entity_name_label_target } : { str ( e ) } "
)
await result . consume ( ) # Ensure results are consumed even on error
raise
2024-10-26 19:29:45 -04:00
2025-02-16 13:53:59 +01:00
async def get_node ( self , node_id : str ) - > dict [ str , str ] | None :
2025-02-14 16:04:06 +01:00
""" Get node by its label identifier.
Args :
node_id : The node label to look up
Returns :
dict : Node properties if found
None : If node not found
2025-03-08 10:19:20 +08:00
Raises :
ValueError : If node_id is invalid
Exception : If there is an error executing the query
2025-02-14 16:04:06 +01:00
"""
2025-03-08 10:23:27 +08:00
entity_name_label = self . _ensure_label ( node_id )
2025-03-08 02:39:51 +08:00
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
2025-03-08 10:19:20 +08:00
try :
2025-03-09 00:24:55 +08:00
query = f " MATCH (n:` { entity_name_label } ` {{ entity_id: $entity_id }} ) RETURN n "
result = await session . run ( query , entity_id = entity_name_label )
2025-03-08 10:19:20 +08:00
try :
2025-03-09 01:00:42 +08:00
records = await result . fetch (
2
) # Get 2 records for duplication check
2025-03-08 10:19:20 +08:00
if len ( records ) > 1 :
logger . warning (
f " Multiple nodes found with label ' { entity_name_label } ' . Using first node. "
)
if records :
node = records [ 0 ] [ " n " ]
node_dict = dict ( node )
logger . debug (
f " { inspect . currentframe ( ) . f_code . co_name } : query: { query } , result: { node_dict } "
)
return node_dict
return None
finally :
await result . consume ( ) # Ensure result is fully consumed
except Exception as e :
logger . error ( f " Error getting node for { entity_name_label } : { str ( e ) } " )
raise
2024-10-26 19:29:45 -04:00
async def node_degree ( self , node_id : str ) - > int :
2025-03-08 01:20:36 +08:00
""" Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label , returns the degree of the first node .
If no node is found , returns 0.
2025-03-08 02:39:51 +08:00
2025-03-08 01:20:36 +08:00
Args :
node_id : The label of the node
2025-03-08 02:39:51 +08:00
2025-03-08 01:20:36 +08:00
Returns :
int : The number of relationships the node has , or 0 if no node found
2025-03-08 10:19:20 +08:00
Raises :
ValueError : If node_id is invalid
Exception : If there is an error executing the query
2025-03-08 01:20:36 +08:00
"""
2025-03-08 10:23:27 +08:00
entity_name_label = self . _ensure_label ( node_id )
2024-10-29 15:36:07 -04:00
2025-03-08 02:39:51 +08:00
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
2025-03-08 10:19:20 +08:00
try :
query = f """
MATCH ( n : ` { entity_name_label } ` )
OPTIONAL MATCH ( n ) - [ r ] - ( )
RETURN n , COUNT ( r ) AS degree
"""
result = await session . run ( query )
try :
records = await result . fetch ( 100 )
2025-03-08 02:39:51 +08:00
2025-03-08 10:19:20 +08:00
if not records :
logger . warning (
f " No node found with label ' { entity_name_label } ' "
)
return 0
2025-03-08 02:39:51 +08:00
2025-03-08 10:19:20 +08:00
if len ( records ) > 1 :
logger . warning (
f " Multiple nodes ( { len ( records ) } ) found with label ' { entity_name_label } ' , using first node ' s degree "
)
2025-03-08 02:39:51 +08:00
2025-03-08 10:19:20 +08:00
degree = records [ 0 ] [ " degree " ]
logger . debug (
f " { inspect . currentframe ( ) . f_code . co_name } :query: { query } :result: { degree } "
)
return degree
finally :
await result . consume ( ) # Ensure result is fully consumed
except Exception as e :
logger . error (
f " Error getting node degree for { entity_name_label } : { str ( e ) } "
)
raise
2024-10-26 19:29:45 -04:00
async def edge_degree ( self , src_id : str , tgt_id : str ) - > int :
2025-03-08 10:19:20 +08:00
""" Get the total degree (sum of relationships) of two nodes.
Args :
src_id : Label of the source node
tgt_id : Label of the target node
Returns :
int : Sum of the degrees of both nodes
"""
2025-03-08 10:23:27 +08:00
entity_name_label_source = self . _ensure_label ( src_id )
entity_name_label_target = self . _ensure_label ( tgt_id )
2025-03-08 10:19:20 +08:00
2024-11-02 18:35:07 -04:00
src_degree = await self . node_degree ( entity_name_label_source )
trg_degree = await self . node_degree ( entity_name_label_target )
2024-11-06 11:18:14 -05:00
2024-11-02 18:35:07 -04:00
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int ( src_degree ) + int ( trg_degree )
return degrees
2024-11-06 11:18:14 -05:00
async def get_edge (
2025-01-04 22:33:35 +08:00
self , source_node_id : str , target_node_id : str
2025-02-16 13:53:59 +01:00
) - > dict [ str , str ] | None :
2025-03-08 10:19:20 +08:00
""" Get edge properties between two nodes.
Args :
source_node_id : Label of the source node
target_node_id : Label of the target node
Returns :
dict : Edge properties if found , default properties if not found or on error
Raises :
ValueError : If either node_id is invalid
Exception : If there is an error executing the query
"""
2025-02-14 16:04:06 +01:00
try :
2025-03-08 10:23:27 +08:00
entity_name_label_source = self . _ensure_label ( source_node_id )
entity_name_label_target = self . _ensure_label ( target_node_id )
2025-02-14 16:04:06 +01:00
2025-03-08 02:39:51 +08:00
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
2025-02-14 16:04:06 +01:00
query = f """
2025-03-08 01:20:36 +08:00
MATCH ( start : ` { entity_name_label_source } ` ) - [ r ] - ( end : ` { entity_name_label_target } ` )
2025-02-14 16:04:06 +01:00
RETURN properties ( r ) as edge_properties
2025-03-01 17:45:06 +08:00
"""
2025-02-14 16:04:06 +01:00
result = await session . run ( query )
2025-03-08 10:19:20 +08:00
try :
2025-03-08 11:36:24 +08:00
records = await result . fetch ( 2 )
2025-03-08 04:28:54 +08:00
2025-03-08 10:19:20 +08:00
if len ( records ) > 1 :
logger . warning (
f " Multiple edges found between ' { entity_name_label_source } ' and ' { entity_name_label_target } ' . Using first edge. "
2025-03-08 04:28:54 +08:00
)
2025-03-08 10:19:20 +08:00
if records :
try :
edge_result = dict ( records [ 0 ] [ " edge_properties " ] )
logger . debug ( f " Result: { edge_result } " )
# Ensure required keys exist with defaults
required_keys = {
" weight " : 0.0 ,
" source_id " : None ,
" description " : None ,
" keywords " : None ,
}
for key , default_value in required_keys . items ( ) :
if key not in edge_result :
edge_result [ key ] = default_value
logger . warning (
f " Edge between { entity_name_label_source } and { entity_name_label_target } "
f " missing { key } , using default: { default_value } "
)
2024-11-06 11:18:14 -05:00
2025-03-08 10:19:20 +08:00
logger . debug (
f " { inspect . currentframe ( ) . f_code . co_name } :query: { query } :result: { edge_result } "
)
return edge_result
except ( KeyError , TypeError , ValueError ) as e :
logger . error (
f " Error processing edge properties between { entity_name_label_source } "
f " and { entity_name_label_target } : { str ( e ) } "
)
# Return default edge properties on error
return {
" weight " : 0.0 ,
" source_id " : None ,
" description " : None ,
" keywords " : None ,
}
logger . debug (
f " { inspect . currentframe ( ) . f_code . co_name } : No edge found between { entity_name_label_source } and { entity_name_label_target } "
)
# Return default edge properties when no edge found
return {
" weight " : 0.0 ,
" source_id " : None ,
" description " : None ,
" keywords " : None ,
}
finally :
await result . consume ( ) # Ensure result is fully consumed
2025-02-14 16:04:06 +01:00
except Exception as e :
logger . error (
f " Error in get_edge between { source_node_id } and { target_node_id } : { str ( e ) } "
)
2025-03-08 10:19:20 +08:00
raise
2024-10-29 15:36:07 -04:00
2025-02-16 13:53:59 +01:00
async def get_node_edges ( self , source_node_id : str ) - > list [ tuple [ str , str ] ] | None :
2025-03-08 10:19:20 +08:00
""" Retrieves all edges (relationships) for a particular node identified by its label.
2024-11-06 11:18:14 -05:00
2025-03-08 10:19:20 +08:00
Args :
source_node_id : Label of the node to get edges for
Returns :
list [ tuple [ str , str ] ] : List of ( source_label , target_label ) tuples representing edges
None : If no edges found
Raises :
ValueError : If source_node_id is invalid
Exception : If there is an error executing the query
2024-10-29 15:36:07 -04:00
"""
2025-03-08 10:19:20 +08:00
try :
2025-03-08 10:23:27 +08:00
node_label = self . _ensure_label ( source_node_id )
2024-11-06 11:18:14 -05:00
2025-03-08 10:19:20 +08:00
query = f """ MATCH (n:` { node_label } `)
OPTIONAL MATCH ( n ) - [ r ] - ( connected )
RETURN n , r , connected """
2024-11-06 11:18:14 -05:00
2025-03-08 10:19:20 +08:00
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
try :
results = await session . run ( query )
edges = [ ]
2024-10-29 15:36:07 -04:00
2025-03-08 10:19:20 +08:00
async for record in results :
source_node = record [ " n " ]
connected_node = record [ " connected " ]
source_label = (
list ( source_node . labels ) [ 0 ] if source_node . labels else None
)
target_label = (
list ( connected_node . labels ) [ 0 ]
if connected_node and connected_node . labels
else None
)
if source_label and target_label :
edges . append ( ( source_label , target_label ) )
await results . consume ( ) # Ensure results are consumed
2025-03-08 11:20:22 +08:00
return edges
2025-03-08 10:19:20 +08:00
except Exception as e :
logger . error ( f " Error getting edges for node { node_label } : { str ( e ) } " )
await results . consume ( ) # Ensure results are consumed even on error
raise
except Exception as e :
logger . error ( f " Error in get_node_edges for { source_node_id } : { str ( e ) } " )
raise
2024-10-29 15:36:07 -04:00
2024-10-30 17:48:14 -04:00
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
2024-11-06 11:18:14 -05:00
retry = retry_if_exception_type (
(
2025-01-04 22:33:35 +08:00
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
neo4jExceptions . ClientError ,
2024-11-06 11:18:14 -05:00
)
) ,
2024-10-30 17:48:14 -04:00
)
2025-02-16 13:53:59 +01:00
async def upsert_node ( self , node_id : str , node_data : dict [ str , str ] ) - > None :
2024-10-26 19:29:45 -04:00
"""
2024-11-02 18:35:07 -04:00
Upsert a node in the Neo4j database .
2024-10-26 19:29:45 -04:00
Args :
2024-11-02 18:35:07 -04:00
node_id : The unique identifier for the node ( used as label )
node_data : Dictionary of node properties
2024-10-26 19:29:45 -04:00
"""
2025-03-08 10:23:27 +08:00
label = self . _ensure_label ( node_id )
2024-11-02 18:35:07 -04:00
properties = node_data
2025-03-09 00:24:55 +08:00
if " entity_id " not in properties :
raise ValueError ( " Neo4j: node properties must contain an ' entity_id ' field " )
2024-10-29 15:36:07 -04:00
2024-11-02 18:35:07 -04:00
try :
2025-01-04 21:47:52 +08:00
async with self . _driver . session ( database = self . _DATABASE ) as session :
2025-03-09 01:00:42 +08:00
2025-03-08 11:36:24 +08:00
async def execute_upsert ( tx : AsyncManagedTransaction ) :
query = f """
2025-03-09 00:24:55 +08:00
MERGE ( n : ` { label } ` { { entity_id : $ properties . entity_id } } )
2025-03-08 11:36:24 +08:00
SET n + = $ properties
"""
result = await tx . run ( query , properties = properties )
logger . debug (
f " Upserted node with label ' { label } ' and properties: { properties } "
)
await result . consume ( ) # Ensure result is fully consumed
2025-03-09 01:00:42 +08:00
2025-03-08 11:36:24 +08:00
await session . execute_write ( execute_upsert )
2024-11-02 18:35:07 -04:00
except Exception as e :
logger . error ( f " Error during upsert: { str ( e ) } " )
raise
2024-11-06 11:18:14 -05:00
2025-03-09 00:24:55 +08:00
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type (
(
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
neo4jExceptions . ClientError ,
)
) ,
)
async def _get_unique_node_entity_id ( self , node_label : str ) - > str :
"""
Get the entity_id of a node with the given label , ensuring the node is unique .
Args :
node_label ( str ) : Label of the node to check
Returns :
str : The entity_id of the unique node
Raises :
ValueError : If no node with the given label exists or if multiple nodes have the same label
"""
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
query = f """
MATCH ( n : ` { node_label } ` )
RETURN n , count ( n ) as node_count
"""
result = await session . run ( query )
try :
2025-03-09 01:00:42 +08:00
records = await result . fetch (
2
) # We only need to know if there are 0, 1, or >1 nodes
2025-03-09 00:24:55 +08:00
if not records or records [ 0 ] [ " node_count " ] == 0 :
2025-03-09 01:00:42 +08:00
raise ValueError (
f " Neo4j: node with label ' { node_label } ' does not exist "
)
2025-03-09 00:24:55 +08:00
if records [ 0 ] [ " node_count " ] > 1 :
2025-03-09 01:00:42 +08:00
raise ValueError (
f " Neo4j: multiple nodes found with label ' { node_label } ' , cannot determine unique node "
)
2025-03-09 00:24:55 +08:00
node = records [ 0 ] [ " n " ]
if " entity_id " not in node :
2025-03-09 01:00:42 +08:00
raise ValueError (
f " Neo4j: node with label ' { node_label } ' does not have an entity_id property "
)
2025-03-09 00:24:55 +08:00
return node [ " entity_id " ]
finally :
await result . consume ( ) # Ensure result is fully consumed
2024-11-02 18:35:07 -04:00
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
2024-11-06 11:18:14 -05:00
retry = retry_if_exception_type (
(
2025-01-04 22:33:35 +08:00
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
2025-02-14 16:04:06 +01:00
neo4jExceptions . ClientError ,
2024-11-06 11:18:14 -05:00
)
) ,
2024-11-02 18:35:07 -04:00
)
2024-11-06 11:18:14 -05:00
async def upsert_edge (
2025-02-16 13:53:59 +01:00
self , source_node_id : str , target_node_id : str , edge_data : dict [ str , str ]
) - > None :
2024-10-26 19:29:45 -04:00
"""
Upsert an edge and its properties between two nodes identified by their labels .
2025-03-09 00:24:55 +08:00
Ensures both source and target nodes exist and are unique before creating the edge .
Uses entity_id property to uniquely identify nodes .
2024-11-02 18:35:07 -04:00
2024-10-26 19:29:45 -04:00
Args :
2024-11-02 18:35:07 -04:00
source_node_id ( str ) : Label of the source node ( used as identifier )
target_node_id ( str ) : Label of the target node ( used as identifier )
edge_data ( dict ) : Dictionary of properties to set on the edge
2025-03-08 01:20:36 +08:00
Raises :
2025-03-09 00:24:55 +08:00
ValueError : If either source or target node does not exist or is not unique
2024-10-26 19:29:45 -04:00
"""
2025-03-08 10:23:27 +08:00
source_label = self . _ensure_label ( source_node_id )
target_label = self . _ensure_label ( target_node_id )
2024-11-02 18:35:07 -04:00
edge_properties = edge_data
2024-10-29 15:36:07 -04:00
2025-03-09 00:24:55 +08:00
# Get entity_ids for source and target nodes, ensuring they are unique
source_entity_id = await self . _get_unique_node_entity_id ( source_label )
target_entity_id = await self . _get_unique_node_entity_id ( target_label )
2025-03-08 01:20:36 +08:00
2024-11-02 18:35:07 -04:00
try :
2025-01-04 21:47:52 +08:00
async with self . _driver . session ( database = self . _DATABASE ) as session :
2025-03-09 01:00:42 +08:00
2025-03-08 11:36:24 +08:00
async def execute_upsert ( tx : AsyncManagedTransaction ) :
query = f """
2025-03-09 00:24:55 +08:00
MATCH ( source : ` { source_label } ` { { entity_id : $ source_entity_id } } )
2025-03-08 11:36:24 +08:00
WITH source
2025-03-09 00:24:55 +08:00
MATCH ( target : ` { target_label } ` { { entity_id : $ target_entity_id } } )
2025-03-08 11:36:24 +08:00
MERGE ( source ) - [ r : DIRECTED ] - ( target )
SET r + = $ properties
RETURN r , source , target
"""
2025-03-09 00:24:55 +08:00
result = await tx . run (
2025-03-09 01:00:42 +08:00
query ,
2025-03-09 00:24:55 +08:00
source_entity_id = source_entity_id ,
target_entity_id = target_entity_id ,
2025-03-09 01:00:42 +08:00
properties = edge_properties ,
2025-03-09 00:24:55 +08:00
)
2025-03-08 11:36:24 +08:00
try :
records = await result . fetch ( 100 )
if records :
logger . debug (
2025-03-09 00:24:55 +08:00
f " Upserted edge from ' { source_label } ' (entity_id: { source_entity_id } ) "
f " to ' { target_label } ' (entity_id: { target_entity_id } ) "
2025-03-08 11:36:24 +08:00
f " with properties: { edge_properties } "
)
finally :
await result . consume ( ) # Ensure result is consumed
2025-03-09 01:00:42 +08:00
2025-03-08 11:36:24 +08:00
await session . execute_write ( execute_upsert )
2024-11-02 18:35:07 -04:00
except Exception as e :
logger . error ( f " Error during edge upsert: { str ( e ) } " )
raise
2024-11-06 11:18:14 -05:00
2024-10-26 19:29:45 -04:00
async def _node2vec_embed ( self ) :
2024-11-06 11:18:14 -05:00
print ( " Implemented but never called. " )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
2025-02-20 14:29:36 +01:00
async def get_knowledge_graph (
2025-03-08 01:20:36 +08:00
self ,
node_label : str ,
max_depth : int = 3 ,
min_degree : int = 0 ,
inclusive : bool = False ,
2025-02-20 14:29:36 +01:00
) - > KnowledgeGraph :
"""
2025-03-02 16:20:37 +08:00
Retrieve a connected subgraph of nodes where the label includes the specified ` node_label ` .
2025-03-02 15:39:14 +08:00
Maximum number of nodes is constrained by the environment variable ` MAX_GRAPH_NODES ` ( default : 1000 ) .
2025-03-02 17:32:25 +08:00
When reducing the number of nodes , the prioritization criteria are as follows :
2025-03-08 01:20:36 +08:00
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally , the degree of the nodes
2025-02-20 14:29:36 +01:00
2025-03-02 15:39:14 +08:00
Args :
2025-03-08 01:20:36 +08:00
node_label : Label of the starting node
max_depth : Maximum depth of the subgraph
min_degree : Minimum degree of nodes to include . Defaults to 0
inclusive : Do an inclusive search if true
2025-03-02 15:39:14 +08:00
Returns :
KnowledgeGraph : Complete connected subgraph for specified node
2025-02-20 14:29:36 +01:00
"""
label = node_label . strip ( ' " ' )
result = KnowledgeGraph ( )
seen_nodes = set ( )
seen_edges = set ( )
2025-03-08 02:39:51 +08:00
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
2025-02-20 14:29:36 +01:00
try :
if label == " * " :
main_query = """
MATCH ( n )
2025-03-02 15:39:14 +08:00
OPTIONAL MATCH ( n ) - [ r ] - ( )
WITH n , count ( r ) AS degree
2025-03-08 01:20:36 +08:00
WHERE degree > = $ min_degree
2025-03-02 15:39:14 +08:00
ORDER BY degree DESC
LIMIT $ max_nodes
2025-03-08 01:20:36 +08:00
WITH collect ( { node : n } ) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect ( node_info . node ) AS kept_nodes , filtered_nodes
MATCH ( a ) - [ r ] - ( b )
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info ,
collect ( DISTINCT r ) AS relationships
2025-02-20 14:29:36 +01:00
"""
2025-03-02 15:39:14 +08:00
result_set = await session . run (
2025-03-08 01:20:36 +08:00
main_query ,
{ " max_nodes " : MAX_GRAPH_NODES , " min_degree " : min_degree } ,
2025-03-02 15:39:14 +08:00
)
2025-02-20 14:29:36 +01:00
else :
2025-03-02 16:20:37 +08:00
# Main query uses partial matching
2025-03-08 01:20:36 +08:00
main_query = """
2025-03-02 16:20:37 +08:00
MATCH ( start )
2025-03-08 01:20:36 +08:00
WHERE any ( label IN labels ( start ) WHERE
CASE
WHEN $ inclusive THEN label CONTAINS $ label
ELSE label = $ label
END
)
2025-02-20 14:29:36 +01:00
WITH start
2025-03-08 01:20:36 +08:00
CALL apoc . path . subgraphAll ( start , {
relationshipFilter : ' ' ,
2025-02-20 14:29:36 +01:00
minLevel : 0 ,
2025-03-08 01:20:36 +08:00
maxLevel : $ max_depth ,
2025-02-20 14:29:36 +01:00
bfs : true
2025-03-08 01:20:36 +08:00
} )
2025-02-20 14:29:36 +01:00
YIELD nodes , relationships
2025-03-02 15:39:14 +08:00
WITH start , nodes , relationships
UNWIND nodes AS node
OPTIONAL MATCH ( node ) - [ r ] - ( )
2025-03-08 01:20:36 +08:00
WITH node , count ( r ) AS degree , start , nodes , relationships
WHERE node = start OR EXISTS ( ( start ) - - ( node ) ) OR degree > = $ min_degree
ORDER BY
CASE
WHEN node = start THEN 3
WHEN EXISTS ( ( start ) - - ( node ) ) THEN 2
ELSE 1
END DESC ,
degree DESC
2025-03-02 15:39:14 +08:00
LIMIT $ max_nodes
2025-03-08 01:20:36 +08:00
WITH collect ( { node : node } ) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect ( node_info . node ) AS kept_nodes , filtered_nodes
MATCH ( a ) - [ r ] - ( b )
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info ,
collect ( DISTINCT r ) AS relationships
2025-02-20 14:29:36 +01:00
"""
2025-03-02 15:39:14 +08:00
result_set = await session . run (
2025-03-08 01:20:36 +08:00
main_query ,
{
" max_nodes " : MAX_GRAPH_NODES ,
" label " : label ,
" inclusive " : inclusive ,
" max_depth " : max_depth ,
" min_degree " : min_degree ,
} ,
2025-03-02 15:39:14 +08:00
)
2025-03-08 01:20:36 +08:00
try :
record = await result_set . single ( )
if record :
# Handle nodes (compatible with multi-label cases)
for node_info in record [ " node_info " ] :
node = node_info [ " node " ]
node_id = node . id
if node_id not in seen_nodes :
result . nodes . append (
KnowledgeGraphNode (
id = f " { node_id } " ,
labels = list ( node . labels ) ,
properties = dict ( node ) ,
)
2025-02-20 14:29:36 +01:00
)
2025-03-08 01:20:36 +08:00
seen_nodes . add ( node_id )
# Handle relationships (including direction information)
for rel in record [ " relationships " ] :
edge_id = rel . id
if edge_id not in seen_edges :
start = rel . start_node
end = rel . end_node
result . edges . append (
KnowledgeGraphEdge (
id = f " { edge_id } " ,
type = rel . type ,
source = f " { start . id } " ,
target = f " { end . id } " ,
properties = dict ( rel ) ,
)
2025-02-20 14:29:36 +01:00
)
2025-03-08 01:20:36 +08:00
seen_edges . add ( edge_id )
2025-02-20 14:29:36 +01:00
2025-03-08 01:20:36 +08:00
logger . info (
2025-03-09 15:25:10 +08:00
f " Process { os . getpid ( ) } graph query return: { len ( result . nodes ) } nodes, { len ( result . edges ) } edges "
2025-03-08 01:20:36 +08:00
)
finally :
await result_set . consume ( ) # Ensure result set is consumed
2025-02-20 14:29:36 +01:00
except neo4jExceptions . ClientError as e :
2025-03-08 04:28:54 +08:00
logger . warning ( f " APOC plugin error: { str ( e ) } " )
if label != " * " :
2025-03-08 01:20:36 +08:00
logger . warning (
2025-03-08 04:28:54 +08:00
" Neo4j: falling back to basic Cypher recursive search... "
2025-03-08 01:20:36 +08:00
)
2025-03-08 04:28:54 +08:00
if inclusive :
logger . warning (
" Neo4j: inclusive search mode is not supported in recursive query, using exact matching "
)
return await self . _robust_fallback ( label , max_depth , min_degree )
2025-02-20 14:29:36 +01:00
return result
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
async def _robust_fallback (
2025-03-08 01:20:36 +08:00
self , label : str , max_depth : int , min_degree : int = 0
2025-03-08 04:28:54 +08:00
) - > KnowledgeGraph :
2025-03-08 01:20:36 +08:00
"""
Fallback implementation when APOC plugin is not available or incompatible .
This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures .
"""
2025-03-08 04:28:54 +08:00
result = KnowledgeGraph ( )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
visited_nodes = set ( )
visited_edges = set ( )
2025-03-02 17:32:25 +08:00
2025-03-08 04:28:54 +08:00
async def traverse (
node : KnowledgeGraphNode ,
edge : Optional [ KnowledgeGraphEdge ] ,
current_depth : int ,
) :
2025-03-08 01:20:36 +08:00
# Check traversal limits
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
if current_depth > max_depth :
2025-03-08 01:20:36 +08:00
logger . debug ( f " Reached max depth: { max_depth } " )
return
if len ( visited_nodes ) > = MAX_GRAPH_NODES :
logger . debug ( f " Reached max nodes limit: { MAX_GRAPH_NODES } " )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
return
2025-03-08 04:28:54 +08:00
# Check if node already visited
if node . id in visited_nodes :
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
return
2025-03-08 04:28:54 +08:00
# Get all edges and target nodes
2025-03-08 02:39:51 +08:00
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
2025-03-08 04:28:54 +08:00
query = """
MATCH ( a ) - [ r ] - ( b )
WHERE id ( a ) = toInteger ( $ node_id )
WITH r , b , id ( r ) as edge_id , id ( b ) as target_id
RETURN r , b , edge_id , target_id
"""
results = await session . run ( query , { " node_id " : node . id } )
# Get all records and release database connection
2025-03-09 01:00:42 +08:00
records = await results . fetch (
1000
) # Max neighbour nodes we can handled
2025-03-08 04:28:54 +08:00
await results . consume ( ) # Ensure results are consumed
# Nodes not connected to start node need to check degree
if current_depth > 1 and len ( records ) < min_degree :
return
# Add current node to result
result . nodes . append ( node )
visited_nodes . add ( node . id )
# Add edge to result if it exists and not already added
if edge and edge . id not in visited_edges :
result . edges . append ( edge )
visited_edges . add ( edge . id )
# Prepare nodes and edges for recursive processing
nodes_to_process = [ ]
for record in records :
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
rel = record [ " r " ]
2025-03-08 04:28:54 +08:00
edge_id = str ( record [ " edge_id " ] )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
if edge_id not in visited_edges :
2025-03-08 01:20:36 +08:00
b_node = record [ " b " ]
2025-03-08 04:28:54 +08:00
target_id = str ( record [ " target_id " ] )
2025-03-08 01:20:36 +08:00
if b_node . labels : # Only process if target node has labels
2025-03-08 04:28:54 +08:00
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode (
2025-03-08 10:20:10 +08:00
id = f " { target_id } " ,
2025-03-08 04:28:54 +08:00
labels = list ( b_node . labels ) ,
properties = dict ( b_node ) ,
2025-03-08 02:39:51 +08:00
)
2025-03-08 01:20:36 +08:00
2025-03-08 04:28:54 +08:00
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge (
2025-03-08 10:20:10 +08:00
id = f " { edge_id } " ,
2025-03-08 04:28:54 +08:00
type = rel . type ,
2025-03-08 10:20:10 +08:00
source = f " { node . id } " ,
target = f " { target_id } " ,
2025-03-08 04:28:54 +08:00
properties = dict ( rel ) ,
)
nodes_to_process . append ( ( target_node , target_edge ) )
2025-03-08 01:20:36 +08:00
else :
2025-03-08 02:39:51 +08:00
logger . warning (
f " Skipping edge { edge_id } due to missing labels on target node "
)
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
2025-03-08 04:28:54 +08:00
# Process nodes after releasing database connection
for target_node , target_edge in nodes_to_process :
await traverse ( target_node , target_edge , current_depth + 1 )
# Get the starting node's data
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
query = f """
MATCH ( n : ` { label } ` )
RETURN id ( n ) as node_id , n
"""
node_result = await session . run ( query )
try :
node_record = await node_result . single ( )
if not node_record :
return result
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode (
2025-03-08 10:20:10 +08:00
id = f " { node_record [ ' node_id ' ] } " ,
2025-03-08 04:28:54 +08:00
labels = list ( node_record [ " n " ] . labels ) ,
properties = dict ( node_record [ " n " ] ) ,
)
finally :
await node_result . consume ( ) # Ensure results are consumed
# Start traversal with the initial node
await traverse ( start_node , None , 0 )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
return result
2025-02-20 15:09:43 +01:00
async def get_all_labels ( self ) - > list [ str ] :
"""
Get all existing node labels in the database
Returns :
[ " Person " , " Company " , . . . ] # Alphabetically sorted label list
"""
2025-03-08 02:39:51 +08:00
async with self . _driver . session (
database = self . _DATABASE , default_access_mode = " READ "
) as session :
2025-02-20 15:09:43 +01:00
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label"
# Method 2: Query compatible with older versions
query = """
MATCH ( n )
WITH DISTINCT labels ( n ) AS node_labels
UNWIND node_labels AS label
RETURN DISTINCT label
ORDER BY label
"""
result = await session . run ( query )
labels = [ ]
2025-03-08 01:20:36 +08:00
try :
async for record in result :
labels . append ( record [ " label " ] )
finally :
2025-03-08 02:39:51 +08:00
await (
result . consume ( )
) # Ensure results are consumed even if processing fails
2025-02-20 15:09:43 +01:00
return labels
2025-03-04 14:20:55 +08:00
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type (
(
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
neo4jExceptions . ClientError ,
)
) ,
)
2025-02-16 13:53:59 +01:00
async def delete_node ( self , node_id : str ) - > None :
2025-03-04 14:20:55 +08:00
""" Delete a node with the specified label
Args :
node_id : The label of the node to delete
"""
2025-03-08 10:23:27 +08:00
label = self . _ensure_label ( node_id )
2025-03-04 14:20:55 +08:00
async def _do_delete ( tx : AsyncManagedTransaction ) :
query = f """
MATCH ( n : ` { label } ` )
DETACH DELETE n
"""
2025-03-08 02:39:51 +08:00
result = await tx . run ( query )
2025-03-04 14:20:55 +08:00
logger . debug ( f " Deleted node with label ' { label } ' " )
2025-03-08 02:39:51 +08:00
await result . consume ( ) # Ensure result is fully consumed
2025-03-04 14:20:55 +08:00
try :
async with self . _driver . session ( database = self . _DATABASE ) as session :
await session . execute_write ( _do_delete )
except Exception as e :
logger . error ( f " Error during node deletion: { str ( e ) } " )
raise
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type (
(
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
neo4jExceptions . ClientError ,
)
) ,
)
async def remove_nodes ( self , nodes : list [ str ] ) :
""" Delete multiple nodes
Args :
nodes : List of node labels to be deleted
"""
for node in nodes :
await self . delete_node ( node )
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type (
(
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
neo4jExceptions . ClientError ,
)
) ,
)
async def remove_edges ( self , edges : list [ tuple [ str , str ] ] ) :
""" Delete multiple edges
Args :
edges : List of edges to be deleted , each edge is a ( source , target ) tuple
"""
for source , target in edges :
2025-03-08 10:23:27 +08:00
source_label = self . _ensure_label ( source )
target_label = self . _ensure_label ( target )
2025-03-04 14:20:55 +08:00
async def _do_delete_edge ( tx : AsyncManagedTransaction ) :
query = f """
2025-03-08 01:20:36 +08:00
MATCH ( source : ` { source_label } ` ) - [ r ] - ( target : ` { target_label } ` )
2025-03-04 14:20:55 +08:00
DELETE r
"""
2025-03-08 02:39:51 +08:00
result = await tx . run ( query )
2025-03-04 14:20:55 +08:00
logger . debug ( f " Deleted edge from ' { source_label } ' to ' { target_label } ' " )
2025-03-08 02:39:51 +08:00
await result . consume ( ) # Ensure result is fully consumed
2025-03-04 14:20:55 +08:00
try :
async with self . _driver . session ( database = self . _DATABASE ) as session :
await session . execute_write ( _do_delete_edge )
except Exception as e :
logger . error ( f " Error during edge deletion: { str ( e ) } " )
raise
2025-02-16 13:55:30 +01:00
2025-02-16 13:53:59 +01:00
async def embed_nodes (
self , algorithm : str
) - > tuple [ np . ndarray [ Any , Any ] , list [ str ] ] :
2025-02-16 13:55:30 +01:00
raise NotImplementedError