optimize: batch embedding and qdrant write_consistency_factor parameter (#21776)

Co-authored-by: hobo.l <hobo.l@binance.com>
This commit is contained in:
luckylhb90 2025-07-10 10:16:59 +08:00 committed by GitHub
parent a316766ad7
commit a371390d6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 2 deletions

View File

@ -47,6 +47,7 @@ class QdrantConfig(BaseModel):
grpc_port: int = 6334
prefer_grpc: bool = False
replication_factor: int = 1
write_consistency_factor: int = 1
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
@ -127,6 +128,7 @@ class QdrantVector(BaseVector):
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
replication_factor=self._client_config.replication_factor,
write_consistency_factor=self._client_config.write_consistency_factor,
)
# create group_id payload index

View File

@ -1,3 +1,5 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import Any, Optional
@ -13,6 +15,8 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Whitelist
logger = logging.getLogger(__name__)
class AbstractVectorFactory(ABC):
@abstractmethod
@ -173,8 +177,20 @@ class Vector:
def create(self, texts: Optional[list] = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
start = time.time()
logger.info(f"start embedding {len(texts)} texts {start}")
batch_size = 1000
total_batches = len(texts) + batch_size - 1
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
batch_start = time.time()
logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)")
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
logger.info(
f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s"
)
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s")
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):