diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 6b249b94..cf4822e5 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -40,36 +40,39 @@ GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional") class ClientManager: _instances = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() @classmethod async def get_client(cls) -> AsyncMongoClient: - if cls._instances["db"] is None: - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", - "uri", - fallback="mongodb://root:root@localhost:27017/", - ), - ) - database_name = os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - client = AsyncMongoClient(uri) - db = client.get_database(database_name) - cls._instances["db"] = db - cls._instances["ref_count"] = 0 - cls._instances["ref_count"] += 1 - return cls._instances["db"] + async with cls._lock: + if cls._instances["db"] is None: + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", + "uri", + fallback="mongodb://root:root@localhost:27017/", + ), + ) + database_name = os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + client = AsyncMongoClient(uri) + db = client.get_database(database_name) + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] @classmethod async def release_client(cls, db: AsyncDatabase): - if db is not None: - if db is cls._instances["db"]: - cls._instances["ref_count"] -= 1 - if cls._instances["ref_count"] == 0: - cls._instances["db"] = None + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + cls._instances["db"] = None @final