mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-25 09:50:20 +00:00
218 lines
7.6 KiB
Python
218 lines
7.6 KiB
Python
"""
|
|
NanoVectorDB Storage Module
|
|
=======================
|
|
|
|
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
|
|
|
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
|
|
|
Author: lightrag team
|
|
Created: 2024-01-25
|
|
License: MIT
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE.
|
|
|
|
Version: 1.0.0
|
|
|
|
Dependencies:
|
|
- NetworkX
|
|
- NumPy
|
|
- LightRAG
|
|
- graspologic
|
|
|
|
Features:
|
|
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
|
- Query graph nodes and edges
|
|
- Calculate node and edge degrees
|
|
- Embed nodes using various algorithms (e.g., Node2Vec)
|
|
- Remove nodes and edges from the graph
|
|
|
|
Usage:
|
|
from lightrag.storage.networkx_storage import NetworkXStorage
|
|
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
from tqdm.asyncio import tqdm as tqdm_async
|
|
from dataclasses import dataclass
|
|
import numpy as np
|
|
import pipmaster as pm
|
|
|
|
if not pm.is_installed("nano-vectordb"):
|
|
pm.install("nano-vectordb")
|
|
|
|
from nano_vectordb import NanoVectorDB
|
|
import time
|
|
|
|
from lightrag.utils import (
|
|
logger,
|
|
compute_mdhash_id,
|
|
)
|
|
|
|
from lightrag.base import (
|
|
BaseVectorStorage,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class NanoVectorDBStorage(BaseVectorStorage):
|
|
cosine_better_than_threshold: float = None
|
|
|
|
def __post_init__(self):
|
|
# Initialize lock only for file operations
|
|
self._save_lock = asyncio.Lock()
|
|
# Use global config value if specified, otherwise use default
|
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
|
if cosine_threshold is None:
|
|
raise ValueError(
|
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
|
)
|
|
self.cosine_better_than_threshold = cosine_threshold
|
|
|
|
self._client_file_name = os.path.join(
|
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
|
)
|
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
|
self._client = NanoVectorDB(
|
|
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
|
)
|
|
|
|
async def upsert(self, data: dict[str, dict]):
|
|
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
|
if not len(data):
|
|
logger.warning("You insert an empty data to vector DB")
|
|
return []
|
|
|
|
current_time = time.time()
|
|
list_data = [
|
|
{
|
|
"__id__": k,
|
|
"__created_at__": current_time,
|
|
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
|
}
|
|
for k, v in data.items()
|
|
]
|
|
contents = [v["content"] for v in data.values()]
|
|
batches = [
|
|
contents[i : i + self._max_batch_size]
|
|
for i in range(0, len(contents), self._max_batch_size)
|
|
]
|
|
|
|
async def wrapped_task(batch):
|
|
result = await self.embedding_func(batch)
|
|
pbar.update(1)
|
|
return result
|
|
|
|
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
|
pbar = tqdm_async(
|
|
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
|
)
|
|
embeddings_list = await asyncio.gather(*embedding_tasks)
|
|
|
|
embeddings = np.concatenate(embeddings_list)
|
|
if len(embeddings) == len(list_data):
|
|
for i, d in enumerate(list_data):
|
|
d["__vector__"] = embeddings[i]
|
|
results = self._client.upsert(datas=list_data)
|
|
return results
|
|
else:
|
|
# sometimes the embedding is not returned correctly. just log it.
|
|
logger.error(
|
|
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
|
)
|
|
|
|
async def query(self, query: str, top_k=5):
|
|
embedding = await self.embedding_func([query])
|
|
embedding = embedding[0]
|
|
results = self._client.query(
|
|
query=embedding,
|
|
top_k=top_k,
|
|
better_than_threshold=self.cosine_better_than_threshold,
|
|
)
|
|
results = [
|
|
{
|
|
**dp,
|
|
"id": dp["__id__"],
|
|
"distance": dp["__metrics__"],
|
|
"created_at": dp.get("__created_at__"),
|
|
}
|
|
for dp in results
|
|
]
|
|
return results
|
|
|
|
@property
|
|
def client_storage(self):
|
|
return getattr(self._client, "_NanoVectorDB__storage")
|
|
|
|
async def delete(self, ids: list[str]):
|
|
"""Delete vectors with specified IDs
|
|
|
|
Args:
|
|
ids: List of vector IDs to be deleted
|
|
"""
|
|
try:
|
|
self._client.delete(ids)
|
|
logger.info(
|
|
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
|
|
|
async def delete_entity(self, entity_name: str):
|
|
try:
|
|
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
|
logger.debug(
|
|
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
|
)
|
|
# Check if the entity exists
|
|
if self._client.get([entity_id]):
|
|
await self.delete([entity_id])
|
|
logger.debug(f"Successfully deleted entity {entity_name}")
|
|
else:
|
|
logger.debug(f"Entity {entity_name} not found in storage")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting entity {entity_name}: {e}")
|
|
|
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
|
try:
|
|
relations = [
|
|
dp
|
|
for dp in self.client_storage["data"]
|
|
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
|
]
|
|
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
|
ids_to_delete = [relation["__id__"] for relation in relations]
|
|
|
|
if ids_to_delete:
|
|
await self.delete(ids_to_delete)
|
|
logger.debug(
|
|
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
|
)
|
|
else:
|
|
logger.debug(f"No relations found for entity {entity_name}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
|
|
|
async def index_done_callback(self):
|
|
# Protect file write operation
|
|
async with self._save_lock:
|
|
self._client.save()
|