From 73cc86662ac82b72e3f73b60f053f8916fa5f2bc Mon Sep 17 00:00:00 2001 From: Ken Chen Date: Sat, 28 Jun 2025 20:00:13 +0800 Subject: [PATCH] Add two BFS subgraph search support for MongoDBGraph --- lightrag/kg/mongo_impl.py | 345 ++++++++++++++++++++++++-------------- 1 file changed, 222 insertions(+), 123 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 6b158d7a..4bdba7ae 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,4 +1,5 @@ import os +import time from dataclasses import dataclass, field import numpy as np import configparser @@ -35,6 +36,7 @@ config.read("config.ini", "utf-8") # Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) +GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional") class ClientManager: @@ -822,6 +824,198 @@ class MongoGraphStorage(BaseGraphStorage): return result + async def _bidirectional_bfs_nodes( + self, + node_labels: list[str], + seen_nodes: set[str], + result: KnowledgeGraph, + depth: int = 0, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + if depth > max_depth or len(result.nodes) > max_nodes: + return result + + cursor = self.collection.find({"_id": {"$in": node_labels}}) + + async for node in cursor: + node_id = node["_id"] + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append(self._construct_graph_node(node_id, node)) + if len(result.nodes) > max_nodes: + return result + + # Collect neighbors + # Get both inbound and outbound one hop nodes + cursor = self.edge_collection.find( + { + "$or": [ + {"source_node_id": {"$in": node_labels}}, + {"target_node_id": {"$in": node_labels}}, + ] + } + ) + + neighbor_nodes = [] + async for edge in cursor: + if edge["source_node_id"] not in seen_nodes: + neighbor_nodes.append(edge["source_node_id"]) + if edge["target_node_id"] not in seen_nodes: + neighbor_nodes.append(edge["target_node_id"]) + + if neighbor_nodes: + result = await self._bidirectional_bfs_nodes( + neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes + ) + + return result + + async def get_knowledge_subgraph_bidirectional_bfs( + self, + node_label: str, + depth=0, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + seen_nodes = set() + seen_edges = set() + result = KnowledgeGraph() + + result = await self._bidirectional_bfs_nodes( + [node_label], seen_nodes, result, depth, max_depth, max_nodes + ) + + # Get all edges from seen_nodes + all_node_ids = list(seen_nodes) + cursor = self.edge_collection.find( + { + "$and": [ + {"source_node_id": {"$in": all_node_ids}}, + {"target_node_id": {"$in": all_node_ids}}, + ] + } + ) + + async for edge in cursor: + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + result.edges.append(self._construct_graph_edge(edge_id, edge)) + seen_edges.add(edge_id) + + return result + + async def get_knowledge_subgraph_in_out_bound_bfs( + self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES + ) -> KnowledgeGraph: + seen_nodes = set() + seen_edges = set() + result = KnowledgeGraph() + project_doc = { + "source_ids": 0, + "created_at": 0, + "entity_type": 0, + "file_path": 0, + } + + # Verify if starting node exists + start_node = await self.collection.find_one({"_id": node_label}) + if not start_node: + logger.warning(f"Starting node with label {node_label} does not exist!") + return result + + seen_nodes.add(node_label) + result.nodes.append(self._construct_graph_node(node_label, start_node)) + + if max_depth == 0: + return result + + # In MongoDB, depth = 0 means one-hop + max_depth = max_depth - 1 + + pipeline = [ + {"$match": {"_id": node_label}}, + {"$project": project_doc}, + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "target_node_id", + "connectToField": "source_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + }, + }, + { + "$unionWith": { + "coll": self._collection_name, + "pipeline": [ + {"$match": {"_id": node_label}}, + {"$project": project_doc}, + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "source_node_id", + "connectToField": "target_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + } + }, + ], + } + }, + ] + + cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) + node_edges = [] + + # Two records for node_label are returned capturing outbound and inbound connected_edges + async for doc in cursor: + if doc.get("connected_edges", []): + node_edges.extend(doc.get("connected_edges")) + + # Sort the connected edges by depth ascending and weight descending + # And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes + node_edges = sorted( + node_edges, + key=lambda x: (x["depth"], -x["weight"]), + ) + + # As order matters, we need to use another list to store the node_id + # And only take the first max_nodes ones + node_ids = [] + for edge in node_edges: + if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes: + node_ids.append(edge["source_node_id"]) + seen_nodes.add(edge["source_node_id"]) + + if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes: + node_ids.append(edge["target_node_id"]) + seen_nodes.add(edge["target_node_id"]) + + # Filter out all the node whose id is same as node_label so that we do not check existence next step + cursor = self.collection.find({"_id": {"$in": node_ids}}) + + async for doc in cursor: + result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc)) + + for edge in node_edges: + if ( + edge["source_node_id"] not in seen_nodes + or edge["target_node_id"] not in seen_nodes + ): + continue + + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + result.edges.append(self._construct_graph_edge(edge_id, edge)) + seen_edges.add(edge_id) + + return result + async def get_knowledge_graph( self, node_label: str, @@ -830,6 +1024,23 @@ class MongoGraphStorage(BaseGraphStorage): ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) + If a graph is like this and starting from B: + A → B ← C ← F, B -> E, C → D + + Outbound BFS: + B → E + + Inbound BFS: + A → B + C → B + F → C + + Bidirectional BFS: + A → B + B → E + F → C + C → B + C → D Args: node_label: Label of the nodes to start from @@ -838,138 +1049,28 @@ class MongoGraphStorage(BaseGraphStorage): Returns: KnowledgeGraph object containing nodes and edges of the subgraph """ - label = node_label result = KnowledgeGraph() - seen_nodes = set() - seen_edges = set() - node_edges = [] - project_doc = { - "source_ids": 0, - "created_at": 0, - "entity_type": 0, - "file_path": 0, - } + start = time.perf_counter() try: # Optimize pipeline to avoid memory issues with large datasets - if label == "*": - return await self.get_knowledge_graph_all_by_degree( + if node_label == "*": + result = await self.get_knowledge_graph_all_by_degree( max_depth, max_nodes ) + elif GRAPH_BFS_MODE == "in_out_bound": + result = await self.get_knowledge_subgraph_in_out_bound_bfs( + node_label, max_depth, max_nodes + ) else: - # Verify if starting node exists - start_node = await self.collection.find_one({"_id": label}) - if not start_node: - logger.warning(f"Starting node with label {label} does not exist!") - return result - - # For specific node queries, use the original pipeline but optimized - pipeline = [ - {"$match": {"_id": label}}, - {"$project": project_doc}, - { - "$graphLookup": { - "from": self._edge_collection_name, - "startWith": "$_id", - "connectFromField": "target_node_id", - "connectToField": "source_node_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_edges", - }, - }, - { - "$unionWith": { - "coll": self._collection_name, - "pipeline": [ - {"$match": {"_id": label}}, - {"$project": project_doc}, - { - "$graphLookup": { - "from": self._edge_collection_name, - "startWith": "$_id", - "connectFromField": "source_node_id", - "connectToField": "target_node_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_edges", - } - }, - ], - } - }, - ] - - cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) - nodes_processed = 0 - - async for doc in cursor: - # Add the start nodes - node_id = str(doc["_id"]) - if node_id not in seen_nodes: - seen_nodes.add(node_id) - result.nodes.append(self._construct_graph_node(node_id, doc)) - - if doc.get("connected_edges", []): - node_edges.extend(doc.get("connected_edges")) - - nodes_processed += 1 - - # Additional safety check to prevent memory issues - if nodes_processed >= max_nodes: - result.is_truncated = True - break - - # When label != "*", cursor above only have one node and we need to get the subgraph by connected edges - # Sort the connected edges by depth ascending and weight descending - # And stores the source_node_id and target_node_id in sequence to retrieve the nodes again - if label != "*": - node_edges = sorted( - node_edges, - key=lambda x: (x["depth"], -x["weight"]), + result = await self.get_knowledge_subgraph_bidirectional_bfs( + node_label, 0, max_depth, max_nodes ) - # As order matters, we need to use another list to store the node_id - # And only take the first max_nodes ones - node_ids = [] - for edge in node_edges: - if ( - len(node_ids) < max_nodes - and edge["source_node_id"] not in seen_nodes - ): - node_ids.append(edge["source_node_id"]) - seen_nodes.add(edge["source_node_id"]) - - if ( - len(node_ids) < max_nodes - and edge["target_node_id"] not in seen_nodes - ): - node_ids.append(edge["target_node_id"]) - seen_nodes.add(edge["target_node_id"]) - - # Filter out all the node whose id is same as label so that we do not check existence next step - cursor = self.collection.find( - {"_id": {"$in": [node for node in node_ids if node != label]}} - ) - - async for doc in cursor: - node_id = str(doc["_id"]) - result.nodes.append(self._construct_graph_node(node_id, doc)) - - for edge in node_edges: - if ( - edge["source_node_id"] not in seen_nodes - or edge["target_node_id"] not in seen_nodes - ): - continue - - edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" - if edge_id not in seen_edges: - result.edges.append(self._construct_graph_edge(edge_id, edge)) - seen_edges.add(edge_id) + duration = time.perf_counter() - start logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" + f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" ) except PyMongoError as e: @@ -1149,8 +1250,6 @@ class MongoVectorDBStorage(BaseVectorStorage): return # Add current time as Unix timestamp - import time - current_time = int(time.time()) list_data = [