mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-12-04 02:46:27 +00:00
perf: add optional query_embedding parameter to avoid redundant embedding calls
This commit is contained in:
parent
a923d378dd
commit
03d0fa3014
@ -211,8 +211,17 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
meta_fields: set[str] = field(default_factory=set)
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""Query the vector storage and retrieve top_k results."""
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query the vector storage and retrieve top_k results.
|
||||
|
||||
Args:
|
||||
query: The query string to search for
|
||||
top_k: Number of top results to return
|
||||
query_embedding: Optional pre-computed embedding for the query.
|
||||
If provided, skips embedding computation for better performance.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
|
||||
@ -179,15 +179,21 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
return [m["__id__"] for m in list_data]
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||
"""
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
# embedding is shape (1, dim)
|
||||
embedding = np.array(embedding, dtype=np.float32)
|
||||
if query_embedding is not None:
|
||||
embedding = np.array([query_embedding], dtype=np.float32)
|
||||
else:
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
# embedding is shape (1, dim)
|
||||
embedding = np.array(embedding, dtype=np.float32)
|
||||
|
||||
faiss.normalize_L2(embedding) # we do in-place normalization
|
||||
|
||||
# Perform the similarity search
|
||||
|
||||
@ -1046,13 +1046,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
# Ensure collection is loaded before querying
|
||||
self._ensure_collection_loaded()
|
||||
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
# Use provided embedding or compute it
|
||||
if query_embedding is not None:
|
||||
embedding = [query_embedding] # Milvus expects a list of embeddings
|
||||
else:
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
|
||||
# Include all meta_fields (created_at is now always included)
|
||||
output_fields = list(self.meta_fields)
|
||||
|
||||
@ -1809,15 +1809,19 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
return list_data
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Queries the vector database using Atlas Vector Search."""
|
||||
# Generate the embedding
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
|
||||
# Convert numpy array to a list to ensure compatibility with MongoDB
|
||||
query_vector = embedding[0].tolist()
|
||||
if query_embedding is not None:
|
||||
query_vector = query_embedding
|
||||
else:
|
||||
# Generate the embedding
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
# Convert numpy array to a list to ensure compatibility with MongoDB
|
||||
query_vector = embedding[0].tolist()
|
||||
|
||||
# Define the aggregation pipeline with the converted query vector
|
||||
pipeline = [
|
||||
|
||||
@ -136,12 +136,18 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
# Execute embedding outside of lock to avoid improve cocurrent
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embedding[0]
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
# Use provided embedding or compute it
|
||||
if query_embedding is not None:
|
||||
embedding = query_embedding
|
||||
else:
|
||||
# Execute embedding outside of lock to avoid improve cocurrent
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embedding[0]
|
||||
|
||||
client = await self._get_client()
|
||||
results = client.query(
|
||||
|
||||
@ -2004,11 +2004,17 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embeddings[0]
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
if query_embedding is not None:
|
||||
embedding = query_embedding
|
||||
else:
|
||||
embeddings = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embeddings[0]
|
||||
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
|
||||
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
||||
|
||||
@ -199,13 +199,20 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
if query_embedding is not None:
|
||||
embedding = query_embedding
|
||||
else:
|
||||
embedding_result = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embedding_result[0]
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self.final_namespace,
|
||||
query_vector=embedding[0],
|
||||
query_vector=embedding,
|
||||
limit=top_k,
|
||||
with_payload=True,
|
||||
score_threshold=self.cosine_better_than_threshold,
|
||||
|
||||
@ -2234,6 +2234,7 @@ async def _get_vector_context(
|
||||
query: str,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
query_param: QueryParam,
|
||||
query_embedding: list[float] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Retrieve text chunks from the vector database without reranking or truncation.
|
||||
@ -2245,6 +2246,7 @@ async def _get_vector_context(
|
||||
query: The query string to search for
|
||||
chunks_vdb: Vector database containing document chunks
|
||||
query_param: Query parameters including chunk_top_k and ids
|
||||
query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls
|
||||
|
||||
Returns:
|
||||
List of text chunks with metadata
|
||||
@ -2253,7 +2255,9 @@ async def _get_vector_context(
|
||||
# Use chunk_top_k if specified, otherwise fall back to top_k
|
||||
search_top_k = query_param.chunk_top_k or query_param.top_k
|
||||
|
||||
results = await chunks_vdb.query(query, top_k=search_top_k)
|
||||
results = await chunks_vdb.query(
|
||||
query, top_k=search_top_k, query_embedding=query_embedding
|
||||
)
|
||||
if not results:
|
||||
logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})")
|
||||
return []
|
||||
@ -2291,6 +2295,10 @@ async def _build_query_context(
|
||||
query_param: QueryParam,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
):
|
||||
if not query:
|
||||
logger.warning("Query is empty, skipping context building")
|
||||
return ""
|
||||
|
||||
logger.info(f"Process {os.getpid()} building query context...")
|
||||
|
||||
# Collect chunks from different sources separately
|
||||
@ -2309,12 +2317,12 @@ async def _build_query_context(
|
||||
# Track chunk sources and metadata for final logging
|
||||
chunk_tracking = {} # chunk_id -> {source, frequency, order}
|
||||
|
||||
# Pre-compute query embedding if vector similarity method is used
|
||||
# Pre-compute query embedding once for all vector operations
|
||||
kg_chunk_pick_method = text_chunks_db.global_config.get(
|
||||
"kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
|
||||
)
|
||||
query_embedding = None
|
||||
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
|
||||
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
|
||||
embedding_func_config = text_chunks_db.embedding_func
|
||||
if embedding_func_config and embedding_func_config.func:
|
||||
try:
|
||||
@ -2322,9 +2330,7 @@ async def _build_query_context(
|
||||
query_embedding = query_embedding[
|
||||
0
|
||||
] # Extract first embedding from batch result
|
||||
logger.debug(
|
||||
"Pre-computed query embedding for vector similarity chunk selection"
|
||||
)
|
||||
logger.debug("Pre-computed query embedding for all vector operations")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-compute query embedding: {e}")
|
||||
query_embedding = None
|
||||
@ -2368,6 +2374,7 @@ async def _build_query_context(
|
||||
query,
|
||||
chunks_vdb,
|
||||
query_param,
|
||||
query_embedding,
|
||||
)
|
||||
# Track vector chunks with source metadata
|
||||
for i, chunk in enumerate(vector_chunks):
|
||||
@ -3429,7 +3436,7 @@ async def naive_query(
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
|
||||
chunks = await _get_vector_context(query, chunks_vdb, query_param)
|
||||
chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
|
||||
|
||||
if chunks is None or len(chunks) == 0:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user