using neo4j async

This commit is contained in:
Ken Wiltshire 2024-11-02 18:35:07 -04:00
parent 40e80ebc9d
commit f19af82db1
5 changed files with 141 additions and 179 deletions

View File

@ -2,14 +2,16 @@ import asyncio
import html import html
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast from typing import Any, Union, cast, Tuple, List, Dict
import numpy as np import numpy as np
import inspect import inspect
from lightrag.utils import load_json, logger, write_json from lightrag.utils import load_json, logger, write_json
from ..base import ( from ..base import (
BaseGraphStorage BaseGraphStorage
) )
from neo4j import GraphDatabase, exceptions as neo4jExceptions from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
from contextlib import asynccontextmanager
from tenacity import ( from tenacity import (
@ -20,126 +22,135 @@ from tenacity import (
) )
@dataclass @dataclass
class GraphStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):
print ("no preloading of graph with neo4j in production") print ("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config):
super().__init__(namespace=namespace, global_config=global_config)
self._driver = None
self._driver_lock = asyncio.Lock()
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
return None
def __post_init__(self): def __post_init__(self):
# self._graph = preloaded_graph or nx.Graph() # self._graph = preloaded_graph or nx.Graph()
print("is this ever run")
credetial_parts = ['URI', 'USERNAME','PASSWORD'] credetial_parts = ['URI', 'USERNAME','PASSWORD']
credentials_set = all(x in os.environ for x in credetial_parts ) credentials_set = all(x in os.environ for x in credetial_parts )
if credentials_set:
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
else:
raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment")
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def close(self):
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
await self._driver.close()
async def index_done_callback(self): async def index_done_callback(self):
print ("KG successfully indexed.") print ("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('\"') entity_name_label = node_id.strip('\"')
def _check_node_exists(tx, label): async with self._driver.session() as session:
query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists" query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
result = tx.run(query) result = await session.run(query)
single_result = result.single() single_result = await result.single()
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
) )
return single_result["node_exists"] return single_result["node_exists"]
with self._driver.session() as session:
return session.read_transaction(_check_node_exists, entity_name_label)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('\"') entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"') entity_name_label_target = target_node_id.strip('\"')
async with self._driver.session() as session:
def _check_edge_existence(tx, label1, label2):
query = ( query = (
f"MATCH (a:`{label1}`)-[r]-(b:`{label2}`) " f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists" "RETURN COUNT(r) > 0 AS edgeExists"
) )
result = tx.run(query) result = await session.run(query)
single_result = result.single() single_result = await result.single()
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
) )
return single_result["edgeExists"] return single_result["edgeExists"]
def close(self): def close(self):
self._driver.close() self._driver.close()
#hard code relaitionship type, directed.
with self._driver.session() as session:
result = session.read_transaction(_check_edge_existence, entity_name_label_source, entity_name_label_target)
return result
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name_label = node_id.strip('\"') async with self._driver.session() as session:
with self._driver.session() as session: entity_name_label = node_id.strip('\"')
query = "MATCH (n:`{entity_name_label}`) RETURN n".format(entity_name_label=entity_name_label) query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = session.run(query) result = await session.run(query)
for record in result: record = await result.single()
result = record["n"] if record:
node = record["n"]
node_dict = dict(node)
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}' f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}'
) )
return result return node_dict
return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('\"') entity_name_label = node_id.strip('\"')
async with self._driver.session() as session:
def _find_node_degree(session, label): query = f"""
with session.begin_transaction() as tx: MATCH (n:`{entity_name_label}`)
query = f""" RETURN COUNT{{ (n)--() }} AS totalEdgeCount
MATCH (n:`{label}`) """
RETURN COUNT{{ (n)--() }} AS totalEdgeCount result = await session.run(query)
""" record = await result.single()
result = tx.run(query) if record:
record = result.single() edge_count = record["totalEdgeCount"]
if record: logger.debug(
edge_count = record["totalEdgeCount"] f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
logger.debug( )
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}' return edge_count
) else:
return edge_count return None
else:
return None
with self._driver.session() as session:
degree = _find_node_degree(session, entity_name_label)
return degree
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('\"') entity_name_label_source = src_id.strip('\"')
entity_name_label_target = tgt_id.strip('\"') entity_name_label_target = tgt_id.strip('\"')
with self._driver.session() as session: src_degree = await self.node_degree(entity_name_label_source)
query = f"""MATCH (n1:`{entity_name_label_source}`)-[r]-(n2:`{entity_name_label_target}`) trg_degree = await self.node_degree(entity_name_label_target)
RETURN count(r) AS degree"""
result = session.run(query) # Convert None to 0 for addition
record = result.single() src_degree = 0 if src_degree is None else src_degree
logger.debug( trg_degree = 0 if trg_degree is None else trg_degree
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
) degrees = int(src_degree) + int(trg_degree)
return record["degree"] logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}'
)
return degrees
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('\"') entity_name_label_source = source_node_id.strip('\"')
@ -154,15 +165,15 @@ class GraphStorage(BaseGraphStorage):
Returns: Returns:
list: List of all relationships/edges found list: List of all relationships/edges found
""" """
with self._driver.session() as session: async with self._driver.session() as session:
query = f""" query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
LIMIT 1 LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target) """.format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
result = session.run(query) result = await session.run(query)
record = result.single() record = await result.single()
if record: if record:
result = dict(record["edge_properties"]) result = dict(record["edge_properties"])
logger.debug( logger.debug(
@ -173,29 +184,20 @@ class GraphStorage(BaseGraphStorage):
return None return None
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]:
node_label = source_node_id.strip('\"') node_label = source_node_id.strip('\"')
""" """
Retrieves all edges (relationships) for a particular node identified by its label and ID. Retrieves all edges (relationships) for a particular node identified by its label.
:param uri: Neo4j database URI
:param username: Neo4j username
:param password: Neo4j password
:param node_label: Label of the node
:param node_id: ID property of the node
:return: List of dictionaries containing edge information :return: List of dictionaries containing edge information
""" """
query = f"""MATCH (n:`{node_label}`)
def fetch_edges(tx, label):
query = f"""MATCH (n:`{label}`)
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected""" RETURN n, r, connected"""
async with self._driver.session() as session:
results = tx.run(query) results = await session.run(query)
edges = [] edges = []
for record in results: async for record in results:
source_node = record['n'] source_node = record['n']
connected_node = record['connected'] connected_node = record['connected']
@ -207,7 +209,7 @@ class GraphStorage(BaseGraphStorage):
return edges return edges
with self._driver.session() as session: async with self._driver.session() as session:
edges = session.read_transaction(fetch_edges,node_label) edges = session.read_transaction(fetch_edges,node_label)
return edges return edges
@ -217,86 +219,51 @@ class GraphStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)), retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
) )
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
Upsert a node in the Neo4j database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('\"') label = node_id.strip('\"')
properties = node_data properties = node_data
"""
Upsert a node with the given label and properties within a transaction.
Args:
label: The node label to search for and apply
properties: Dictionary of node properties
Returns:
Dictionary containing the node's properties after upsert, or None if operation fails
"""
def _do_upsert(tx, label: str, properties: dict[str, Any]):
"""
Args:
tx: Neo4j transaction object
label: The node label to search for and apply
properties: Dictionary of node properties
Returns:
Dictionary containing the node's properties after upsert, or None if operation fails
"""
async def _do_upsert(tx: AsyncManagedTransaction):
query = f""" query = f"""
MERGE (n:`{label}`) MERGE (n:`{label}`)
SET n += $properties SET n += $properties
RETURN n
""" """
# Execute the query with properties as parameters await tx.run(query, properties=properties)
# with session.begin_transaction() as tx: logger.debug(f"Upserted node with label '{label}' and properties: {properties}")
result = tx.run(query, properties=properties)
record = result.single()
if record:
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
)
return dict(record["n"])
return None
try:
async with self._driver.session() as session:
await session.execute_write(_do_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
with self._driver.session() as session: @retry(
with session.begin_transaction() as tx: stop=stop_after_attempt(3),
try: wait=wait_exponential(multiplier=1, min=4, max=10),
result = _do_upsert(tx,label,properties) retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
tx.commit() )
return result async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]):
except Exception as e:
raise # roll back
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None:
source_node_label = source_node_id.strip('\"')
target_node_label = target_node_id.strip('\"')
edge_properties = edge_data
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
Args: Args:
source_node_label (str): Label of the source node (used as identifier) source_node_id (str): Label of the source node (used as identifier)
target_node_label (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_properties (dict): Dictionary of properties to set on the edge edge_data (dict): Dictionary of properties to set on the edge
""" """
source_node_label = source_node_id.strip('\"')
target_node_label = target_node_id.strip('\"')
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
def _do_upsert_edge(tx, source_node_label: str, target_node_label: str, edge_properties: dict[str, Any]) -> None:
"""
Static method to perform the edge upsert within a transaction.
The query will:
1. Match the source and target nodes by their labels
2. Merge the DIRECTED relationship
3. Set all properties on the relationship, updating existing ones and adding new ones
"""
# Convert edge properties to Cypher parameter string
# props_string = ", ".join(f"r.{key} = ${key}" for key in edge_properties.keys())
# """.format(props_string)
query = f""" query = f"""
MATCH (source:`{source_node_label}`) MATCH (source:`{source_node_label}`)
WITH source WITH source
@ -305,22 +272,15 @@ class GraphStorage(BaseGraphStorage):
SET r += $properties SET r += $properties
RETURN r RETURN r
""" """
await tx.run(query, properties=edge_properties)
logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}")
result = tx.run(query, properties=edge_properties) try:
logger.debug( async with self._driver.session() as session:
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}' await session.execute_write(_do_upsert_edge)
) except Exception as e:
return result.single() logger.error(f"Error during edge upsert: {str(e)}")
raise
with self._driver.session() as session:
session.execute_write(
_do_upsert_edge,
source_node_label,
target_node_label,
edge_properties
)
# return result
async def _node2vec_embed(self): async def _node2vec_embed(self):
print ("Implemented but never called.") print ("Implemented but never called.")

View File

@ -26,7 +26,7 @@ from .storage import (
) )
from .kg.neo4j_impl import ( from .kg.neo4j_impl import (
GraphStorage as Neo4JStorage Neo4JStorage
) )
#future KG integrations #future KG integrations
@ -57,9 +57,10 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
except RuntimeError: except RuntimeError:
logger.info("Creating a new event loop in a sub-thread.") logger.info("Creating a new event loop in main thread.")
loop = asyncio.new_event_loop() # loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) # asyncio.set_event_loop(loop)
loop = asyncio.get_event_loop()
return loop return loop

View File

@ -2,6 +2,7 @@ import os
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
######### #########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio # import nest_asyncio