LightRAG/lightrag/kg/memgraph_impl.py

423 lines
18 KiB
Python
Raw Normal View History

2025-06-26 16:15:56 +02:00
import os
import re
from dataclasses import dataclass
from typing import final
import configparser
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import (
AsyncGraphDatabase,
AsyncManagedTransaction,
)
from dotenv import load_dotenv
# use the .env that is inside the current folder
load_dotenv(dotenv_path=".env", override=False)
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@final
@dataclass
class MemgraphStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
async def initialize(self):
URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687"))
USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback=""))
PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback=""))
DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph"))
self._driver = AsyncGraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
)
self._DATABASE = DATABASE
try:
async with self._driver.session(database=DATABASE) as session:
# Create index for base nodes on entity_id if it doesn't exist
try:
await session.run("""CREATE INDEX ON :base(entity_id)""")
logger.info("Created index on :base(entity_id) in Memgraph.")
except Exception as e:
# Index may already exist, which is not an error
logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}")
await session.run("RETURN 1")
logger.info(f"Connected to Memgraph at {URI}")
except Exception as e:
logger.error(f"Failed to connect to Memgraph at {URI}: {e}")
raise
async def finalize(self):
if self._driver is not None:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
await self.finalize()
async def index_done_callback(self):
# Memgraph handles persistence automatically
pass
async def has_node(self, node_id: str) -> bool:
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id)
single_result = await result.single()
await result.consume()
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = (
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
single_result = await result.single()
await result.consume()
return single_result["edgeExists"]
async def get_node(self, node_id: str) -> dict[str, str] | None:
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
result = await session.run(query, entity_id=node_id)
records = await result.fetch(2)
await result.consume()
if records:
node = records[0]["n"]
node_dict = dict(node)
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
return node_dict
return None
async def get_all_labels(self) -> list[str]:
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
MATCH (n:base)
WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label
ORDER BY label
"""
result = await session.run(query)
labels = []
async for record in result:
labels.append(record["label"])
await result.consume()
return labels
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
MATCH (n:base {entity_id: $entity_id})
OPTIONAL MATCH (n)-[r]-(connected:base)
WHERE connected.entity_id IS NOT NULL
RETURN n, r, connected
"""
results = await session.run(query, entity_id=source_node_id)
edges = []
async for record in results:
source_node = record["n"]
connected_node = record["connected"]
if not source_node or not connected_node:
continue
source_label = source_node.get("entity_id")
target_label = connected_node.get("entity_id")
if source_label and target_label:
edges.append((source_label, target_label))
await results.consume()
return edges
async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None:
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
RETURN properties(r) as edge_properties
"""
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
records = await result.fetch(2)
await result.consume()
if records:
edge_result = dict(records[0]["edge_properties"])
for key, default_value in {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}.items():
if key not in edge_result:
edge_result[key] = default_value
return edge_result
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
properties = node_data
entity_type = properties.get("entity_type", "base")
if "entity_id" not in properties:
raise ValueError("Memgraph: node properties must contain an 'entity_id' field")
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
query = (
f"""
MERGE (n:base {{entity_id: $entity_id}})
SET n += $properties
SET n:`{entity_type}`
"""
)
result = await tx.run(query, entity_id=node_id, properties=properties)
await result.consume()
await session.execute_write(execute_upsert)
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})
WITH source
MATCH (target:base {entity_id: $target_entity_id})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
properties=edge_properties,
)
await result.consume()
await session.execute_write(execute_upsert)
async def delete_node(self, node_id: str) -> None:
async def _do_delete(tx: AsyncManagedTransaction):
query = """
MATCH (n:base {entity_id: $entity_id})
DETACH DELETE n
"""
result = await tx.run(query, entity_id=node_id)
await result.consume()
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete)
async def remove_nodes(self, nodes: list[str]):
for node in nodes:
await self.delete_node(node)
async def remove_edges(self, edges: list[tuple[str, str]]):
for source, target in edges:
async def _do_delete_edge(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
DELETE r
"""
result = await tx.run(
query, source_entity_id=source, target_entity_id=target
)
await result.consume()
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge)
async def drop(self) -> dict[str, str]:
try:
async with self._driver.session(database=self._DATABASE) as session:
query = "MATCH (n) DETACH DELETE n"
result = await session.run(query)
await result.consume()
logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}")
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
return {"status": "error", "message": str(e)}
async def node_degree(self, node_id: str) -> int:
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
MATCH (n:base {entity_id: $entity_id})
OPTIONAL MATCH (n)-[r]-()
RETURN COUNT(r) AS degree
"""
result = await session.run(query, entity_id=node_id)
record = await result.single()
await result.consume()
if not record:
return 0
return record["degree"]
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
return int(src_degree) + int(trg_degree)
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
UNWIND $chunk_ids AS chunk_id
MATCH (n:base)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
UNWIND $chunk_ids AS chunk_id
MATCH (a:base)-[r]-(b:base)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
if node_label == "*":
count_query = "MATCH (n) RETURN count(n) as total"
count_result = await session.run(count_query)
count_record = await count_result.single()
await count_result.consume()
if count_record and count_record["total"] > max_nodes:
result.is_truncated = True
logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}")
main_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
WITH n, COALESCE(count(r), 0) AS degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect({node: n}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
"""
result_set = await session.run(main_query, {"max_nodes": max_nodes})
record = await result_set.single()
await result_set.consume()
else:
# BFS fallback for Memgraph (no APOC)
from collections import deque
# Get the starting node
start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
node_result = await session.run(start_query, entity_id=node_label)
node_record = await node_result.single()
await node_result.consume()
if not node_record:
return result
start_node = node_record["n"]
start_node_id = start_node.get("entity_id")
queue = deque([(start_node, 0)])
visited = set()
bfs_nodes = []
while queue and len(bfs_nodes) < max_nodes:
current_node, depth = queue.popleft()
node_id = current_node.get("entity_id")
if node_id in visited:
continue
visited.add(node_id)
bfs_nodes.append(current_node)
if depth < max_depth:
# Get neighbors
neighbor_query = """
MATCH (n:base {entity_id: $entity_id})-[]-(m:base)
RETURN m
"""
neighbors_result = await session.run(neighbor_query, entity_id=node_id)
neighbors = [rec["m"] for rec in await neighbors_result.to_list()]
await neighbors_result.consume()
for neighbor in neighbors:
neighbor_id = neighbor.get("entity_id")
if neighbor_id not in visited:
queue.append((neighbor, depth + 1))
# Build subgraph
subgraph_ids = [n.get("entity_id") for n in bfs_nodes]
# Nodes
for n in bfs_nodes:
node_id = n.get("entity_id")
if node_id not in seen_nodes:
result.nodes.append(KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties=dict(n),
))
seen_nodes.add(node_id)
# Edges
if subgraph_ids:
edge_query = """
MATCH (a:base)-[r]-(b:base)
WHERE a.entity_id IN $ids AND b.entity_id IN $ids
RETURN DISTINCT r, a, b
"""
edge_result = await session.run(edge_query, ids=subgraph_ids)
async for record in edge_result:
r = record["r"]
a = record["a"]
b = record["b"]
edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}"
if edge_id not in seen_edges:
result.edges.append(KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=a.get("entity_id"),
target=b.get("entity_id"),
properties=dict(r),
))
seen_edges.add(edge_id)
await edge_result.consume()
logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}")
return result