mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
Added max_seq_length and batch_size params to embeddingretriever (#1817)
* Added max_seq_length and batch_size params, added progress_bar to faiss writing_documents * Add latest docstring and tutorial changes * fixed typos * Update dense.py Changed default batch_size and max_seq_len in EmbeddingRetriever * Add latest docstring and tutorial changes * Update faiss.py Change import tqdm.auto to tqdm * Update faiss.py Changing tqdm back to tqdm.auto Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
fb511dc4a3
commit
56e4e8486f
@ -611,7 +611,7 @@ class EmbeddingRetriever(BaseRetriever)
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None)
|
||||
| __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -620,6 +620,8 @@ class EmbeddingRetriever(BaseRetriever)
|
||||
- `embedding_model`: Local path or name of model in Hugging Face's model hub such as ``'sentence-transformers/all-MiniLM-L6-v2'``
|
||||
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
- `use_gpu`: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
|
||||
- `batch_size`: Number of documents to encode at once.
|
||||
- `max_seq_len`: Longest length of each document sequence. Maximum number of tokens for the document text. Longer ones will be cut down.
|
||||
- `model_format`: Name of framework that was used for saving the model. Options:
|
||||
|
||||
- ``'farm'``
|
||||
|
||||
@ -7,7 +7,7 @@ import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union, List, Optional, Dict, Generator
|
||||
from tqdm import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
try:
|
||||
import faiss
|
||||
@ -233,26 +233,29 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
"Please call `update_embeddings` method to repopulate `faiss_index`")
|
||||
|
||||
vector_id = self.faiss_indexes[index].ntotal
|
||||
for i in range(0, len(document_objects), batch_size):
|
||||
if add_vectors:
|
||||
embeddings = [doc.embedding for doc in document_objects[i: i + batch_size]]
|
||||
embeddings_to_index = np.array(embeddings, dtype="float32")
|
||||
|
||||
if self.similarity=="cosine": self.normalize_embedding(embeddings_to_index)
|
||||
|
||||
self.faiss_indexes[index].add(embeddings_to_index)
|
||||
|
||||
docs_to_write_in_sql = []
|
||||
for doc in document_objects[i: i + batch_size]:
|
||||
meta = doc.meta
|
||||
with tqdm(total = len(document_objects), disable =not self.progress_bar, position=0,
|
||||
desc="Writing Documents") as progress_bar:
|
||||
for i in range(0, len(document_objects), batch_size):
|
||||
if add_vectors:
|
||||
meta["vector_id"] = vector_id
|
||||
vector_id += 1
|
||||
docs_to_write_in_sql.append(doc)
|
||||
embeddings = [doc.embedding for doc in document_objects[i: i + batch_size]]
|
||||
embeddings_to_index = np.array(embeddings, dtype="float32")
|
||||
|
||||
super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index,
|
||||
if self.similarity=="cosine": self.normalize_embedding(embeddings_to_index)
|
||||
|
||||
self.faiss_indexes[index].add(embeddings_to_index)
|
||||
|
||||
docs_to_write_in_sql = []
|
||||
for doc in document_objects[i: i + batch_size]:
|
||||
meta = doc.meta
|
||||
if add_vectors:
|
||||
meta["vector_id"] = vector_id
|
||||
vector_id += 1
|
||||
docs_to_write_in_sql.append(doc)
|
||||
|
||||
super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index,
|
||||
duplicate_documents=duplicate_documents)
|
||||
|
||||
progress_bar.update(batch_size)
|
||||
progress_bar.close()
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
return {
|
||||
self.index: self.embedding_field,
|
||||
|
||||
@ -55,7 +55,7 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
retriever.embedding_model, revision=retriever.model_version, task_type="embeddings",
|
||||
extraction_strategy=retriever.pooling_strategy,
|
||||
extraction_layer=retriever.emb_extraction_layer, gpu=retriever.use_gpu,
|
||||
batch_size=4, max_seq_len=512, num_processes=0,use_auth_token=retriever.use_auth_token
|
||||
batch_size=retriever.batch_size, max_seq_len=retriever.max_seq_len, num_processes=0,use_auth_token=retriever.use_auth_token
|
||||
)
|
||||
# Check that document_store has the right similarity function
|
||||
similarity = retriever.document_store.similarity
|
||||
@ -98,6 +98,8 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
# pretrained embedding models coming from: https://github.com/UKPLab/sentence-transformers#pretrained-models
|
||||
# e.g. 'roberta-base-nli-stsb-mean-tokens'
|
||||
self.embedding_model = SentenceTransformer(retriever.embedding_model, device=str(retriever.devices[0]))
|
||||
self.batch_size = retriever.batch_size
|
||||
self.embedding_model.max_seq_length = retriever.max_seq_len
|
||||
self.show_progress_bar = retriever.progress_bar
|
||||
document_store = retriever.document_store
|
||||
if document_store.similarity != "cosine":
|
||||
@ -109,7 +111,7 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]:
|
||||
# texts can be a list of strings or a list of [title, text]
|
||||
# get back list of numpy embedding vectors
|
||||
emb = self.embedding_model.encode(texts, batch_size=200, show_progress_bar=self.show_progress_bar)
|
||||
emb = self.embedding_model.encode(texts, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar)
|
||||
emb = [r for r in emb]
|
||||
return emb
|
||||
|
||||
@ -129,9 +131,11 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
):
|
||||
|
||||
self.progress_bar = retriever.progress_bar
|
||||
|
||||
self.batch_size = retriever.batch_size
|
||||
self.max_length = retriever.max_seq_len
|
||||
self.embedding_tokenizer = AutoTokenizer.from_pretrained(retriever.embedding_model)
|
||||
self.embedding_model = AutoModel.from_pretrained(retriever.embedding_model).to(str(retriever.devices[0]))
|
||||
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
|
||||
|
||||
@ -171,7 +175,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
|
||||
dataset, tensor_names = self.dataset_from_dicts(text_to_encode)
|
||||
dataloader = NamedDataLoader(dataset=dataset, sampler=SequentialSampler(dataset),
|
||||
batch_size=32, tensor_names=tensor_names)
|
||||
batch_size=self.batch_size, tensor_names=tensor_names)
|
||||
return dataloader
|
||||
|
||||
def dataset_from_dicts(self, dicts: List[dict]):
|
||||
@ -180,6 +184,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
texts,
|
||||
return_token_type_ids=True,
|
||||
return_attention_mask=True,
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
padding=True
|
||||
)
|
||||
|
||||
@ -955,6 +955,8 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
embedding_model: str,
|
||||
model_version: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
batch_size: int = 32,
|
||||
max_seq_len: int = 512,
|
||||
model_format: str = "farm",
|
||||
pooling_strategy: str = "reduce_mean",
|
||||
emb_extraction_layer: int = -1,
|
||||
@ -968,6 +970,8 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
:param embedding_model: Local path or name of model in Hugging Face's model hub such as ``'sentence-transformers/all-MiniLM-L6-v2'``
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
|
||||
:param batch_size: Number of documents to encode at once.
|
||||
:param max_seq_len: Longest length of each document sequence. Maximum number of tokens for the document text. Longer ones will be cut down.
|
||||
:param model_format: Name of framework that was used for saving the model. Options:
|
||||
|
||||
- ``'farm'``
|
||||
@ -993,7 +997,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
document_store=document_store, embedding_model=embedding_model, model_version=model_version,
|
||||
use_gpu=use_gpu, model_format=model_format, pooling_strategy=pooling_strategy,
|
||||
use_gpu=use_gpu, batch_size=batch_size, max_seq_len=max_seq_len, model_format=model_format, pooling_strategy=pooling_strategy,
|
||||
emb_extraction_layer=emb_extraction_layer, top_k=top_k,
|
||||
)
|
||||
|
||||
@ -1001,12 +1005,17 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
self.devices = devices
|
||||
else:
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True)
|
||||
|
||||
if batch_size < len(self.devices):
|
||||
logger.warning("Batch size is less than the number of devices. All gpus will not be utilized.")
|
||||
|
||||
self.document_store = document_store
|
||||
self.embedding_model = embedding_model
|
||||
self.model_format = model_format
|
||||
self.model_version = model_version
|
||||
self.use_gpu = use_gpu
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.pooling_strategy = pooling_strategy
|
||||
self.emb_extraction_layer = emb_extraction_layer
|
||||
self.top_k = top_k
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user