mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-03 18:36:04 +00:00
Prevent wrapping DataParallel in second DataParallel (#1855)
* Prevent wrapping DataParallel in second DataParallel * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
8cb513c2c6
commit
cbfe2b4626
@ -64,7 +64,9 @@ class SentenceTransformersRanker(BaseRanker)
|
||||
Sentence Transformer based pre-trained Cross-Encoder model for Document Re-ranking (https://huggingface.co/cross-encoder).
|
||||
Re-Ranking can be used on top of a retriever to boost the performance for document search. This is particularly useful if the retriever has a high recall but is bad in sorting the documents by relevance.
|
||||
|
||||
SentenceTransformerRanker handles Cross-Encoder models that use a single logit as similarity score.
|
||||
SentenceTransformerRanker handles Cross-Encoder models
|
||||
- use a single logit as similarity score e.g. cross-encoder/ms-marco-MiniLM-L-12-v2
|
||||
- use two output logits (no_answer, has_answer) e.g. deepset/gbert-base-germandpr-reranking
|
||||
https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#usage-with-transformers
|
||||
|
||||
| With a SentenceTransformersRanker, you can:
|
||||
|
||||
@ -303,7 +303,7 @@ def optimize_model(model, device, local_rank, optimizer=None, distributed=False,
|
||||
find_unused_parameters=True)
|
||||
|
||||
elif torch.cuda.device_count() > 1 and device.type == "cuda":
|
||||
model = WrappedDataParallel(model)
|
||||
model = WrappedDataParallel(model) if not isinstance(model, DataParallel) else WrappedDataParallel(model.module)
|
||||
logger.info("Multi-GPU Training via DataParallel")
|
||||
|
||||
return model, optimizer
|
||||
|
||||
@ -413,7 +413,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}")
|
||||
self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}")
|
||||
|
||||
self.model = DataParallel(self.model, device_ids=self.devices)
|
||||
if len(self.devices) > 1 and not isinstance(self.model, DataParallel):
|
||||
self.model = DataParallel(self.model, device_ids=self.devices)
|
||||
|
||||
def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder",
|
||||
passage_encoder_dir: str = "passage_encoder"):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user