Add two BFS subgraph search support for MongoDBGraph

This commit is contained in:
Ken Chen 2025-06-28 20:00:13 +08:00
parent 5739f52d29
commit 73cc86662a

View File

@ -1,4 +1,5 @@
import os
import time
from dataclasses import dataclass, field
import numpy as np
import configparser
@ -35,6 +36,7 @@ config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
class ClientManager:
@ -822,6 +824,198 @@ class MongoGraphStorage(BaseGraphStorage):
return result
async def _bidirectional_bfs_nodes(
self,
node_labels: list[str],
seen_nodes: set[str],
result: KnowledgeGraph,
depth: int = 0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
if depth > max_depth or len(result.nodes) > max_nodes:
return result
cursor = self.collection.find({"_id": {"$in": node_labels}})
async for node in cursor:
node_id = node["_id"]
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(self._construct_graph_node(node_id, node))
if len(result.nodes) > max_nodes:
return result
# Collect neighbors
# Get both inbound and outbound one hop nodes
cursor = self.edge_collection.find(
{
"$or": [
{"source_node_id": {"$in": node_labels}},
{"target_node_id": {"$in": node_labels}},
]
}
)
neighbor_nodes = []
async for edge in cursor:
if edge["source_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["source_node_id"])
if edge["target_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["target_node_id"])
if neighbor_nodes:
result = await self._bidirectional_bfs_nodes(
neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
)
return result
async def get_knowledge_subgraph_bidirectional_bfs(
self,
node_label: str,
depth=0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
result = KnowledgeGraph()
result = await self._bidirectional_bfs_nodes(
[node_label], seen_nodes, result, depth, max_depth, max_nodes
)
# Get all edges from seen_nodes
all_node_ids = list(seen_nodes)
cursor = self.edge_collection.find(
{
"$and": [
{"source_node_id": {"$in": all_node_ids}},
{"target_node_id": {"$in": all_node_ids}},
]
}
)
async for edge in cursor:
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
return result
async def get_knowledge_subgraph_in_out_bound_bfs(
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
result = KnowledgeGraph()
project_doc = {
"source_ids": 0,
"created_at": 0,
"entity_type": 0,
"file_path": 0,
}
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": node_label})
if not start_node:
logger.warning(f"Starting node with label {node_label} does not exist!")
return result
seen_nodes.add(node_label)
result.nodes.append(self._construct_graph_node(node_label, start_node))
if max_depth == 0:
return result
# In MongoDB, depth = 0 means one-hop
max_depth = max_depth - 1
pipeline = [
{"$match": {"_id": node_label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
{
"$unionWith": {
"coll": self._collection_name,
"pipeline": [
{"$match": {"_id": node_label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "source_node_id",
"connectToField": "target_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
}
},
],
}
},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
node_edges = []
# Two records for node_label are returned capturing outbound and inbound connected_edges
async for doc in cursor:
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
# Sort the connected edges by depth ascending and weight descending
# And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
node_edges = sorted(
node_edges,
key=lambda x: (x["depth"], -x["weight"]),
)
# As order matters, we need to use another list to store the node_id
# And only take the first max_nodes ones
node_ids = []
for edge in node_edges:
if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
node_ids.append(edge["source_node_id"])
seen_nodes.add(edge["source_node_id"])
if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
node_ids.append(edge["target_node_id"])
seen_nodes.add(edge["target_node_id"])
# Filter out all the node whose id is same as node_label so that we do not check existence next step
cursor = self.collection.find({"_id": {"$in": node_ids}})
async for doc in cursor:
result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
return result
async def get_knowledge_graph(
self,
node_label: str,
@ -830,6 +1024,23 @@ class MongoGraphStorage(BaseGraphStorage):
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
If a graph is like this and starting from B:
A B C F, B -> E, C D
Outbound BFS:
B E
Inbound BFS:
A B
C B
F C
Bidirectional BFS:
A B
B E
F C
C B
C D
Args:
node_label: Label of the nodes to start from
@ -838,138 +1049,28 @@ class MongoGraphStorage(BaseGraphStorage):
Returns:
KnowledgeGraph object containing nodes and edges of the subgraph
"""
label = node_label
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
node_edges = []
project_doc = {
"source_ids": 0,
"created_at": 0,
"entity_type": 0,
"file_path": 0,
}
start = time.perf_counter()
try:
# Optimize pipeline to avoid memory issues with large datasets
if label == "*":
return await self.get_knowledge_graph_all_by_degree(
if node_label == "*":
result = await self.get_knowledge_graph_all_by_degree(
max_depth, max_nodes
)
elif GRAPH_BFS_MODE == "in_out_bound":
result = await self.get_knowledge_subgraph_in_out_bound_bfs(
node_label, max_depth, max_nodes
)
else:
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": label})
if not start_node:
logger.warning(f"Starting node with label {label} does not exist!")
return result
# For specific node queries, use the original pipeline but optimized
pipeline = [
{"$match": {"_id": label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
{
"$unionWith": {
"coll": self._collection_name,
"pipeline": [
{"$match": {"_id": label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "source_node_id",
"connectToField": "target_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
}
},
],
}
},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
nodes_processed = 0
async for doc in cursor:
# Add the start nodes
node_id = str(doc["_id"])
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(self._construct_graph_node(node_id, doc))
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
nodes_processed += 1
# Additional safety check to prevent memory issues
if nodes_processed >= max_nodes:
result.is_truncated = True
break
# When label != "*", cursor above only have one node and we need to get the subgraph by connected edges
# Sort the connected edges by depth ascending and weight descending
# And stores the source_node_id and target_node_id in sequence to retrieve the nodes again
if label != "*":
node_edges = sorted(
node_edges,
key=lambda x: (x["depth"], -x["weight"]),
result = await self.get_knowledge_subgraph_bidirectional_bfs(
node_label, 0, max_depth, max_nodes
)
# As order matters, we need to use another list to store the node_id
# And only take the first max_nodes ones
node_ids = []
for edge in node_edges:
if (
len(node_ids) < max_nodes
and edge["source_node_id"] not in seen_nodes
):
node_ids.append(edge["source_node_id"])
seen_nodes.add(edge["source_node_id"])
if (
len(node_ids) < max_nodes
and edge["target_node_id"] not in seen_nodes
):
node_ids.append(edge["target_node_id"])
seen_nodes.add(edge["target_node_id"])
# Filter out all the node whose id is same as label so that we do not check existence next step
cursor = self.collection.find(
{"_id": {"$in": [node for node in node_ids if node != label]}}
)
async for doc in cursor:
node_id = str(doc["_id"])
result.nodes.append(self._construct_graph_node(node_id, doc))
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
duration = time.perf_counter() - start
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
)
except PyMongoError as e:
@ -1149,8 +1250,6 @@ class MongoVectorDBStorage(BaseVectorStorage):
return
# Add current time as Unix timestamp
import time
current_time = int(time.time())
list_data = [