Adding multi gpu support for DPR inference (#1414)

* Added support for Multi-GPU inference to DPR including benchmark

* fixed multi gpu

* added batch size to benchmark to better reflect multi gpu capabilities

* remove unnecessary entry in config.json

* fixed typos

* fixed config name

* update benchmark to use DEVICES constant

* changed multi gpu parameters and updated docstring

* adds silent fallback on cpu

* update doc string, warning and config

Co-authored-by: Michel Bartels <kontakt@michelbartels.com>
Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
MichelBartels 2021-09-10 13:25:02 +02:00 committed by GitHub
parent 1f859694f1
commit da2e8da561
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 15 deletions

View File

@ -2,6 +2,7 @@ import logging
from abc import abstractmethod
from typing import List, Union, Optional
import torch
from torch.nn import DataParallel
import numpy as np
from pathlib import Path
@ -52,7 +53,8 @@ class DensePassageRetriever(BaseRetriever):
infer_tokenizer_classes: bool = False,
similarity_function: str = "dot_product",
global_loss_buffer_size: int = 150000,
progress_bar: bool = True
progress_bar: bool = True,
devices: Optional[List[Union[int, str, torch.device]]] = None
):
"""
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@ -82,8 +84,8 @@ class DensePassageRetriever(BaseRetriever):
:param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
:param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
:param top_k: How many documents to return per query.
:param use_gpu: Whether to use gpu or not
:param batch_size: Number of questions or passages to encode at once
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
:param batch_size: Number of questions or passages to encode at once. In case of multiple gpus, this will be the total batch size.
:param embed_title: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
This is the approach used in the original paper and is likely to improve performance if your
titles contain meaningful information for retrieval (topic, entities etc.) .
@ -99,6 +101,8 @@ class DensePassageRetriever(BaseRetriever):
Increase if errors like "encoded data exceeds max_size ..." come up
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
"""
# save init parameters to enable export of component config as YAML
@ -108,9 +112,19 @@ class DensePassageRetriever(BaseRetriever):
model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
similarity_function=similarity_function, progress_bar=progress_bar,
similarity_function=similarity_function, progress_bar=progress_bar, devices=devices
)
if devices is not None:
self.devices = devices
elif use_gpu and torch.cuda.is_available():
self.devices = [torch.device(device) for device in range(torch.cuda.device_count())]
else:
self.devices = [torch.device("cpu")]
if batch_size < len(self.devices):
logger.warning("Batch size is less than the number of devices. All gpus will not be utilized.")
self.document_store = document_store
self.batch_size = batch_size
self.progress_bar = progress_bar
@ -125,8 +139,6 @@ class DensePassageRetriever(BaseRetriever):
"We recommend you use dot_product instead. "
"This can be set when initializing the DocumentStore")
self.device, _ = initialize_device_settings(use_cuda=use_gpu)
self.infer_tokenizer_classes = infer_tokenizer_classes
tokenizers_default_classes = {
"query": "DPRQuestionEncoderTokenizer",
@ -171,11 +183,14 @@ class DensePassageRetriever(BaseRetriever):
embeds_dropout_prob=0.1,
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
device=self.device,
device=self.devices[0],
)
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
if len(self.devices) > 1:
self.model = DataParallel(self.model, device_ids=self.devices)
def retrieve(self, query: str, filters: dict = None, top_k: Optional[int] = None, index: str = None) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
@ -234,7 +249,7 @@ class DensePassageRetriever(BaseRetriever):
with tqdm(total=len(data_loader)*self.batch_size, unit=" Docs", desc=f"Create embeddings", position=1,
leave=False, disable=disable_tqdm) as progress_bar:
for batch in data_loader:
batch = {key: batch[key].to(self.device) for key in batch}
batch = {key: batch[key].to(self.devices[0]) for key in batch}
# get logits
with torch.no_grad():
@ -371,7 +386,7 @@ class DensePassageRetriever(BaseRetriever):
n_batches=len(data_silo.loaders["train"]),
n_epochs=n_epochs,
grad_acc_steps=grad_acc_steps,
device=self.device,
device=self.devices[0], # Only use first device while multi-gpu training is not implemented
use_amp=use_amp
)
@ -384,7 +399,7 @@ class DensePassageRetriever(BaseRetriever):
n_gpu=n_gpu,
lr_schedule=lr_schedule,
evaluate_every=evaluate_every,
device=self.device,
device=self.devices[0], # Only use first device while multi-gpu training is not implemented
use_amp=use_amp
)
@ -395,6 +410,8 @@ class DensePassageRetriever(BaseRetriever):
self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}")
self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}")
self.model = DataParallel(self.model, device_ids=self.devices)
def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder",
passage_encoder_dir: str = "passage_encoder"):
"""

View File

@ -35,6 +35,8 @@ overview_json = "../../docs/_src/benchmarks/retriever_performance.json"
map_json = "../../docs/_src/benchmarks/retriever_map.json"
speed_json = "../../docs/_src/benchmarks/retriever_speed.json"
DEVICES = None
seed = 42
random.seed(42)
@ -47,7 +49,7 @@ def benchmark_indexing(n_docs_options, retriever_doc_stores, data_dir, filename_
logger.info(f"##### Start indexing run: {retriever_name}, {doc_store_name}, {n_docs} docs ##### ")
try:
doc_store = get_document_store(doc_store_name)
retriever = get_retriever(retriever_name, doc_store)
retriever = get_retriever(retriever_name, doc_store, DEVICES)
docs, _ = prepare_data(data_dir=data_dir,
filename_gold=filename_gold,
filename_negative=filename_negative,
@ -143,7 +145,7 @@ def benchmark_querying(n_docs_options,
else:
similarity = "dot_product"
doc_store = get_document_store(doc_store_name, similarity=similarity)
retriever = get_retriever(retriever_name, doc_store)
retriever = get_retriever(retriever_name, doc_store, DEVICES)
add_precomputed = retriever_name in ["dpr"]
# For DPR, precomputed embeddings are loaded from file
docs, labels = prepare_data(data_dir=data_dir,

View File

@ -94,7 +94,7 @@ def get_document_store(document_store_type, similarity='dot_product', index="doc
raise Exception(f"No document store fixture for '{document_store_type}'")
return document_store
def get_retriever(retriever_name, doc_store):
def get_retriever(retriever_name, doc_store, devices):
if retriever_name == "elastic":
return ElasticsearchRetriever(doc_store)
if retriever_name == "tfidf":
@ -104,7 +104,8 @@ def get_retriever(retriever_name, doc_store):
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True,
use_fast_tokenizers=False)
use_fast_tokenizers=False,
devices=devices)
if retriever_name == "sentence_transformers":
return EmbeddingRetriever(document_store=doc_store,
embedding_model="nq-distilbert-base-v1",
@ -166,4 +167,4 @@ def download_from_url(url: str, filepath:Union[str, Path]):
logger.info(f"Downloading {url} to {filepath} ")
with open(filepath, "wb") as file:
http_get(url=url, temp_file=file)
return filepath
return filepath