Prevent wrapping DataParallel in second DataParallel (#1855)

* Prevent wrapping DataParallel in second DataParallel

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
bogdankostic 2021-12-08 09:56:45 +01:00 committed by GitHub
parent 8cb513c2c6
commit cbfe2b4626
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 3 deletions

View File

@ -64,7 +64,9 @@ class SentenceTransformersRanker(BaseRanker)
Sentence Transformer based pre-trained Cross-Encoder model for Document Re-ranking (https://huggingface.co/cross-encoder).
Re-Ranking can be used on top of a retriever to boost the performance for document search. This is particularly useful if the retriever has a high recall but is bad in sorting the documents by relevance.
SentenceTransformerRanker handles Cross-Encoder models that use a single logit as similarity score.
SentenceTransformerRanker handles Cross-Encoder models
- use a single logit as similarity score e.g. cross-encoder/ms-marco-MiniLM-L-12-v2
- use two output logits (no_answer, has_answer) e.g. deepset/gbert-base-germandpr-reranking
https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#usage-with-transformers
| With a SentenceTransformersRanker, you can:

View File

@ -303,7 +303,7 @@ def optimize_model(model, device, local_rank, optimizer=None, distributed=False,
find_unused_parameters=True)
elif torch.cuda.device_count() > 1 and device.type == "cuda":
model = WrappedDataParallel(model)
model = WrappedDataParallel(model) if not isinstance(model, DataParallel) else WrappedDataParallel(model.module)
logger.info("Multi-GPU Training via DataParallel")
return model, optimizer

View File

@ -413,7 +413,8 @@ class DensePassageRetriever(BaseRetriever):
self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}")
self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}")
self.model = DataParallel(self.model, device_ids=self.devices)
if len(self.devices) > 1 and not isinstance(self.model, DataParallel):
self.model = DataParallel(self.model, device_ids=self.devices)
def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder",
passage_encoder_dir: str = "passage_encoder"):