mirror of
https://github.com/langgenius/dify.git
synced 2025-09-03 06:13:45 +00:00
optimize: batch embedding and qdrant write_consistency_factor parameter (#21776)
Co-authored-by: hobo.l <hobo.l@binance.com>
This commit is contained in:
parent
a316766ad7
commit
a371390d6c
@ -47,6 +47,7 @@ class QdrantConfig(BaseModel):
|
|||||||
grpc_port: int = 6334
|
grpc_port: int = 6334
|
||||||
prefer_grpc: bool = False
|
prefer_grpc: bool = False
|
||||||
replication_factor: int = 1
|
replication_factor: int = 1
|
||||||
|
write_consistency_factor: int = 1
|
||||||
|
|
||||||
def to_qdrant_params(self):
|
def to_qdrant_params(self):
|
||||||
if self.endpoint and self.endpoint.startswith("path:"):
|
if self.endpoint and self.endpoint.startswith("path:"):
|
||||||
@ -127,6 +128,7 @@ class QdrantVector(BaseVector):
|
|||||||
hnsw_config=hnsw_config,
|
hnsw_config=hnsw_config,
|
||||||
timeout=int(self._client_config.timeout),
|
timeout=int(self._client_config.timeout),
|
||||||
replication_factor=self._client_config.replication_factor,
|
replication_factor=self._client_config.replication_factor,
|
||||||
|
write_consistency_factor=self._client_config.write_consistency_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
# create group_id payload index
|
# create group_id payload index
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -13,6 +15,8 @@ from extensions.ext_database import db
|
|||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset, Whitelist
|
from models.dataset import Dataset, Whitelist
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AbstractVectorFactory(ABC):
|
class AbstractVectorFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -173,8 +177,20 @@ class Vector:
|
|||||||
|
|
||||||
def create(self, texts: Optional[list] = None, **kwargs):
|
def create(self, texts: Optional[list] = None, **kwargs):
|
||||||
if texts:
|
if texts:
|
||||||
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
|
start = time.time()
|
||||||
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
|
logger.info(f"start embedding {len(texts)} texts {start}")
|
||||||
|
batch_size = 1000
|
||||||
|
total_batches = len(texts) + batch_size - 1
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
batch_start = time.time()
|
||||||
|
logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)")
|
||||||
|
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
|
||||||
|
logger.info(
|
||||||
|
f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s"
|
||||||
|
)
|
||||||
|
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
||||||
|
logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s")
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], **kwargs):
|
def add_texts(self, documents: list[Document], **kwargs):
|
||||||
if kwargs.get("duplicate_check", False):
|
if kwargs.get("duplicate_check", False):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user