From 03d0fa3014bc84756c94da2f2ac59dbe90dd94f9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 29 Aug 2025 18:15:45 +0800 Subject: [PATCH] perf: add optional query_embedding parameter to avoid redundant embedding calls --- lightrag/base.py | 13 +++++++++++-- lightrag/kg/faiss_impl.py | 18 ++++++++++++------ lightrag/kg/milvus_impl.py | 14 ++++++++++---- lightrag/kg/mongo_impl.py | 20 ++++++++++++-------- lightrag/kg/nano_vector_db_impl.py | 18 ++++++++++++------ lightrag/kg/postgres_impl.py | 16 +++++++++++----- lightrag/kg/qdrant_impl.py | 17 ++++++++++++----- lightrag/operate.py | 21 ++++++++++++++------- 8 files changed, 94 insertions(+), 43 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index e88a9d3e..c5518d23 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -211,8 +211,17 @@ class BaseVectorStorage(StorageNameSpace, ABC): meta_fields: set[str] = field(default_factory=set) @abstractmethod - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: - """Query the vector storage and retrieve top_k results.""" + async def query( + self, query: str, top_k: int, query_embedding: list[float] = None + ) -> list[dict[str, Any]]: + """Query the vector storage and retrieve top_k results. + + Args: + query: The query string to search for + top_k: Number of top results to return + query_embedding: Optional pre-computed embedding for the query. + If provided, skips embedding computation for better performance. + """ @abstractmethod async def upsert(self, data: dict[str, dict[str, Any]]) -> None: diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 29e7c5dd..7d6a6dac 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -179,15 +179,21 @@ class FaissVectorDBStorage(BaseVectorStorage): ) return [m["__id__"] for m in list_data] - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, query_embedding: list[float] = None + ) -> list[dict[str, Any]]: """ Search by a textual query; returns top_k results with their metadata + similarity distance. """ - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - # embedding is shape (1, dim) - embedding = np.array(embedding, dtype=np.float32) + if query_embedding is not None: + embedding = np.array([query_embedding], dtype=np.float32) + else: + embedding = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + # embedding is shape (1, dim) + embedding = np.array(embedding, dtype=np.float32) + faiss.normalize_L2(embedding) # we do in-place normalization # Perform the similarity search diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 9fa79022..f2368afe 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1046,13 +1046,19 @@ class MilvusVectorDBStorage(BaseVectorStorage): ) return results - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, query_embedding: list[float] = None + ) -> list[dict[str, Any]]: # Ensure collection is loaded before querying self._ensure_collection_loaded() - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query + # Use provided embedding or compute it + if query_embedding is not None: + embedding = [query_embedding] # Milvus expects a list of embeddings + else: + embedding = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query # Include all meta_fields (created_at is now always included) output_fields = list(self.meta_fields) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 6a38a86c..8d52af64 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1809,15 +1809,19 @@ class MongoVectorDBStorage(BaseVectorStorage): return list_data - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, query_embedding: list[float] = None + ) -> list[dict[str, Any]]: """Queries the vector database using Atlas Vector Search.""" - # Generate the embedding - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - - # Convert numpy array to a list to ensure compatibility with MongoDB - query_vector = embedding[0].tolist() + if query_embedding is not None: + query_vector = query_embedding + else: + # Generate the embedding + embedding = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + # Convert numpy array to a list to ensure compatibility with MongoDB + query_vector = embedding[0].tolist() # Define the aggregation pipeline with the converted query vector pipeline = [ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index bc1b72d3..def5a83d 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -136,12 +136,18 @@ class NanoVectorDBStorage(BaseVectorStorage): f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" ) - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: - # Execute embedding outside of lock to avoid improve cocurrent - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - embedding = embedding[0] + async def query( + self, query: str, top_k: int, query_embedding: list[float] = None + ) -> list[dict[str, Any]]: + # Use provided embedding or compute it + if query_embedding is not None: + embedding = query_embedding + else: + # Execute embedding outside of lock to avoid improve cocurrent + embedding = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + embedding = embedding[0] client = await self._get_client() results = client.query( diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 2811de5b..03a26f54 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -2004,11 +2004,17 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) #################### query method ############### - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: - embeddings = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - embedding = embeddings[0] + async def query( + self, query: str, top_k: int, query_embedding: list[float] = None + ) -> list[dict[str, Any]]: + if query_embedding is not None: + embedding = query_embedding + else: + embeddings = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + embedding = embeddings[0] + embedding_string = ",".join(map(str, embedding)) sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index fbd6bb10..dad95bbc 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -199,13 +199,20 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) return results - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: - embedding = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query + async def query( + self, query: str, top_k: int, query_embedding: list[float] = None + ) -> list[dict[str, Any]]: + if query_embedding is not None: + embedding = query_embedding + else: + embedding_result = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + embedding = embedding_result[0] + results = self._client.search( collection_name=self.final_namespace, - query_vector=embedding[0], + query_vector=embedding, limit=top_k, with_payload=True, score_threshold=self.cosine_better_than_threshold, diff --git a/lightrag/operate.py b/lightrag/operate.py index 5e87f6af..afa8205f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2234,6 +2234,7 @@ async def _get_vector_context( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, + query_embedding: list[float] = None, ) -> list[dict]: """ Retrieve text chunks from the vector database without reranking or truncation. @@ -2245,6 +2246,7 @@ async def _get_vector_context( query: The query string to search for chunks_vdb: Vector database containing document chunks query_param: Query parameters including chunk_top_k and ids + query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls Returns: List of text chunks with metadata @@ -2253,7 +2255,9 @@ async def _get_vector_context( # Use chunk_top_k if specified, otherwise fall back to top_k search_top_k = query_param.chunk_top_k or query_param.top_k - results = await chunks_vdb.query(query, top_k=search_top_k) + results = await chunks_vdb.query( + query, top_k=search_top_k, query_embedding=query_embedding + ) if not results: logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})") return [] @@ -2291,6 +2295,10 @@ async def _build_query_context( query_param: QueryParam, chunks_vdb: BaseVectorStorage = None, ): + if not query: + logger.warning("Query is empty, skipping context building") + return "" + logger.info(f"Process {os.getpid()} building query context...") # Collect chunks from different sources separately @@ -2309,12 +2317,12 @@ async def _build_query_context( # Track chunk sources and metadata for final logging chunk_tracking = {} # chunk_id -> {source, frequency, order} - # Pre-compute query embedding if vector similarity method is used + # Pre-compute query embedding once for all vector operations kg_chunk_pick_method = text_chunks_db.global_config.get( "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD ) query_embedding = None - if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: + if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): embedding_func_config = text_chunks_db.embedding_func if embedding_func_config and embedding_func_config.func: try: @@ -2322,9 +2330,7 @@ async def _build_query_context( query_embedding = query_embedding[ 0 ] # Extract first embedding from batch result - logger.debug( - "Pre-computed query embedding for vector similarity chunk selection" - ) + logger.debug("Pre-computed query embedding for all vector operations") except Exception as e: logger.warning(f"Failed to pre-compute query embedding: {e}") query_embedding = None @@ -2368,6 +2374,7 @@ async def _build_query_context( query, chunks_vdb, query_param, + query_embedding, ) # Track vector chunks with source metadata for i, chunk in enumerate(vector_chunks): @@ -3429,7 +3436,7 @@ async def naive_query( tokenizer: Tokenizer = global_config["tokenizer"] - chunks = await _get_vector_context(query, chunks_vdb, query_param) + chunks = await _get_vector_context(query, chunks_vdb, query_param, None) if chunks is None or len(chunks) == 0: return PROMPTS["fail_response"]