import asyncio import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np from lightrag.utils import logger from ..base import BaseVectorStorage import pipmaster as pm import configparser if not pm.is_installed("pymilvus"): pm.install("pymilvus") from pymilvus import MilvusClient config = configparser.ConfigParser() config.read("config.ini", "utf-8") @dataclass class MilvusVectorDBStorage(BaseVectorStorage): cosine_better_than_threshold: float = None @staticmethod def create_collection_if_not_exist( client: MilvusClient, collection_name: str, **kwargs ): if client.has_collection(collection_name): return client.create_collection( collection_name, max_length=64, id_type="string", **kwargs ) def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: raise ValueError( "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" ) self.cosine_better_than_threshold = cosine_threshold self._client = MilvusClient( uri=os.environ.get( "MILVUS_URI", config.get( "milvus", "uri", fallback=os.path.join( self.global_config["working_dir"], "milvus_lite.db" ), ), ), user=os.environ.get( "MILVUS_USER", config.get("milvus", "user", fallback=None) ), password=os.environ.get( "MILVUS_PASSWORD", config.get("milvus", "password", fallback=None) ), token=os.environ.get( "MILVUS_TOKEN", config.get("milvus", "token", fallback=None) ), db_name=os.environ.get( "MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None) ), ) self._max_batch_size = self.global_config["embedding_batch_num"] MilvusVectorDBStorage.create_collection_if_not_exist( self._client, self.namespace, dimension=self.embedding_func.embedding_dim, ) async def upsert(self, data: dict[str, dict]): logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") return [] list_data = [ { "id": k, **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, } for k, v in data.items() ] contents = [v["content"] for v in data.values()] batches = [ contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] async def wrapped_task(batch): result = await self.embedding_func(batch) pbar.update(1) return result embedding_tasks = [wrapped_task(batch) for batch in batches] pbar = tqdm_async( total=len(embedding_tasks), desc="Generating embeddings", unit="batch" ) embeddings_list = await asyncio.gather(*embedding_tasks) embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["vector"] = embeddings[i] results = self._client.upsert(collection_name=self.namespace, data=list_data) return results async def query(self, query, top_k=5): embedding = await self.embedding_func([query]) results = self._client.search( collection_name=self.namespace, data=embedding, limit=top_k, output_fields=list(self.meta_fields), search_params={ "metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}, }, ) print(results) return [ {**dp["entity"], "id": dp["id"], "distance": dp["distance"]} for dp in results[0] ]