mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-08-05 15:21:53 +00:00
Add two BFS subgraph search support for MongoDBGraph
This commit is contained in:
parent
5739f52d29
commit
73cc86662a
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
import configparser
|
||||
@ -35,6 +36,7 @@ config.read("config.ini", "utf-8")
|
||||
|
||||
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
|
||||
|
||||
|
||||
class ClientManager:
|
||||
@ -822,6 +824,198 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
|
||||
return result
|
||||
|
||||
async def _bidirectional_bfs_nodes(
|
||||
self,
|
||||
node_labels: list[str],
|
||||
seen_nodes: set[str],
|
||||
result: KnowledgeGraph,
|
||||
depth: int = 0,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
if depth > max_depth or len(result.nodes) > max_nodes:
|
||||
return result
|
||||
|
||||
cursor = self.collection.find({"_id": {"$in": node_labels}})
|
||||
|
||||
async for node in cursor:
|
||||
node_id = node["_id"]
|
||||
if node_id not in seen_nodes:
|
||||
seen_nodes.add(node_id)
|
||||
result.nodes.append(self._construct_graph_node(node_id, node))
|
||||
if len(result.nodes) > max_nodes:
|
||||
return result
|
||||
|
||||
# Collect neighbors
|
||||
# Get both inbound and outbound one hop nodes
|
||||
cursor = self.edge_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"source_node_id": {"$in": node_labels}},
|
||||
{"target_node_id": {"$in": node_labels}},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
neighbor_nodes = []
|
||||
async for edge in cursor:
|
||||
if edge["source_node_id"] not in seen_nodes:
|
||||
neighbor_nodes.append(edge["source_node_id"])
|
||||
if edge["target_node_id"] not in seen_nodes:
|
||||
neighbor_nodes.append(edge["target_node_id"])
|
||||
|
||||
if neighbor_nodes:
|
||||
result = await self._bidirectional_bfs_nodes(
|
||||
neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def get_knowledge_subgraph_bidirectional_bfs(
|
||||
self,
|
||||
node_label: str,
|
||||
depth=0,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
result = KnowledgeGraph()
|
||||
|
||||
result = await self._bidirectional_bfs_nodes(
|
||||
[node_label], seen_nodes, result, depth, max_depth, max_nodes
|
||||
)
|
||||
|
||||
# Get all edges from seen_nodes
|
||||
all_node_ids = list(seen_nodes)
|
||||
cursor = self.edge_collection.find(
|
||||
{
|
||||
"$and": [
|
||||
{"source_node_id": {"$in": all_node_ids}},
|
||||
{"target_node_id": {"$in": all_node_ids}},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
async for edge in cursor:
|
||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
return result
|
||||
|
||||
async def get_knowledge_subgraph_in_out_bound_bfs(
|
||||
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
|
||||
) -> KnowledgeGraph:
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
result = KnowledgeGraph()
|
||||
project_doc = {
|
||||
"source_ids": 0,
|
||||
"created_at": 0,
|
||||
"entity_type": 0,
|
||||
"file_path": 0,
|
||||
}
|
||||
|
||||
# Verify if starting node exists
|
||||
start_node = await self.collection.find_one({"_id": node_label})
|
||||
if not start_node:
|
||||
logger.warning(f"Starting node with label {node_label} does not exist!")
|
||||
return result
|
||||
|
||||
seen_nodes.add(node_label)
|
||||
result.nodes.append(self._construct_graph_node(node_label, start_node))
|
||||
|
||||
if max_depth == 0:
|
||||
return result
|
||||
|
||||
# In MongoDB, depth = 0 means one-hop
|
||||
max_depth = max_depth - 1
|
||||
|
||||
pipeline = [
|
||||
{"$match": {"_id": node_label}},
|
||||
{"$project": project_doc},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "target_node_id",
|
||||
"connectToField": "source_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
},
|
||||
},
|
||||
{
|
||||
"$unionWith": {
|
||||
"coll": self._collection_name,
|
||||
"pipeline": [
|
||||
{"$match": {"_id": node_label}},
|
||||
{"$project": project_doc},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "source_node_id",
|
||||
"connectToField": "target_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
||||
node_edges = []
|
||||
|
||||
# Two records for node_label are returned capturing outbound and inbound connected_edges
|
||||
async for doc in cursor:
|
||||
if doc.get("connected_edges", []):
|
||||
node_edges.extend(doc.get("connected_edges"))
|
||||
|
||||
# Sort the connected edges by depth ascending and weight descending
|
||||
# And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
|
||||
node_edges = sorted(
|
||||
node_edges,
|
||||
key=lambda x: (x["depth"], -x["weight"]),
|
||||
)
|
||||
|
||||
# As order matters, we need to use another list to store the node_id
|
||||
# And only take the first max_nodes ones
|
||||
node_ids = []
|
||||
for edge in node_edges:
|
||||
if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
|
||||
node_ids.append(edge["source_node_id"])
|
||||
seen_nodes.add(edge["source_node_id"])
|
||||
|
||||
if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
|
||||
node_ids.append(edge["target_node_id"])
|
||||
seen_nodes.add(edge["target_node_id"])
|
||||
|
||||
# Filter out all the node whose id is same as node_label so that we do not check existence next step
|
||||
cursor = self.collection.find({"_id": {"$in": node_ids}})
|
||||
|
||||
async for doc in cursor:
|
||||
result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
|
||||
|
||||
for edge in node_edges:
|
||||
if (
|
||||
edge["source_node_id"] not in seen_nodes
|
||||
or edge["target_node_id"] not in seen_nodes
|
||||
):
|
||||
continue
|
||||
|
||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
return result
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self,
|
||||
node_label: str,
|
||||
@ -830,6 +1024,23 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
If a graph is like this and starting from B:
|
||||
A → B ← C ← F, B -> E, C → D
|
||||
|
||||
Outbound BFS:
|
||||
B → E
|
||||
|
||||
Inbound BFS:
|
||||
A → B
|
||||
C → B
|
||||
F → C
|
||||
|
||||
Bidirectional BFS:
|
||||
A → B
|
||||
B → E
|
||||
F → C
|
||||
C → B
|
||||
C → D
|
||||
|
||||
Args:
|
||||
node_label: Label of the nodes to start from
|
||||
@ -838,138 +1049,28 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges of the subgraph
|
||||
"""
|
||||
label = node_label
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
node_edges = []
|
||||
project_doc = {
|
||||
"source_ids": 0,
|
||||
"created_at": 0,
|
||||
"entity_type": 0,
|
||||
"file_path": 0,
|
||||
}
|
||||
start = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Optimize pipeline to avoid memory issues with large datasets
|
||||
if label == "*":
|
||||
return await self.get_knowledge_graph_all_by_degree(
|
||||
if node_label == "*":
|
||||
result = await self.get_knowledge_graph_all_by_degree(
|
||||
max_depth, max_nodes
|
||||
)
|
||||
elif GRAPH_BFS_MODE == "in_out_bound":
|
||||
result = await self.get_knowledge_subgraph_in_out_bound_bfs(
|
||||
node_label, max_depth, max_nodes
|
||||
)
|
||||
else:
|
||||
# Verify if starting node exists
|
||||
start_node = await self.collection.find_one({"_id": label})
|
||||
if not start_node:
|
||||
logger.warning(f"Starting node with label {label} does not exist!")
|
||||
return result
|
||||
|
||||
# For specific node queries, use the original pipeline but optimized
|
||||
pipeline = [
|
||||
{"$match": {"_id": label}},
|
||||
{"$project": project_doc},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "target_node_id",
|
||||
"connectToField": "source_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
},
|
||||
},
|
||||
{
|
||||
"$unionWith": {
|
||||
"coll": self._collection_name,
|
||||
"pipeline": [
|
||||
{"$match": {"_id": label}},
|
||||
{"$project": project_doc},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "source_node_id",
|
||||
"connectToField": "target_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
||||
nodes_processed = 0
|
||||
|
||||
async for doc in cursor:
|
||||
# Add the start nodes
|
||||
node_id = str(doc["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
seen_nodes.add(node_id)
|
||||
result.nodes.append(self._construct_graph_node(node_id, doc))
|
||||
|
||||
if doc.get("connected_edges", []):
|
||||
node_edges.extend(doc.get("connected_edges"))
|
||||
|
||||
nodes_processed += 1
|
||||
|
||||
# Additional safety check to prevent memory issues
|
||||
if nodes_processed >= max_nodes:
|
||||
result.is_truncated = True
|
||||
break
|
||||
|
||||
# When label != "*", cursor above only have one node and we need to get the subgraph by connected edges
|
||||
# Sort the connected edges by depth ascending and weight descending
|
||||
# And stores the source_node_id and target_node_id in sequence to retrieve the nodes again
|
||||
if label != "*":
|
||||
node_edges = sorted(
|
||||
node_edges,
|
||||
key=lambda x: (x["depth"], -x["weight"]),
|
||||
result = await self.get_knowledge_subgraph_bidirectional_bfs(
|
||||
node_label, 0, max_depth, max_nodes
|
||||
)
|
||||
|
||||
# As order matters, we need to use another list to store the node_id
|
||||
# And only take the first max_nodes ones
|
||||
node_ids = []
|
||||
for edge in node_edges:
|
||||
if (
|
||||
len(node_ids) < max_nodes
|
||||
and edge["source_node_id"] not in seen_nodes
|
||||
):
|
||||
node_ids.append(edge["source_node_id"])
|
||||
seen_nodes.add(edge["source_node_id"])
|
||||
|
||||
if (
|
||||
len(node_ids) < max_nodes
|
||||
and edge["target_node_id"] not in seen_nodes
|
||||
):
|
||||
node_ids.append(edge["target_node_id"])
|
||||
seen_nodes.add(edge["target_node_id"])
|
||||
|
||||
# Filter out all the node whose id is same as label so that we do not check existence next step
|
||||
cursor = self.collection.find(
|
||||
{"_id": {"$in": [node for node in node_ids if node != label]}}
|
||||
)
|
||||
|
||||
async for doc in cursor:
|
||||
node_id = str(doc["_id"])
|
||||
result.nodes.append(self._construct_graph_node(node_id, doc))
|
||||
|
||||
for edge in node_edges:
|
||||
if (
|
||||
edge["source_node_id"] not in seen_nodes
|
||||
or edge["target_node_id"] not in seen_nodes
|
||||
):
|
||||
continue
|
||||
|
||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||
seen_edges.add(edge_id)
|
||||
duration = time.perf_counter() - start
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
|
||||
f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
|
||||
)
|
||||
|
||||
except PyMongoError as e:
|
||||
@ -1149,8 +1250,6 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
return
|
||||
|
||||
# Add current time as Unix timestamp
|
||||
import time
|
||||
|
||||
current_time = int(time.time())
|
||||
|
||||
list_data = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user