From 687ccd49239ea86db289f9dcb638b80e8b92a3b8 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 26 Jun 2025 14:37:04 +0800 Subject: [PATCH] fix: optimize MongoDB aggregation pipeline to prevent memory limit errors - Move $limit operation early in pipeline for "*" queries to reduce memory usage - Remove memory-intensive $sort operation for large dataset queries - Add fallback mechanism for memory limit errors with simple query - Implement additional safety checks to enforce max_nodes limit - Improve error handling and logging for memory-related issues --- lightrag/kg/mongo_impl.py | 92 ++++++++++++++++++++++++++++++--------- 1 file changed, 71 insertions(+), 21 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index ac32268f..fbea463b 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -732,24 +732,25 @@ class MongoGraphStorage(BaseGraphStorage): node_edges = [] try: - pipeline = [ - { - "$graphLookup": { - "from": self._edge_collection_name, - "startWith": "$_id", - "connectFromField": "target_node_id", - "connectToField": "source_node_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_edges", - }, - }, - {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, - {"$sort": {"edge_count": -1}}, - {"$limit": max_nodes}, - ] - + # Optimize pipeline to avoid memory issues with large datasets if label == "*": + # For getting all nodes, use a simpler pipeline to avoid memory issues + pipeline = [ + {"$limit": max_nodes}, # Limit early to reduce memory usage + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "target_node_id", + "connectToField": "source_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + }, + }, + ] + + # Check if we need to set truncation flag all_node_count = await self.collection.count_documents({}) result.is_truncated = all_node_count > max_nodes else: @@ -759,10 +760,28 @@ class MongoGraphStorage(BaseGraphStorage): logger.warning(f"Starting node with label {label} does not exist!") return result - # Add starting node to pipeline - pipeline.insert(0, {"$match": {"_id": label}}) + # For specific node queries, use the original pipeline but optimized + pipeline = [ + {"$match": {"_id": label}}, + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "target_node_id", + "connectToField": "source_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + }, + }, + {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, + {"$sort": {"edge_count": -1}}, + {"$limit": max_nodes}, + ] cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) + nodes_processed = 0 + async for doc in cursor: # Add the start node node_id = str(doc["_id"]) @@ -786,6 +805,13 @@ class MongoGraphStorage(BaseGraphStorage): if doc.get("connected_edges", []): node_edges.extend(doc.get("connected_edges")) + nodes_processed += 1 + + # Additional safety check to prevent memory issues + if nodes_processed >= max_nodes: + result.is_truncated = True + break + for edge in node_edges: if ( edge["source_node_id"] not in seen_nodes @@ -817,11 +843,35 @@ class MongoGraphStorage(BaseGraphStorage): seen_edges.add(edge_id) logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" ) except PyMongoError as e: - logger.error(f"MongoDB query failed: {str(e)}") + # Handle memory limit errors specifically + if "memory limit" in str(e).lower() or "sort exceeded" in str(e).lower(): + logger.warning( + f"MongoDB memory limit exceeded, falling back to simple query: {str(e)}" + ) + # Fallback to a simple query without complex aggregation + try: + simple_cursor = self.collection.find({}).limit(max_nodes) + async for doc in simple_cursor: + node_id = str(doc["_id"]) + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties={k: v for k, v in doc.items() if k != "_id"}, + ) + ) + result.is_truncated = True + logger.info( + f"Fallback query completed | Node count: {len(result.nodes)}" + ) + except PyMongoError as fallback_error: + logger.error(f"Fallback query also failed: {str(fallback_error)}") + else: + logger.error(f"MongoDB query failed: {str(e)}") return result