using neo4j async

This commit is contained in:
Ken Wiltshire 2024-11-02 18:35:07 -04:00
parent 40e80ebc9d
commit f19af82db1
5 changed files with 141 additions and 179 deletions

View File

@ -2,14 +2,16 @@ import asyncio
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast
from typing import Any, Union, cast, Tuple, List, Dict
import numpy as np
import inspect
from lightrag.utils import load_json, logger, write_json
from ..base import (
BaseGraphStorage
)
from neo4j import GraphDatabase, exceptions as neo4jExceptions
from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
from contextlib import asynccontextmanager
from tenacity import (
@ -20,126 +22,135 @@ from tenacity import (
)
@dataclass
class GraphStorage(BaseGraphStorage):
class Neo4JStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print ("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config):
super().__init__(namespace=namespace, global_config=global_config)
self._driver = None
self._driver_lock = asyncio.Lock()
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
return None
def __post_init__(self):
# self._graph = preloaded_graph or nx.Graph()
print("is this ever run")
credetial_parts = ['URI', 'USERNAME','PASSWORD']
credentials_set = all(x in os.environ for x in credetial_parts )
if credentials_set:
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
else:
raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment")
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self):
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
await self._driver.close()
async def index_done_callback(self):
print ("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('\"')
def _check_node_exists(tx, label):
query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
result = tx.run(query)
single_result = result.single()
async with self._driver.session() as session:
query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
return single_result["node_exists"]
with self._driver.session() as session:
return session.read_transaction(_check_node_exists, entity_name_label)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
def _check_edge_existence(tx, label1, label2):
async with self._driver.session() as session:
query = (
f"MATCH (a:`{label1}`)-[r]-(b:`{label2}`) "
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = tx.run(query)
single_result = result.single()
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
return single_result["edgeExists"]
def close(self):
self._driver.close()
#hard code relaitionship type, directed.
with self._driver.session() as session:
result = session.read_transaction(_check_edge_existence, entity_name_label_source, entity_name_label_target)
return result
async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name_label = node_id.strip('\"')
with self._driver.session() as session:
query = "MATCH (n:`{entity_name_label}`) RETURN n".format(entity_name_label=entity_name_label)
result = session.run(query)
for record in result:
result = record["n"]
async with self._driver.session() as session:
entity_name_label = node_id.strip('\"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
record = await result.single()
if record:
node = record["n"]
node_dict = dict(node)
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
)
return result
f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}'
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('\"')
def _find_node_degree(session, label):
with session.begin_transaction() as tx:
query = f"""
MATCH (n:`{label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
"""
result = tx.run(query)
record = result.single()
if record:
edge_count = record["totalEdgeCount"]
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
)
return edge_count
else:
return None
async with self._driver.session() as session:
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
"""
result = await session.run(query)
record = await result.single()
if record:
edge_count = record["totalEdgeCount"]
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
)
return edge_count
else:
return None
with self._driver.session() as session:
degree = _find_node_degree(session, entity_name_label)
return degree
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('\"')
entity_name_label_target = tgt_id.strip('\"')
with self._driver.session() as session:
query = f"""MATCH (n1:`{entity_name_label_source}`)-[r]-(n2:`{entity_name_label_target}`)
RETURN count(r) AS degree"""
result = session.run(query)
record = result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
)
return record["degree"]
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}'
)
return degrees
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('\"')
@ -154,15 +165,15 @@ class GraphStorage(BaseGraphStorage):
Returns:
list: List of all relationships/edges found
"""
with self._driver.session() as session:
async with self._driver.session() as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
result = session.run(query)
record = result.single()
result = await session.run(query)
record = await result.single()
if record:
result = dict(record["edge_properties"])
logger.debug(
@ -173,29 +184,20 @@ class GraphStorage(BaseGraphStorage):
return None
async def get_node_edges(self, source_node_id: str):
async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]:
node_label = source_node_id.strip('\"')
"""
Retrieves all edges (relationships) for a particular node identified by its label and ID.
:param uri: Neo4j database URI
:param username: Neo4j username
:param password: Neo4j password
:param node_label: Label of the node
:param node_id: ID property of the node
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
"""
def fetch_edges(tx, label):
query = f"""MATCH (n:`{label}`)
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
results = tx.run(query)
async with self._driver.session() as session:
results = await session.run(query)
edges = []
for record in results:
async for record in results:
source_node = record['n']
connected_node = record['connected']
@ -207,7 +209,7 @@ class GraphStorage(BaseGraphStorage):
return edges
with self._driver.session() as session:
async with self._driver.session() as session:
edges = session.read_transaction(fetch_edges,node_label)
return edges
@ -217,86 +219,51 @@ class GraphStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
Upsert a node in the Neo4j database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('\"')
properties = node_data
"""
Upsert a node with the given label and properties within a transaction.
Args:
label: The node label to search for and apply
properties: Dictionary of node properties
Returns:
Dictionary containing the node's properties after upsert, or None if operation fails
"""
def _do_upsert(tx, label: str, properties: dict[str, Any]):
"""
Args:
tx: Neo4j transaction object
label: The node label to search for and apply
properties: Dictionary of node properties
Returns:
Dictionary containing the node's properties after upsert, or None if operation fails
"""
async def _do_upsert(tx: AsyncManagedTransaction):
query = f"""
MERGE (n:`{label}`)
SET n += $properties
RETURN n
"""
# Execute the query with properties as parameters
# with session.begin_transaction() as tx:
result = tx.run(query, properties=properties)
record = result.single()
if record:
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
)
return dict(record["n"])
return None
with self._driver.session() as session:
with session.begin_transaction() as tx:
try:
result = _do_upsert(tx,label,properties)
tx.commit()
return result
except Exception as e:
raise # roll back
await tx.run(query, properties=properties)
logger.debug(f"Upserted node with label '{label}' and properties: {properties}")
try:
async with self._driver.session() as session:
await session.execute_write(_do_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
)
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]):
"""
Upsert an edge and its properties between two nodes identified by their labels.
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None:
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('\"')
target_node_label = target_node_id.strip('\"')
edge_properties = edge_data
"""
Upsert an edge and its properties between two nodes identified by their labels.
Args:
source_node_label (str): Label of the source node (used as identifier)
target_node_label (str): Label of the target node (used as identifier)
edge_properties (dict): Dictionary of properties to set on the edge
"""
def _do_upsert_edge(tx, source_node_label: str, target_node_label: str, edge_properties: dict[str, Any]) -> None:
"""
Static method to perform the edge upsert within a transaction.
The query will:
1. Match the source and target nodes by their labels
2. Merge the DIRECTED relationship
3. Set all properties on the relationship, updating existing ones and adding new ones
"""
# Convert edge properties to Cypher parameter string
# props_string = ", ".join(f"r.{key} = ${key}" for key in edge_properties.keys())
# """.format(props_string)
async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f"""
MATCH (source:`{source_node_label}`)
WITH source
@ -305,22 +272,15 @@ class GraphStorage(BaseGraphStorage):
SET r += $properties
RETURN r
"""
await tx.run(query, properties=edge_properties)
logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}")
result = tx.run(query, properties=edge_properties)
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
)
return result.single()
with self._driver.session() as session:
session.execute_write(
_do_upsert_edge,
source_node_label,
target_node_label,
edge_properties
)
# return result
try:
async with self._driver.session() as session:
await session.execute_write(_do_upsert_edge)
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
async def _node2vec_embed(self):
print ("Implemented but never called.")

View File

@ -26,7 +26,7 @@ from .storage import (
)
from .kg.neo4j_impl import (
GraphStorage as Neo4JStorage
Neo4JStorage
)
#future KG integrations
@ -57,9 +57,10 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
logger.info("Creating a new event loop in a sub-thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
logger.info("Creating a new event loop in main thread.")
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
loop = asyncio.get_event_loop()
return loop
@ -329,4 +330,4 @@ class LightRAG:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
await asyncio.gather(*tasks)

View File

@ -798,4 +798,4 @@ if __name__ == "__main__":
result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())
asyncio.run(main())

View File

@ -1083,4 +1083,4 @@ async def naive_query(
.strip()
)
return response
return response

View File

@ -2,6 +2,7 @@ import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio