2024-10-26 19:29:45 -04:00
import asyncio
2025-01-04 21:47:52 +08:00
import inspect
2024-10-26 19:29:45 -04:00
import os
2025-02-08 16:06:07 +08:00
import re
2024-10-26 19:29:45 -04:00
from dataclasses import dataclass
2025-02-16 14:38:09 +01:00
from typing import Any , List , Dict , final
2025-02-16 13:53:59 +01:00
import numpy as np
2025-02-11 00:55:52 +08:00
import configparser
2025-01-27 23:21:34 +08:00
2025-01-04 21:47:52 +08:00
2024-10-30 17:48:14 -04:00
from tenacity import (
retry ,
stop_after_attempt ,
wait_exponential ,
retry_if_exception_type ,
)
2025-03-07 16:56:48 +08:00
import logging
2025-02-08 16:06:07 +08:00
from . . utils import logger
2025-01-04 21:47:52 +08:00
from . . base import BaseGraphStorage
2025-02-20 14:29:36 +01:00
from . . types import KnowledgeGraph , KnowledgeGraphNode , KnowledgeGraphEdge
2025-02-16 16:04:07 +01:00
import pipmaster as pm
2025-01-04 21:47:52 +08:00
2025-02-16 16:04:07 +01:00
if not pm . is_installed ( " neo4j " ) :
pm . install ( " neo4j " )
2025-02-16 16:04:35 +01:00
2025-03-02 15:39:14 +08:00
from neo4j import ( # type: ignore
2025-02-19 19:32:23 +01:00
AsyncGraphDatabase ,
exceptions as neo4jExceptions ,
AsyncDriver ,
AsyncManagedTransaction ,
GraphDatabase ,
)
2025-02-16 14:38:09 +01:00
2025-02-11 00:55:52 +08:00
config = configparser . ConfigParser ( )
config . read ( " config.ini " , " utf-8 " )
2025-03-02 16:20:37 +08:00
# Get maximum number of graph nodes from environment variable, default is 1000
2025-03-02 15:39:14 +08:00
MAX_GRAPH_NODES = int ( os . getenv ( " MAX_GRAPH_NODES " , 1000 ) )
2025-03-07 16:56:48 +08:00
# Set neo4j logger level to ERROR to suppress warning logs
logging . getLogger ( " neo4j " ) . setLevel ( logging . ERROR )
2025-02-11 03:29:40 +08:00
2025-03-08 01:20:36 +08:00
2025-02-16 14:38:09 +01:00
@final
2024-10-26 19:29:45 -04:00
@dataclass
2024-11-02 18:35:07 -04:00
class Neo4JStorage ( BaseGraphStorage ) :
2024-12-02 02:44:47 +05:30
def __init__ ( self , namespace , global_config , embedding_func ) :
super ( ) . __init__ (
namespace = namespace ,
global_config = global_config ,
embedding_func = embedding_func ,
)
2024-11-02 18:35:07 -04:00
self . _driver = None
self . _driver_lock = asyncio . Lock ( )
2025-02-11 00:55:52 +08:00
2025-02-13 18:07:24 +03:00
URI = os . environ . get ( " NEO4J_URI " , config . get ( " neo4j " , " uri " , fallback = None ) )
USERNAME = os . environ . get (
2025-02-11 03:29:40 +08:00
" NEO4J_USERNAME " , config . get ( " neo4j " , " username " , fallback = None )
2025-02-13 18:07:24 +03:00
)
2025-02-13 18:09:24 +03:00
PASSWORD = os . environ . get (
2025-02-11 03:29:40 +08:00
" NEO4J_PASSWORD " , config . get ( " neo4j " , " password " , fallback = None )
2025-02-13 18:07:24 +03:00
)
2025-02-13 15:26:45 +03:00
MAX_CONNECTION_POOL_SIZE = int (
os . environ . get (
" NEO4J_MAX_CONNECTION_POOL_SIZE " ,
2025-03-08 01:20:36 +08:00
config . get ( " neo4j " , " connection_pool_size " , fallback = 50 ) , # Reduced from 800
2025-02-13 15:26:45 +03:00
)
2025-02-11 03:29:40 +08:00
)
2025-02-17 12:58:04 +03:00
CONNECTION_TIMEOUT = float (
2025-02-17 20:54:08 +03:00
os . environ . get (
" NEO4J_CONNECTION_TIMEOUT " ,
2025-03-08 01:20:36 +08:00
config . get ( " neo4j " , " connection_timeout " , fallback = 30.0 ) , # Reduced from 60.0
2025-02-17 20:54:08 +03:00
) ,
2025-02-17 12:58:04 +03:00
)
CONNECTION_ACQUISITION_TIMEOUT = float (
2025-02-17 20:54:08 +03:00
os . environ . get (
" NEO4J_CONNECTION_ACQUISITION_TIMEOUT " ,
2025-03-08 01:20:36 +08:00
config . get ( " neo4j " , " connection_acquisition_timeout " , fallback = 30.0 ) , # Reduced from 60.0
) ,
)
MAX_TRANSACTION_RETRY_TIME = float (
os . environ . get (
" NEO4J_MAX_TRANSACTION_RETRY_TIME " ,
config . get ( " neo4j " , " max_transaction_retry_time " , fallback = 30.0 ) ,
2025-02-17 20:54:08 +03:00
) ,
2025-02-17 12:58:04 +03:00
)
2025-01-04 21:47:52 +08:00
DATABASE = os . environ . get (
2025-02-08 16:06:07 +08:00
" NEO4J_DATABASE " , re . sub ( r " [^a-zA-Z0-9-] " , " - " , namespace )
)
2025-02-11 00:55:52 +08:00
2024-11-06 11:18:14 -05:00
self . _driver : AsyncDriver = AsyncGraphDatabase . driver (
2025-02-17 12:58:04 +03:00
URI ,
auth = ( USERNAME , PASSWORD ) ,
max_connection_pool_size = MAX_CONNECTION_POOL_SIZE ,
connection_timeout = CONNECTION_TIMEOUT ,
connection_acquisition_timeout = CONNECTION_ACQUISITION_TIMEOUT ,
2025-03-08 01:20:36 +08:00
max_transaction_retry_time = MAX_TRANSACTION_RETRY_TIME ,
2024-11-06 11:18:14 -05:00
)
2025-02-08 16:06:07 +08:00
# Try to connect to the database
2025-01-13 07:27:30 +00:00
with GraphDatabase . driver (
URI ,
auth = ( USERNAME , PASSWORD ) ,
max_connection_pool_size = MAX_CONNECTION_POOL_SIZE ,
2025-02-17 12:58:04 +03:00
connection_timeout = CONNECTION_TIMEOUT ,
connection_acquisition_timeout = CONNECTION_ACQUISITION_TIMEOUT ,
2025-01-13 07:27:30 +00:00
) as _sync_driver :
2025-02-08 16:06:07 +08:00
for database in ( DATABASE , None ) :
self . _DATABASE = database
connected = False
2025-01-04 21:47:52 +08:00
try :
2025-02-08 16:06:07 +08:00
with _sync_driver . session ( database = database ) as session :
try :
session . run ( " MATCH (n) RETURN n LIMIT 0 " )
logger . info ( f " Connected to { database } at { URI } " )
connected = True
except neo4jExceptions . ServiceUnavailable as e :
logger . error (
f " { database } at { URI } is not available " . capitalize ( )
)
raise e
except neo4jExceptions . AuthError as e :
logger . error ( f " Authentication failed for { database } at { URI } " )
raise e
2025-01-04 21:47:52 +08:00
except neo4jExceptions . ClientError as e :
2025-02-08 16:06:07 +08:00
if e . code == " Neo.ClientError.Database.DatabaseNotFound " :
logger . info (
f " { database } at { URI } not found. Try to create specified database. " . capitalize ( )
2025-01-04 22:33:35 +08:00
)
2025-02-08 16:06:07 +08:00
try :
with _sync_driver . session ( ) as session :
session . run (
f " CREATE DATABASE ` { database } ` IF NOT EXISTS "
)
logger . info ( f " { database } at { URI } created " . capitalize ( ) )
connected = True
except (
neo4jExceptions . ClientError ,
neo4jExceptions . DatabaseError ,
) as e :
if (
e . code
== " Neo.ClientError.Statement.UnsupportedAdministrationCommand "
) or (
e . code == " Neo.DatabaseError.Statement.ExecutionFailed "
) :
if database is not None :
logger . warning (
" This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database. "
)
if database is None :
logger . error ( f " Failed to create { database } at { URI } " )
raise e
if connected :
break
2024-11-02 18:35:07 -04:00
2024-10-26 19:29:45 -04:00
def __post_init__ ( self ) :
self . _node_embed_algorithms = {
" node2vec " : self . _node2vec_embed ,
}
2024-11-02 18:35:07 -04:00
async def close ( self ) :
if self . _driver :
await self . _driver . close ( )
self . _driver = None
async def __aexit__ ( self , exc_type , exc , tb ) :
if self . _driver :
await self . _driver . close ( )
2025-02-16 14:38:09 +01:00
async def index_done_callback ( self ) - > None :
2025-02-16 16:04:07 +01:00
# Noe4J handles persistence automatically
2025-02-16 14:38:09 +01:00
pass
2024-11-02 18:35:07 -04:00
2025-02-14 16:04:06 +01:00
async def _ensure_label ( self , label : str ) - > str :
2025-03-07 16:43:18 +08:00
""" Ensure a label is valid
2025-03-08 01:20:36 +08:00
2025-03-07 16:43:18 +08:00
Args :
label : The label to validate
"""
2025-02-14 16:04:06 +01:00
clean_label = label . strip ( ' " ' )
2025-03-08 01:20:36 +08:00
if not clean_label :
raise ValueError ( " Neo4j: Label cannot be empty " )
2025-02-14 16:04:06 +01:00
return clean_label
async def has_node ( self , node_id : str ) - > bool :
entity_name_label = await self . _ensure_label ( node_id )
2025-03-08 01:20:36 +08:00
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
2024-11-06 11:18:14 -05:00
query = (
f " MATCH (n:` { entity_name_label } `) RETURN count(n) > 0 AS node_exists "
)
result = await session . run ( query )
2024-11-02 18:35:07 -04:00
single_result = await result . single ( )
2025-03-08 01:20:36 +08:00
await result . consume ( ) # Ensure result is fully consumed
2024-11-01 11:01:50 -04:00
logger . debug (
2025-02-08 16:06:07 +08:00
f " { inspect . currentframe ( ) . f_code . co_name } :query: { query } :result: { single_result [ ' node_exists ' ] } "
2024-11-06 11:18:14 -05:00
)
2024-10-29 15:36:07 -04:00
return single_result [ " node_exists " ]
2024-11-06 11:18:14 -05:00
2024-10-29 15:36:07 -04:00
async def has_edge ( self , source_node_id : str , target_node_id : str ) - > bool :
2024-11-06 11:18:14 -05:00
entity_name_label_source = source_node_id . strip ( ' " ' )
entity_name_label_target = target_node_id . strip ( ' " ' )
2025-03-08 01:20:36 +08:00
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
2024-11-06 11:18:14 -05:00
query = (
f " MATCH (a:` { entity_name_label_source } `)-[r]-(b:` { entity_name_label_target } `) "
" RETURN COUNT(r) > 0 AS edgeExists "
)
result = await session . run ( query )
2024-11-02 18:35:07 -04:00
single_result = await result . single ( )
2025-03-08 01:20:36 +08:00
await result . consume ( ) # Ensure result is fully consumed
2024-11-01 11:01:50 -04:00
logger . debug (
2025-02-08 16:06:07 +08:00
f " { inspect . currentframe ( ) . f_code . co_name } :query: { query } :result: { single_result [ ' edgeExists ' ] } "
2024-11-06 11:18:14 -05:00
)
2024-10-29 15:36:07 -04:00
return single_result [ " edgeExists " ]
2024-10-26 19:29:45 -04:00
2025-02-16 13:53:59 +01:00
async def get_node ( self , node_id : str ) - > dict [ str , str ] | None :
2025-02-14 16:04:06 +01:00
""" Get node by its label identifier.
Args :
node_id : The node label to look up
Returns :
dict : Node properties if found
None : If node not found
"""
2025-03-08 01:20:36 +08:00
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
2025-02-14 16:04:06 +01:00
entity_name_label = await self . _ensure_label ( node_id )
2024-11-02 18:35:07 -04:00
query = f " MATCH (n:` { entity_name_label } `) RETURN n "
result = await session . run ( query )
2025-03-08 01:20:36 +08:00
records = await result . fetch ( 2 ) # Get up to 2 records to check for duplicates
await result . consume ( ) # Ensure result is fully consumed
if len ( records ) > 1 :
logger . warning ( f " Multiple nodes found with label ' { entity_name_label } ' . Using first node. " )
if records :
node = records [ 0 ] [ " n " ]
2024-11-02 18:35:07 -04:00
node_dict = dict ( node )
2024-11-01 11:01:50 -04:00
logger . debug (
2024-11-06 11:18:14 -05:00
f " { inspect . currentframe ( ) . f_code . co_name } : query: { query } , result: { node_dict } "
2024-11-02 18:35:07 -04:00
)
return node_dict
return None
2024-10-26 19:29:45 -04:00
async def node_degree ( self , node_id : str ) - > int :
2025-03-08 01:20:36 +08:00
""" Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label , returns the degree of the first node .
If no node is found , returns 0.
Args :
node_id : The label of the node
Returns :
int : The number of relationships the node has , or 0 if no node found
"""
2024-11-06 11:18:14 -05:00
entity_name_label = node_id . strip ( ' " ' )
2024-10-29 15:36:07 -04:00
2025-03-08 01:20:36 +08:00
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
2024-11-02 18:35:07 -04:00
query = f """
MATCH ( n : ` { entity_name_label } ` )
2025-03-08 01:20:36 +08:00
OPTIONAL MATCH ( n ) - [ r ] - ( )
RETURN n , COUNT ( r ) AS degree
2024-11-02 18:35:07 -04:00
"""
2024-11-06 11:18:14 -05:00
result = await session . run ( query )
2025-03-08 01:20:36 +08:00
records = await result . fetch ( 100 )
await result . consume ( ) # Ensure result is fully consumed
if not records :
logger . warning ( f " No node found with label ' { entity_name_label } ' " )
return 0
if len ( records ) > 1 :
logger . warning ( f " Multiple nodes ( { len ( records ) } ) found with label ' { entity_name_label } ' , using first node ' s degree " )
degree = records [ 0 ] [ " degree " ]
logger . debug (
f " { inspect . currentframe ( ) . f_code . co_name } :query: { query } :result: { degree } "
)
return degree
2024-10-26 19:29:45 -04:00
async def edge_degree ( self , src_id : str , tgt_id : str ) - > int :
2024-11-06 11:18:14 -05:00
entity_name_label_source = src_id . strip ( ' " ' )
entity_name_label_target = tgt_id . strip ( ' " ' )
2024-11-02 18:35:07 -04:00
src_degree = await self . node_degree ( entity_name_label_source )
trg_degree = await self . node_degree ( entity_name_label_target )
2024-11-06 11:18:14 -05:00
2024-11-02 18:35:07 -04:00
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int ( src_degree ) + int ( trg_degree )
logger . debug (
2024-11-06 11:18:14 -05:00
f " { inspect . currentframe ( ) . f_code . co_name } :query:src_Degree+trg_degree:result: { degrees } "
)
2024-11-02 18:35:07 -04:00
return degrees
2025-03-08 01:20:36 +08:00
async def check_duplicate_nodes ( self ) - > list [ tuple [ str , int ] ] :
""" Find all labels that have multiple nodes
Returns :
list [ tuple [ str , int ] ] : List of tuples containing ( label , node_count ) for labels with multiple nodes
"""
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
query = """
MATCH ( n )
WITH labels ( n ) as nodeLabels
UNWIND nodeLabels as label
WITH label , count ( * ) as node_count
WHERE node_count > 1
RETURN label , node_count
ORDER BY node_count DESC
"""
result = await session . run ( query )
duplicates = [ ]
async for record in result :
label = record [ " label " ]
count = record [ " node_count " ]
logger . info ( f " Found { count } nodes with label: { label } " )
duplicates . append ( ( label , count ) )
return duplicates
2024-11-06 11:18:14 -05:00
async def get_edge (
2025-01-04 22:33:35 +08:00
self , source_node_id : str , target_node_id : str
2025-02-16 13:53:59 +01:00
) - > dict [ str , str ] | None :
2025-02-14 16:04:06 +01:00
try :
entity_name_label_source = source_node_id . strip ( ' " ' )
entity_name_label_target = target_node_id . strip ( ' " ' )
2025-03-08 01:20:36 +08:00
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
2025-02-14 16:04:06 +01:00
query = f """
2025-03-08 01:20:36 +08:00
MATCH ( start : ` { entity_name_label_source } ` ) - [ r ] - ( end : ` { entity_name_label_target } ` )
2025-02-14 16:04:06 +01:00
RETURN properties ( r ) as edge_properties
2025-03-01 17:45:06 +08:00
"""
2025-02-14 16:04:06 +01:00
result = await session . run ( query )
2025-03-08 01:20:36 +08:00
records = await result . fetch ( 2 ) # Get up to 2 records to check for duplicates
if len ( records ) > 1 :
logger . warning (
f " Multiple edges found between ' { entity_name_label_source } ' and ' { entity_name_label_target } ' . Using first edge. "
)
if records :
2025-02-14 16:04:06 +01:00
try :
2025-03-08 01:20:36 +08:00
result = dict ( records [ 0 ] [ " edge_properties " ] )
2025-03-07 16:43:18 +08:00
logger . debug ( f " Result: { result } " )
2025-02-14 16:04:06 +01:00
# Ensure required keys exist with defaults
required_keys = {
" weight " : 0.0 ,
" source_id " : None ,
2025-02-17 19:56:46 +08:00
" description " : None ,
" keywords " : None ,
2025-02-14 16:04:06 +01:00
}
for key , default_value in required_keys . items ( ) :
if key not in result :
result [ key ] = default_value
logger . warning (
f " Edge between { entity_name_label_source } and { entity_name_label_target } "
f " missing { key } , using default: { default_value } "
)
logger . debug (
f " { inspect . currentframe ( ) . f_code . co_name } :query: { query } :result: { result } "
)
return result
except ( KeyError , TypeError , ValueError ) as e :
logger . error (
f " Error processing edge properties between { entity_name_label_source } "
f " and { entity_name_label_target } : { str ( e ) } "
)
# Return default edge properties on error
2025-02-17 18:38:55 +01:00
return {
" weight " : 0.0 ,
" description " : None ,
" keywords " : None ,
" source_id " : None ,
}
2024-11-06 11:18:14 -05:00
2024-11-01 11:01:50 -04:00
logger . debug (
2025-02-14 16:04:06 +01:00
f " { inspect . currentframe ( ) . f_code . co_name } : No edge found between { entity_name_label_source } and { entity_name_label_target } "
2024-11-06 11:18:14 -05:00
)
2025-02-14 16:04:06 +01:00
# Return default edge properties when no edge found
2025-02-17 18:38:55 +01:00
return {
" weight " : 0.0 ,
" description " : None ,
" keywords " : None ,
" source_id " : None ,
}
2025-02-14 16:04:06 +01:00
except Exception as e :
logger . error (
f " Error in get_edge between { source_node_id } and { target_node_id } : { str ( e ) } "
)
# Return default edge properties on error
2025-02-17 18:38:55 +01:00
return {
" weight " : 0.0 ,
" description " : None ,
" keywords " : None ,
" source_id " : None ,
}
2024-10-29 15:36:07 -04:00
2025-02-16 13:53:59 +01:00
async def get_node_edges ( self , source_node_id : str ) - > list [ tuple [ str , str ] ] | None :
2024-11-06 11:18:14 -05:00
node_label = source_node_id . strip ( ' " ' )
2024-10-29 15:36:07 -04:00
"""
2024-11-02 18:35:07 -04:00
Retrieves all edges ( relationships ) for a particular node identified by its label .
2024-10-29 15:36:07 -04:00
: return : List of dictionaries containing edge information
"""
2024-11-02 18:35:07 -04:00
query = f """ MATCH (n:` { node_label } `)
2024-10-29 15:36:07 -04:00
OPTIONAL MATCH ( n ) - [ r ] - ( connected )
RETURN n , r , connected """
2025-03-08 01:20:36 +08:00
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
2024-11-02 18:35:07 -04:00
results = await session . run ( query )
2024-10-29 15:36:07 -04:00
edges = [ ]
2025-03-08 01:20:36 +08:00
try :
async for record in results :
source_node = record [ " n " ]
connected_node = record [ " connected " ]
2024-11-06 11:18:14 -05:00
2025-03-08 01:20:36 +08:00
source_label = (
list ( source_node . labels ) [ 0 ] if source_node . labels else None
)
target_label = (
list ( connected_node . labels ) [ 0 ]
if connected_node and connected_node . labels
else None
)
2024-11-06 11:18:14 -05:00
2025-03-08 01:20:36 +08:00
if source_label and target_label :
edges . append ( ( source_label , target_label ) )
finally :
await results . consume ( ) # Ensure results are consumed even if processing fails
2024-10-29 15:36:07 -04:00
2024-11-06 11:18:14 -05:00
return edges
2024-10-29 15:36:07 -04:00
2024-10-30 17:48:14 -04:00
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
2024-11-06 11:18:14 -05:00
retry = retry_if_exception_type (
(
2025-01-04 22:33:35 +08:00
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
neo4jExceptions . ClientError ,
2024-11-06 11:18:14 -05:00
)
) ,
2024-10-30 17:48:14 -04:00
)
2025-02-16 13:53:59 +01:00
async def upsert_node ( self , node_id : str , node_data : dict [ str , str ] ) - > None :
2024-10-26 19:29:45 -04:00
"""
2024-11-02 18:35:07 -04:00
Upsert a node in the Neo4j database .
2024-10-26 19:29:45 -04:00
Args :
2024-11-02 18:35:07 -04:00
node_id : The unique identifier for the node ( used as label )
node_data : Dictionary of node properties
2024-10-26 19:29:45 -04:00
"""
2025-02-14 16:04:06 +01:00
label = await self . _ensure_label ( node_id )
2024-11-02 18:35:07 -04:00
properties = node_data
2024-10-29 15:36:07 -04:00
2024-11-02 18:35:07 -04:00
async def _do_upsert ( tx : AsyncManagedTransaction ) :
2024-10-26 19:29:45 -04:00
query = f """
2024-10-29 15:36:07 -04:00
MERGE ( n : ` { label } ` )
SET n + = $ properties
2024-10-26 19:29:45 -04:00
"""
2024-11-02 18:35:07 -04:00
await tx . run ( query , properties = properties )
2024-11-06 11:18:14 -05:00
logger . debug (
f " Upserted node with label ' { label } ' and properties: { properties } "
)
2024-11-02 18:35:07 -04:00
try :
2025-01-04 21:47:52 +08:00
async with self . _driver . session ( database = self . _DATABASE ) as session :
2024-11-02 18:35:07 -04:00
await session . execute_write ( _do_upsert )
except Exception as e :
logger . error ( f " Error during upsert: { str ( e ) } " )
raise
2024-11-06 11:18:14 -05:00
2024-11-02 18:35:07 -04:00
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
2024-11-06 11:18:14 -05:00
retry = retry_if_exception_type (
(
2025-01-04 22:33:35 +08:00
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
2025-02-14 16:04:06 +01:00
neo4jExceptions . ClientError ,
2024-11-06 11:18:14 -05:00
)
) ,
2024-11-02 18:35:07 -04:00
)
2024-11-06 11:18:14 -05:00
async def upsert_edge (
2025-02-16 13:53:59 +01:00
self , source_node_id : str , target_node_id : str , edge_data : dict [ str , str ]
) - > None :
2024-10-26 19:29:45 -04:00
"""
Upsert an edge and its properties between two nodes identified by their labels .
2025-03-08 01:20:36 +08:00
Checks if both source and target nodes exist before creating the edge .
2024-11-02 18:35:07 -04:00
2024-10-26 19:29:45 -04:00
Args :
2024-11-02 18:35:07 -04:00
source_node_id ( str ) : Label of the source node ( used as identifier )
target_node_id ( str ) : Label of the target node ( used as identifier )
edge_data ( dict ) : Dictionary of properties to set on the edge
2025-03-08 01:20:36 +08:00
Raises :
ValueError : If either source or target node does not exist
2024-10-26 19:29:45 -04:00
"""
2025-02-14 16:04:06 +01:00
source_label = await self . _ensure_label ( source_node_id )
target_label = await self . _ensure_label ( target_node_id )
2024-11-02 18:35:07 -04:00
edge_properties = edge_data
2024-10-29 15:36:07 -04:00
2025-03-08 01:20:36 +08:00
# Check if both nodes exist
source_exists = await self . has_node ( source_label )
target_exists = await self . has_node ( target_label )
if not source_exists :
raise ValueError ( f " Neo4j: source node with label ' { source_label } ' does not exist " )
if not target_exists :
raise ValueError ( f " Neo4j: target node with label ' { target_label } ' does not exist " )
2024-11-02 18:35:07 -04:00
async def _do_upsert_edge ( tx : AsyncManagedTransaction ) :
2024-10-29 15:36:07 -04:00
query = f """
2025-02-14 16:04:06 +01:00
MATCH ( source : ` { source_label } ` )
2024-10-29 15:36:07 -04:00
WITH source
2025-02-14 16:04:06 +01:00
MATCH ( target : ` { target_label } ` )
2025-03-08 01:20:36 +08:00
MERGE ( source ) - [ r : DIRECTED ] - ( target )
2024-10-29 15:36:07 -04:00
SET r + = $ properties
RETURN r
"""
2025-02-14 16:04:06 +01:00
result = await tx . run ( query , properties = edge_properties )
2025-03-08 01:20:36 +08:00
try :
record = await result . single ( )
logger . debug (
f " Upserted edge from ' { source_label } ' to ' { target_label } ' with properties: { edge_properties } , result: { record [ ' r ' ] if record else None } "
)
finally :
await result . consume ( ) # Ensure result is consumed
2024-11-02 18:35:07 -04:00
try :
2025-01-04 21:47:52 +08:00
async with self . _driver . session ( database = self . _DATABASE ) as session :
2024-11-02 18:35:07 -04:00
await session . execute_write ( _do_upsert_edge )
except Exception as e :
logger . error ( f " Error during edge upsert: { str ( e ) } " )
raise
2024-11-06 11:18:14 -05:00
2024-10-26 19:29:45 -04:00
async def _node2vec_embed ( self ) :
2024-11-06 11:18:14 -05:00
print ( " Implemented but never called. " )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
2025-02-20 14:29:36 +01:00
async def get_knowledge_graph (
2025-03-08 01:20:36 +08:00
self ,
node_label : str ,
max_depth : int = 3 ,
min_degree : int = 0 ,
inclusive : bool = False ,
2025-02-20 14:29:36 +01:00
) - > KnowledgeGraph :
"""
2025-03-02 16:20:37 +08:00
Retrieve a connected subgraph of nodes where the label includes the specified ` node_label ` .
2025-03-02 15:39:14 +08:00
Maximum number of nodes is constrained by the environment variable ` MAX_GRAPH_NODES ` ( default : 1000 ) .
2025-03-02 17:32:25 +08:00
When reducing the number of nodes , the prioritization criteria are as follows :
2025-03-08 01:20:36 +08:00
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally , the degree of the nodes
2025-02-20 14:29:36 +01:00
2025-03-02 15:39:14 +08:00
Args :
2025-03-08 01:20:36 +08:00
node_label : Label of the starting node
max_depth : Maximum depth of the subgraph
min_degree : Minimum degree of nodes to include . Defaults to 0
inclusive : Do an inclusive search if true
2025-03-02 15:39:14 +08:00
Returns :
KnowledgeGraph : Complete connected subgraph for specified node
2025-02-20 14:29:36 +01:00
"""
label = node_label . strip ( ' " ' )
result = KnowledgeGraph ( )
seen_nodes = set ( )
seen_edges = set ( )
2025-03-08 01:20:36 +08:00
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
2025-02-20 14:29:36 +01:00
try :
if label == " * " :
main_query = """
MATCH ( n )
2025-03-02 15:39:14 +08:00
OPTIONAL MATCH ( n ) - [ r ] - ( )
WITH n , count ( r ) AS degree
2025-03-08 01:20:36 +08:00
WHERE degree > = $ min_degree
2025-03-02 15:39:14 +08:00
ORDER BY degree DESC
LIMIT $ max_nodes
2025-03-08 01:20:36 +08:00
WITH collect ( { node : n } ) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect ( node_info . node ) AS kept_nodes , filtered_nodes
MATCH ( a ) - [ r ] - ( b )
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info ,
collect ( DISTINCT r ) AS relationships
2025-02-20 14:29:36 +01:00
"""
2025-03-02 15:39:14 +08:00
result_set = await session . run (
2025-03-08 01:20:36 +08:00
main_query ,
{ " max_nodes " : MAX_GRAPH_NODES , " min_degree " : min_degree } ,
2025-03-02 15:39:14 +08:00
)
2025-02-20 14:29:36 +01:00
else :
2025-03-02 16:20:37 +08:00
# Main query uses partial matching
2025-03-08 01:20:36 +08:00
main_query = """
2025-03-02 16:20:37 +08:00
MATCH ( start )
2025-03-08 01:20:36 +08:00
WHERE any ( label IN labels ( start ) WHERE
CASE
WHEN $ inclusive THEN label CONTAINS $ label
ELSE label = $ label
END
)
2025-02-20 14:29:36 +01:00
WITH start
2025-03-08 01:20:36 +08:00
CALL apoc . path . subgraphAll ( start , {
relationshipFilter : ' ' ,
2025-02-20 14:29:36 +01:00
minLevel : 0 ,
2025-03-08 01:20:36 +08:00
maxLevel : $ max_depth ,
2025-02-20 14:29:36 +01:00
bfs : true
2025-03-08 01:20:36 +08:00
} )
2025-02-20 14:29:36 +01:00
YIELD nodes , relationships
2025-03-02 15:39:14 +08:00
WITH start , nodes , relationships
UNWIND nodes AS node
OPTIONAL MATCH ( node ) - [ r ] - ( )
2025-03-08 01:20:36 +08:00
WITH node , count ( r ) AS degree , start , nodes , relationships
WHERE node = start OR EXISTS ( ( start ) - - ( node ) ) OR degree > = $ min_degree
ORDER BY
CASE
WHEN node = start THEN 3
WHEN EXISTS ( ( start ) - - ( node ) ) THEN 2
ELSE 1
END DESC ,
degree DESC
2025-03-02 15:39:14 +08:00
LIMIT $ max_nodes
2025-03-08 01:20:36 +08:00
WITH collect ( { node : node } ) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect ( node_info . node ) AS kept_nodes , filtered_nodes
MATCH ( a ) - [ r ] - ( b )
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info ,
collect ( DISTINCT r ) AS relationships
2025-02-20 14:29:36 +01:00
"""
2025-03-02 15:39:14 +08:00
result_set = await session . run (
2025-03-08 01:20:36 +08:00
main_query ,
{
" max_nodes " : MAX_GRAPH_NODES ,
" label " : label ,
" inclusive " : inclusive ,
" max_depth " : max_depth ,
" min_degree " : min_degree ,
} ,
2025-03-02 15:39:14 +08:00
)
2025-03-08 01:20:36 +08:00
try :
record = await result_set . single ( )
if record :
# Handle nodes (compatible with multi-label cases)
for node_info in record [ " node_info " ] :
node = node_info [ " node " ]
node_id = node . id
if node_id not in seen_nodes :
result . nodes . append (
KnowledgeGraphNode (
id = f " { node_id } " ,
labels = list ( node . labels ) ,
properties = dict ( node ) ,
)
2025-02-20 14:29:36 +01:00
)
2025-03-08 01:20:36 +08:00
seen_nodes . add ( node_id )
# Handle relationships (including direction information)
for rel in record [ " relationships " ] :
edge_id = rel . id
if edge_id not in seen_edges :
start = rel . start_node
end = rel . end_node
result . edges . append (
KnowledgeGraphEdge (
id = f " { edge_id } " ,
type = rel . type ,
source = f " { start . id } " ,
target = f " { end . id } " ,
properties = dict ( rel ) ,
)
2025-02-20 14:29:36 +01:00
)
2025-03-08 01:20:36 +08:00
seen_edges . add ( edge_id )
2025-02-20 14:29:36 +01:00
2025-03-08 01:20:36 +08:00
logger . info (
f " Subgraph query successful | Node count: { len ( result . nodes ) } | Edge count: { len ( result . edges ) } "
)
finally :
await result_set . consume ( ) # Ensure result set is consumed
2025-02-20 14:29:36 +01:00
except neo4jExceptions . ClientError as e :
2025-03-08 01:20:36 +08:00
logger . warning (
f " APOC plugin error: { str ( e ) } , falling back to basic Cypher implementation "
)
if inclusive :
logger . warning (
" Inclusive search mode is not supported in recursive query, using exact matching "
)
return await self . _robust_fallback ( label , max_depth , min_degree )
2025-02-20 14:29:36 +01:00
return result
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
async def _robust_fallback (
2025-03-08 01:20:36 +08:00
self , label : str , max_depth : int , min_degree : int = 0
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
) - > Dict [ str , List [ Dict ] ] :
2025-03-08 01:20:36 +08:00
"""
Fallback implementation when APOC plugin is not available or incompatible .
This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures .
"""
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
result = { " nodes " : [ ] , " edges " : [ ] }
visited_nodes = set ( )
visited_edges = set ( )
2025-03-02 17:32:25 +08:00
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
async def traverse ( current_label : str , current_depth : int ) :
2025-03-08 01:20:36 +08:00
# Check traversal limits
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
if current_depth > max_depth :
2025-03-08 01:20:36 +08:00
logger . debug ( f " Reached max depth: { max_depth } " )
return
if len ( visited_nodes ) > = MAX_GRAPH_NODES :
logger . debug ( f " Reached max nodes limit: { MAX_GRAPH_NODES } " )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
return
2025-01-27 02:07:06 +01:00
# Get current node details
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
node = await self . get_node ( current_label )
if not node :
return
node_id = f " { current_label } "
if node_id in visited_nodes :
return
visited_nodes . add ( node_id )
2025-03-08 01:20:36 +08:00
# Add node data with label as ID
result [ " nodes " ] . append ( {
" id " : current_label ,
" labels " : current_label ,
" properties " : node
} )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
2025-03-08 01:20:36 +08:00
# Get connected nodes that meet the degree requirement
# Note: We don't need to check a's degree since it's the current node
# and was already validated in the previous iteration
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
query = f """
2025-03-08 01:20:36 +08:00
MATCH ( a : ` { current_label } ` ) - [ r ] - ( b )
WITH r , b ,
COUNT ( ( b ) - - ( ) ) AS b_degree
WHERE b_degree > = $ min_degree OR EXISTS ( ( a ) - - ( b ) )
RETURN r , b
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
"""
2025-03-08 01:20:36 +08:00
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
results = await session . run ( query , { " min_degree " : min_degree } )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
async for record in results :
2025-01-27 02:07:06 +01:00
# Handle edges
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
rel = record [ " r " ]
edge_id = f " { rel . id } _ { rel . type } "
if edge_id not in visited_edges :
2025-03-08 01:20:36 +08:00
b_node = record [ " b " ]
if b_node . labels : # Only process if target node has labels
target_label = list ( b_node . labels ) [ 0 ]
result [ " edges " ] . append ( {
" id " : f " { current_label } _ { target_label } " ,
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
" type " : rel . type ,
2025-03-08 01:20:36 +08:00
" source " : current_label ,
" target " : target_label ,
" properties " : dict ( rel )
} )
visited_edges . add ( edge_id )
# Continue traversal
await traverse ( target_label , current_depth + 1 )
else :
logger . warning ( f " Skipping edge { edge_id } due to missing labels on target node " )
feat: Added webui management, including file upload, text upload, Q&A query, graph database management (can view tags, view knowledge graph based on tags), system status (whether it is good, data storage status, model status, path),request /webui/index.html
2025-01-25 18:38:46 +08:00
await traverse ( label , 0 )
return result
2025-02-20 15:09:43 +01:00
async def get_all_labels ( self ) - > list [ str ] :
"""
Get all existing node labels in the database
Returns :
[ " Person " , " Company " , . . . ] # Alphabetically sorted label list
"""
2025-03-08 01:20:36 +08:00
async with self . _driver . session ( database = self . _DATABASE , default_access_mode = " READ " ) as session :
2025-02-20 15:09:43 +01:00
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label"
# Method 2: Query compatible with older versions
query = """
MATCH ( n )
WITH DISTINCT labels ( n ) AS node_labels
UNWIND node_labels AS label
RETURN DISTINCT label
ORDER BY label
"""
result = await session . run ( query )
labels = [ ]
2025-03-08 01:20:36 +08:00
try :
async for record in result :
labels . append ( record [ " label " ] )
finally :
await result . consume ( ) # Ensure results are consumed even if processing fails
2025-02-20 15:09:43 +01:00
return labels
2025-03-04 14:20:55 +08:00
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type (
(
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
neo4jExceptions . ClientError ,
)
) ,
)
2025-02-16 13:53:59 +01:00
async def delete_node ( self , node_id : str ) - > None :
2025-03-04 14:20:55 +08:00
""" Delete a node with the specified label
Args :
node_id : The label of the node to delete
"""
label = await self . _ensure_label ( node_id )
async def _do_delete ( tx : AsyncManagedTransaction ) :
query = f """
MATCH ( n : ` { label } ` )
DETACH DELETE n
"""
await tx . run ( query )
logger . debug ( f " Deleted node with label ' { label } ' " )
try :
async with self . _driver . session ( database = self . _DATABASE ) as session :
await session . execute_write ( _do_delete )
except Exception as e :
logger . error ( f " Error during node deletion: { str ( e ) } " )
raise
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type (
(
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
neo4jExceptions . ClientError ,
)
) ,
)
async def remove_nodes ( self , nodes : list [ str ] ) :
""" Delete multiple nodes
Args :
nodes : List of node labels to be deleted
"""
for node in nodes :
await self . delete_node ( node )
@retry (
stop = stop_after_attempt ( 3 ) ,
wait = wait_exponential ( multiplier = 1 , min = 4 , max = 10 ) ,
retry = retry_if_exception_type (
(
neo4jExceptions . ServiceUnavailable ,
neo4jExceptions . TransientError ,
neo4jExceptions . WriteServiceUnavailable ,
neo4jExceptions . ClientError ,
)
) ,
)
async def remove_edges ( self , edges : list [ tuple [ str , str ] ] ) :
""" Delete multiple edges
Args :
edges : List of edges to be deleted , each edge is a ( source , target ) tuple
"""
for source , target in edges :
source_label = await self . _ensure_label ( source )
target_label = await self . _ensure_label ( target )
async def _do_delete_edge ( tx : AsyncManagedTransaction ) :
query = f """
2025-03-08 01:20:36 +08:00
MATCH ( source : ` { source_label } ` ) - [ r ] - ( target : ` { target_label } ` )
2025-03-04 14:20:55 +08:00
DELETE r
"""
await tx . run ( query )
logger . debug ( f " Deleted edge from ' { source_label } ' to ' { target_label } ' " )
try :
async with self . _driver . session ( database = self . _DATABASE ) as session :
await session . execute_write ( _do_delete_edge )
except Exception as e :
logger . error ( f " Error during edge deletion: { str ( e ) } " )
raise
2025-02-16 13:55:30 +01:00
2025-02-16 13:53:59 +01:00
async def embed_nodes (
self , algorithm : str
) - > tuple [ np . ndarray [ Any , Any ] , list [ str ] ] :
2025-02-16 13:55:30 +01:00
raise NotImplementedError