Merge branch 'add-Memgraph-graph-db' into memgraph

This commit is contained in:
yangdx 2025-07-09 03:38:07 +08:00
commit 14d51518dd
2 changed files with 191 additions and 83 deletions

View File

@ -134,13 +134,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
# LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
### Graph Storage (Recommended for production deployment)
# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage
####################################################################
### Default workspace for all storage types
### For the purpose of isolation of data for each LightRAG instance
### Valid characters: a-z, A-Z, 0-9, and _
####################################################################
# WORKSPACE=doc—
# WORKSPACE=space1
### PostgreSQL Configuration
POSTGRES_HOST=localhost
@ -179,3 +180,10 @@ QDRANT_URL=http://localhost:6333
### Redis
REDIS_URI=redis://localhost:6379
# REDIS_WORKSPACE=forced_workspace_name
### Memgraph Configuration
MEMGRAPH_URI=bolt://localhost:7687
MEMGRAPH_USERNAME=
MEMGRAPH_PASSWORD=
MEMGRAPH_DATABASE=memgraph
# MEMGRAPH_WORKSPACE=forced_workspace_name

View File

@ -31,14 +31,23 @@ config.read("config.ini", "utf-8")
@final
@dataclass
class MemgraphStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func):
def __init__(self, namespace, global_config, embedding_func, workspace=None):
memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
if memgraph_workspace and memgraph_workspace.strip():
workspace = memgraph_workspace
super().__init__(
namespace=namespace,
workspace=workspace or "",
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
def _get_workspace_label(self) -> str:
"""Get workspace label, return 'base' for compatibility when workspace is empty"""
workspace = getattr(self, "workspace", None)
return workspace if workspace else "base"
async def initialize(self):
URI = os.environ.get(
"MEMGRAPH_URI",
@ -63,12 +72,17 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(database=DATABASE) as session:
# Create index for base nodes on entity_id if it doesn't exist
try:
await session.run("""CREATE INDEX ON :base(entity_id)""")
logger.info("Created index on :base(entity_id) in Memgraph.")
workspace_label = self._get_workspace_label()
await session.run(
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
)
logger.info(
f"Created index on :{workspace_label}(entity_id) in Memgraph."
)
except Exception as e:
# Index may already exist, which is not an error
logger.warning(
f"Index creation on :base(entity_id) may have failed or already exists: {e}"
f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
)
await session.run("RETURN 1")
logger.info(f"Connected to Memgraph at {URI}")
@ -101,15 +115,22 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error checking the node existence.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["node_exists"]
return (
single_result["node_exists"] if single_result is not None else False
)
except Exception as e:
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
await result.consume() # Ensure the result is consumed even on error
@ -129,22 +150,29 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error checking the edge existence.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
workspace_label = self._get_workspace_label()
query = (
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
) # type: ignore
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["edgeExists"]
return (
single_result["edgeExists"] if single_result is not None else False
)
except Exception as e:
logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
@ -165,11 +193,18 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
workspace_label = self._get_workspace_label()
query = (
f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
)
result = await session.run(query, entity_id=node_id)
try:
records = await result.fetch(
@ -183,12 +218,12 @@ class MemgraphStorage(BaseGraphStorage):
if records:
node = records[0]["n"]
node_dict = dict(node)
# Remove base label from labels list if it exists
# Remove workspace label from labels list if it exists
if "labels" in node_dict:
node_dict["labels"] = [
label
for label in node_dict["labels"]
if label != "base"
if label != workspace_label
]
return node_dict
return None
@ -212,12 +247,17 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """
MATCH (n:base {entity_id: $entity_id})
workspace_label = self._get_workspace_label()
query = f"""
MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
OPTIONAL MATCH (n)-[r]-()
RETURN COUNT(r) AS degree
"""
@ -246,12 +286,17 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """
MATCH (n:base)
workspace_label = self._get_workspace_label()
query = f"""
MATCH (n:`{workspace_label}`)
WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label
ORDER BY label
@ -280,13 +325,18 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """MATCH (n:base {entity_id: $entity_id})
OPTIONAL MATCH (n)-[r]-(connected:base)
workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
WHERE connected.entity_id IS NOT NULL
RETURN n, r, connected"""
results = await session.run(query, entity_id=source_node_id)
@ -341,12 +391,17 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
workspace_label = self._get_workspace_label()
query = f"""
MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
RETURN properties(r) as edge_properties
"""
result = await session.run(
@ -386,6 +441,10 @@ class MemgraphStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
properties = node_data
entity_type = properties["entity_type"]
if "entity_id" not in properties:
@ -393,16 +452,14 @@ class MemgraphStorage(BaseGraphStorage):
try:
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
async def execute_upsert(tx: AsyncManagedTransaction):
query = (
"""
MERGE (n:base {entity_id: $entity_id})
query = f"""
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
SET n += $properties
SET n:`%s`
SET n:`{entity_type}`
"""
% entity_type
)
result = await tx.run(
query, entity_id=node_id, properties=properties
)
@ -429,15 +486,20 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source
MATCH (target:base {entity_id: $target_entity_id})
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
@ -467,10 +529,15 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async def _do_delete(tx: AsyncManagedTransaction):
query = """
MATCH (n:base {entity_id: $entity_id})
workspace_label = self._get_workspace_label()
query = f"""
MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
DETACH DELETE n
"""
result = await tx.run(query, entity_id=node_id)
@ -490,6 +557,10 @@ class MemgraphStorage(BaseGraphStorage):
Args:
nodes: List of node labels to be deleted
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
for node in nodes:
await self.delete_node(node)
@ -502,11 +573,16 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
for source, target in edges:
async def _do_delete_edge(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
DELETE r
"""
result = await tx.run(
@ -523,9 +599,9 @@ class MemgraphStorage(BaseGraphStorage):
raise
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
"""Drop all data from the current workspace and clean up resources
This method will delete all nodes and relationships in the Neo4j database.
This method will delete all nodes and relationships in the Memgraph database.
Returns:
dict[str, str]: Operation status and message
@ -535,17 +611,24 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
async with self._driver.session(database=self._DATABASE) as session:
query = "MATCH (n) DETACH DELETE n"
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
result = await session.run(query)
await result.consume()
logger.info(
f"Process {os.getpid()} drop Memgraph database {self._DATABASE}"
f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
)
return {"status": "success", "message": "data dropped"}
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
logger.error(
f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@ -558,6 +641,10 @@ class MemgraphStorage(BaseGraphStorage):
Returns:
int: Sum of the degrees of both nodes
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
@ -578,12 +665,17 @@ class MemgraphStorage(BaseGraphStorage):
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (n:base)
MATCH (n:`{workspace_label}`)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
@ -607,12 +699,17 @@ class MemgraphStorage(BaseGraphStorage):
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (a:base)-[r]-(b:base)
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
// Ensure we only return each unique edge once by ordering the source and target
@ -652,9 +749,15 @@ class MemgraphStorage(BaseGraphStorage):
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
@ -682,19 +785,17 @@ class MemgraphStorage(BaseGraphStorage):
await count_result.consume()
# Run the main query to get nodes with highest degree
main_query = """
MATCH (n)
main_query = f"""
MATCH (n:`{workspace_label}`)
OPTIONAL MATCH (n)-[r]-()
WITH n, COALESCE(count(r), 0) AS degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect({node: n}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WITH collect(n) AS kept_nodes
MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
RETURN [node IN kept_nodes | {{node: node}}] AS node_info,
collect(DISTINCT r) AS relationships
"""
result_set = None
try:
@ -710,31 +811,33 @@ class MemgraphStorage(BaseGraphStorage):
await result_set.consume()
else:
bfs_query = """
MATCH (start) WHERE start.entity_id = $entity_id
bfs_query = f"""
MATCH (start:`{workspace_label}`)
WHERE start.entity_id = $entity_id
WITH start
CALL {
CALL {{
WITH start
MATCH path = (start)-[*0..$max_depth]-(node)
MATCH path = (start)-[*0..{max_depth}]-(node)
WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
UNWIND path_nodes AS n
WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
RETURN all_nodes, all_rels
}
}}
WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
// Apply node limiting here
WITH CASE
WHEN total_nodes <= $max_nodes THEN nodes
ELSE nodes[0..$max_nodes]
WITH
CASE
WHEN total_nodes <= {max_nodes} THEN nodes
ELSE nodes[0..{max_nodes}]
END AS limited_nodes,
relationships,
total_nodes,
total_nodes > $max_nodes AS is_truncated
UNWIND limited_nodes AS node
WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated
RETURN node_info, relationships, total_nodes, is_truncated
total_nodes > {max_nodes} AS is_truncated
RETURN
[node IN limited_nodes | {{node: node}}] AS node_info,
relationships,
total_nodes,
is_truncated
"""
result_set = None
try:
@ -742,8 +845,6 @@ class MemgraphStorage(BaseGraphStorage):
bfs_query,
{
"entity_id": node_label,
"max_depth": max_depth,
"max_nodes": max_nodes,
},
)
record = await result_set.single()
@ -777,22 +878,21 @@ class MemgraphStorage(BaseGraphStorage):
)
)
if "relationships" in record and record["relationships"]:
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
seen_edges.add(edge_id)
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
seen_edges.add(edge_id)
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"