mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 07:29:06 +00:00
Adding multi gpu support for DPR inference (#1414)
* Added support for Multi-GPU inference to DPR including benchmark * fixed multi gpu * added batch size to benchmark to better reflect multi gpu capabilities * remove unnecessary entry in config.json * fixed typos * fixed config name * update benchmark to use DEVICES constant * changed multi gpu parameters and updated docstring * adds silent fallback on cpu * update doc string, warning and config Co-authored-by: Michel Bartels <kontakt@michelbartels.com> Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
1f859694f1
commit
da2e8da561
@ -2,6 +2,7 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from typing import List, Union, Optional
|
||||
import torch
|
||||
from torch.nn import DataParallel
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
@ -52,7 +53,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
infer_tokenizer_classes: bool = False,
|
||||
similarity_function: str = "dot_product",
|
||||
global_loss_buffer_size: int = 150000,
|
||||
progress_bar: bool = True
|
||||
progress_bar: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None
|
||||
):
|
||||
"""
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
@ -82,8 +84,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
:param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
|
||||
:param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
|
||||
:param top_k: How many documents to return per query.
|
||||
:param use_gpu: Whether to use gpu or not
|
||||
:param batch_size: Number of questions or passages to encode at once
|
||||
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
|
||||
:param batch_size: Number of questions or passages to encode at once. In case of multiple gpus, this will be the total batch size.
|
||||
:param embed_title: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
|
||||
This is the approach used in the original paper and is likely to improve performance if your
|
||||
titles contain meaningful information for retrieval (topic, entities etc.) .
|
||||
@ -99,6 +101,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
Increase if errors like "encoded data exceeds max_size ..." come up
|
||||
:param progress_bar: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
@ -108,9 +112,19 @@ class DensePassageRetriever(BaseRetriever):
|
||||
model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
|
||||
top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
|
||||
use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
|
||||
similarity_function=similarity_function, progress_bar=progress_bar,
|
||||
similarity_function=similarity_function, progress_bar=progress_bar, devices=devices
|
||||
)
|
||||
|
||||
if devices is not None:
|
||||
self.devices = devices
|
||||
elif use_gpu and torch.cuda.is_available():
|
||||
self.devices = [torch.device(device) for device in range(torch.cuda.device_count())]
|
||||
else:
|
||||
self.devices = [torch.device("cpu")]
|
||||
|
||||
if batch_size < len(self.devices):
|
||||
logger.warning("Batch size is less than the number of devices. All gpus will not be utilized.")
|
||||
|
||||
self.document_store = document_store
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
@ -125,8 +139,6 @@ class DensePassageRetriever(BaseRetriever):
|
||||
"We recommend you use dot_product instead. "
|
||||
"This can be set when initializing the DocumentStore")
|
||||
|
||||
self.device, _ = initialize_device_settings(use_cuda=use_gpu)
|
||||
|
||||
self.infer_tokenizer_classes = infer_tokenizer_classes
|
||||
tokenizers_default_classes = {
|
||||
"query": "DPRQuestionEncoderTokenizer",
|
||||
@ -171,11 +183,14 @@ class DensePassageRetriever(BaseRetriever):
|
||||
embeds_dropout_prob=0.1,
|
||||
lm1_output_types=["per_sequence"],
|
||||
lm2_output_types=["per_sequence"],
|
||||
device=self.device,
|
||||
device=self.devices[0],
|
||||
)
|
||||
|
||||
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
|
||||
|
||||
if len(self.devices) > 1:
|
||||
self.model = DataParallel(self.model, device_ids=self.devices)
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: Optional[int] = None, index: str = None) -> List[Document]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
@ -234,7 +249,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
with tqdm(total=len(data_loader)*self.batch_size, unit=" Docs", desc=f"Create embeddings", position=1,
|
||||
leave=False, disable=disable_tqdm) as progress_bar:
|
||||
for batch in data_loader:
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
batch = {key: batch[key].to(self.devices[0]) for key in batch}
|
||||
|
||||
# get logits
|
||||
with torch.no_grad():
|
||||
@ -371,7 +386,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
n_batches=len(data_silo.loaders["train"]),
|
||||
n_epochs=n_epochs,
|
||||
grad_acc_steps=grad_acc_steps,
|
||||
device=self.device,
|
||||
device=self.devices[0], # Only use first device while multi-gpu training is not implemented
|
||||
use_amp=use_amp
|
||||
)
|
||||
|
||||
@ -384,7 +399,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
n_gpu=n_gpu,
|
||||
lr_schedule=lr_schedule,
|
||||
evaluate_every=evaluate_every,
|
||||
device=self.device,
|
||||
device=self.devices[0], # Only use first device while multi-gpu training is not implemented
|
||||
use_amp=use_amp
|
||||
)
|
||||
|
||||
@ -395,6 +410,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}")
|
||||
self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}")
|
||||
|
||||
self.model = DataParallel(self.model, device_ids=self.devices)
|
||||
|
||||
def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder",
|
||||
passage_encoder_dir: str = "passage_encoder"):
|
||||
"""
|
||||
|
||||
@ -35,6 +35,8 @@ overview_json = "../../docs/_src/benchmarks/retriever_performance.json"
|
||||
map_json = "../../docs/_src/benchmarks/retriever_map.json"
|
||||
speed_json = "../../docs/_src/benchmarks/retriever_speed.json"
|
||||
|
||||
DEVICES = None
|
||||
|
||||
|
||||
seed = 42
|
||||
random.seed(42)
|
||||
@ -47,7 +49,7 @@ def benchmark_indexing(n_docs_options, retriever_doc_stores, data_dir, filename_
|
||||
logger.info(f"##### Start indexing run: {retriever_name}, {doc_store_name}, {n_docs} docs ##### ")
|
||||
try:
|
||||
doc_store = get_document_store(doc_store_name)
|
||||
retriever = get_retriever(retriever_name, doc_store)
|
||||
retriever = get_retriever(retriever_name, doc_store, DEVICES)
|
||||
docs, _ = prepare_data(data_dir=data_dir,
|
||||
filename_gold=filename_gold,
|
||||
filename_negative=filename_negative,
|
||||
@ -143,7 +145,7 @@ def benchmark_querying(n_docs_options,
|
||||
else:
|
||||
similarity = "dot_product"
|
||||
doc_store = get_document_store(doc_store_name, similarity=similarity)
|
||||
retriever = get_retriever(retriever_name, doc_store)
|
||||
retriever = get_retriever(retriever_name, doc_store, DEVICES)
|
||||
add_precomputed = retriever_name in ["dpr"]
|
||||
# For DPR, precomputed embeddings are loaded from file
|
||||
docs, labels = prepare_data(data_dir=data_dir,
|
||||
|
||||
@ -94,7 +94,7 @@ def get_document_store(document_store_type, similarity='dot_product', index="doc
|
||||
raise Exception(f"No document store fixture for '{document_store_type}'")
|
||||
return document_store
|
||||
|
||||
def get_retriever(retriever_name, doc_store):
|
||||
def get_retriever(retriever_name, doc_store, devices):
|
||||
if retriever_name == "elastic":
|
||||
return ElasticsearchRetriever(doc_store)
|
||||
if retriever_name == "tfidf":
|
||||
@ -104,7 +104,8 @@ def get_retriever(retriever_name, doc_store):
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=True,
|
||||
use_fast_tokenizers=False)
|
||||
use_fast_tokenizers=False,
|
||||
devices=devices)
|
||||
if retriever_name == "sentence_transformers":
|
||||
return EmbeddingRetriever(document_store=doc_store,
|
||||
embedding_model="nq-distilbert-base-v1",
|
||||
@ -166,4 +167,4 @@ def download_from_url(url: str, filepath:Union[str, Path]):
|
||||
logger.info(f"Downloading {url} to {filepath} ")
|
||||
with open(filepath, "wb") as file:
|
||||
http_get(url=url, temp_file=file)
|
||||
return filepath
|
||||
return filepath
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user