Add two BFS subgraph search support for MongoDBGraph

This commit is contained in:
Ken Chen 2025-06-28 20:00:13 +08:00
parent 5739f52d29
commit 73cc86662a

View File

@ -1,4 +1,5 @@
import os import os
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
import numpy as np import numpy as np
import configparser import configparser
@ -35,6 +36,7 @@ config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000 # Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
class ClientManager: class ClientManager:
@ -822,27 +824,93 @@ class MongoGraphStorage(BaseGraphStorage):
return result return result
async def get_knowledge_graph( async def _bidirectional_bfs_nodes(
self, self,
node_label: str, node_labels: list[str],
max_depth: int = 5, seen_nodes: set[str],
result: KnowledgeGraph,
depth: int = 0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES, max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" if depth > max_depth or len(result.nodes) > max_nodes:
Get complete connected subgraph for specified node (including the starting node itself) return result
Args: cursor = self.collection.find({"_id": {"$in": node_labels}})
node_label: Label of the nodes to start from
max_depth: Maximum depth of traversal (default: 5)
Returns: async for node in cursor:
KnowledgeGraph object containing nodes and edges of the subgraph node_id = node["_id"]
""" if node_id not in seen_nodes:
label = node_label seen_nodes.add(node_id)
result = KnowledgeGraph() result.nodes.append(self._construct_graph_node(node_id, node))
if len(result.nodes) > max_nodes:
return result
# Collect neighbors
# Get both inbound and outbound one hop nodes
cursor = self.edge_collection.find(
{
"$or": [
{"source_node_id": {"$in": node_labels}},
{"target_node_id": {"$in": node_labels}},
]
}
)
neighbor_nodes = []
async for edge in cursor:
if edge["source_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["source_node_id"])
if edge["target_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["target_node_id"])
if neighbor_nodes:
result = await self._bidirectional_bfs_nodes(
neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
)
return result
async def get_knowledge_subgraph_bidirectional_bfs(
self,
node_label: str,
depth=0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
node_edges = [] result = KnowledgeGraph()
result = await self._bidirectional_bfs_nodes(
[node_label], seen_nodes, result, depth, max_depth, max_nodes
)
# Get all edges from seen_nodes
all_node_ids = list(seen_nodes)
cursor = self.edge_collection.find(
{
"$and": [
{"source_node_id": {"$in": all_node_ids}},
{"target_node_id": {"$in": all_node_ids}},
]
}
)
async for edge in cursor:
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
return result
async def get_knowledge_subgraph_in_out_bound_bfs(
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
result = KnowledgeGraph()
project_doc = { project_doc = {
"source_ids": 0, "source_ids": 0,
"created_at": 0, "created_at": 0,
@ -850,22 +918,23 @@ class MongoGraphStorage(BaseGraphStorage):
"file_path": 0, "file_path": 0,
} }
try:
# Optimize pipeline to avoid memory issues with large datasets
if label == "*":
return await self.get_knowledge_graph_all_by_degree(
max_depth, max_nodes
)
else:
# Verify if starting node exists # Verify if starting node exists
start_node = await self.collection.find_one({"_id": label}) start_node = await self.collection.find_one({"_id": node_label})
if not start_node: if not start_node:
logger.warning(f"Starting node with label {label} does not exist!") logger.warning(f"Starting node with label {node_label} does not exist!")
return result return result
# For specific node queries, use the original pipeline but optimized seen_nodes.add(node_label)
result.nodes.append(self._construct_graph_node(node_label, start_node))
if max_depth == 0:
return result
# In MongoDB, depth = 0 means one-hop
max_depth = max_depth - 1
pipeline = [ pipeline = [
{"$match": {"_id": label}}, {"$match": {"_id": node_label}},
{"$project": project_doc}, {"$project": project_doc},
{ {
"$graphLookup": { "$graphLookup": {
@ -882,7 +951,7 @@ class MongoGraphStorage(BaseGraphStorage):
"$unionWith": { "$unionWith": {
"coll": self._collection_name, "coll": self._collection_name,
"pipeline": [ "pipeline": [
{"$match": {"_id": label}}, {"$match": {"_id": node_label}},
{"$project": project_doc}, {"$project": project_doc},
{ {
"$graphLookup": { "$graphLookup": {
@ -901,29 +970,15 @@ class MongoGraphStorage(BaseGraphStorage):
] ]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
nodes_processed = 0 node_edges = []
# Two records for node_label are returned capturing outbound and inbound connected_edges
async for doc in cursor: async for doc in cursor:
# Add the start nodes
node_id = str(doc["_id"])
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(self._construct_graph_node(node_id, doc))
if doc.get("connected_edges", []): if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges")) node_edges.extend(doc.get("connected_edges"))
nodes_processed += 1
# Additional safety check to prevent memory issues
if nodes_processed >= max_nodes:
result.is_truncated = True
break
# When label != "*", cursor above only have one node and we need to get the subgraph by connected edges
# Sort the connected edges by depth ascending and weight descending # Sort the connected edges by depth ascending and weight descending
# And stores the source_node_id and target_node_id in sequence to retrieve the nodes again # And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
if label != "*":
node_edges = sorted( node_edges = sorted(
node_edges, node_edges,
key=lambda x: (x["depth"], -x["weight"]), key=lambda x: (x["depth"], -x["weight"]),
@ -933,28 +988,19 @@ class MongoGraphStorage(BaseGraphStorage):
# And only take the first max_nodes ones # And only take the first max_nodes ones
node_ids = [] node_ids = []
for edge in node_edges: for edge in node_edges:
if ( if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
len(node_ids) < max_nodes
and edge["source_node_id"] not in seen_nodes
):
node_ids.append(edge["source_node_id"]) node_ids.append(edge["source_node_id"])
seen_nodes.add(edge["source_node_id"]) seen_nodes.add(edge["source_node_id"])
if ( if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
len(node_ids) < max_nodes
and edge["target_node_id"] not in seen_nodes
):
node_ids.append(edge["target_node_id"]) node_ids.append(edge["target_node_id"])
seen_nodes.add(edge["target_node_id"]) seen_nodes.add(edge["target_node_id"])
# Filter out all the node whose id is same as label so that we do not check existence next step # Filter out all the node whose id is same as node_label so that we do not check existence next step
cursor = self.collection.find( cursor = self.collection.find({"_id": {"$in": node_ids}})
{"_id": {"$in": [node for node in node_ids if node != label]}}
)
async for doc in cursor: async for doc in cursor:
node_id = str(doc["_id"]) result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
result.nodes.append(self._construct_graph_node(node_id, doc))
for edge in node_edges: for edge in node_edges:
if ( if (
@ -968,8 +1014,63 @@ class MongoGraphStorage(BaseGraphStorage):
result.edges.append(self._construct_graph_edge(edge_id, edge)) result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id) seen_edges.add(edge_id)
return result
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 5,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
If a graph is like this and starting from B:
A B C F, B -> E, C D
Outbound BFS:
B E
Inbound BFS:
A B
C B
F C
Bidirectional BFS:
A B
B E
F C
C B
C D
Args:
node_label: Label of the nodes to start from
max_depth: Maximum depth of traversal (default: 5)
Returns:
KnowledgeGraph object containing nodes and edges of the subgraph
"""
result = KnowledgeGraph()
start = time.perf_counter()
try:
# Optimize pipeline to avoid memory issues with large datasets
if node_label == "*":
result = await self.get_knowledge_graph_all_by_degree(
max_depth, max_nodes
)
elif GRAPH_BFS_MODE == "in_out_bound":
result = await self.get_knowledge_subgraph_in_out_bound_bfs(
node_label, max_depth, max_nodes
)
else:
result = await self.get_knowledge_subgraph_bidirectional_bfs(
node_label, 0, max_depth, max_nodes
)
duration = time.perf_counter() - start
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
) )
except PyMongoError as e: except PyMongoError as e:
@ -1149,8 +1250,6 @@ class MongoVectorDBStorage(BaseVectorStorage):
return return
# Add current time as Unix timestamp # Add current time as Unix timestamp
import time
current_time = int(time.time()) current_time = int(time.time())
list_data = [ list_data = [