mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 10:49:30 +00:00
fix: Added checks for DataParallel and WrappedDataParallel (#3366)
* Added checks for DataParallel and WrappedDataParallel * Update isinstance checks according to pylint recommendation * Using isinstance over types * Added test for dpr training
This commit is contained in:
parent
db6e5754cd
commit
75641dd024
@ -3,12 +3,14 @@ from typing import Dict, List, Optional, Any
|
||||
import logging
|
||||
import numbers
|
||||
import torch
|
||||
from torch.nn import DataParallel
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
||||
from haystack.modeling.model.optimization import WrappedDataParallel
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
from haystack.modeling.visual import BUSH_SEP
|
||||
|
||||
@ -70,9 +72,13 @@ class Evaluator:
|
||||
for step, batch in enumerate(tqdm(self.data_loader, desc="Evaluating", mininterval=10)):
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(model, (DataParallel, WrappedDataParallel)):
|
||||
module = model.module
|
||||
else:
|
||||
module = model
|
||||
|
||||
if isinstance(model, AdaptiveModel):
|
||||
with torch.no_grad():
|
||||
if isinstance(module, AdaptiveModel):
|
||||
logits = model.forward(
|
||||
input_ids=batch.get("input_ids", None),
|
||||
segment_ids=batch.get("segment_ids", None),
|
||||
@ -80,7 +86,7 @@ class Evaluator:
|
||||
output_hidden_states=batch.get("output_hidden_states", False),
|
||||
output_attentions=batch.get("output_attentions", False),
|
||||
)
|
||||
elif isinstance(model, BiAdaptiveModel):
|
||||
elif isinstance(module, BiAdaptiveModel):
|
||||
logits = model.forward(
|
||||
query_input_ids=batch.get("query_input_ids", None),
|
||||
query_segment_ids=batch.get("query_segment_ids", None),
|
||||
|
||||
@ -18,7 +18,7 @@ from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataS
|
||||
from haystack.modeling.evaluation.eval import Evaluator
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
||||
from haystack.modeling.model.optimization import get_scheduler
|
||||
from haystack.modeling.model.optimization import get_scheduler, WrappedDataParallel
|
||||
from haystack.modeling.utils import GracefulKiller
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
from haystack.utils.early_stopping import EarlyStopping
|
||||
@ -292,12 +292,17 @@ class Trainer:
|
||||
|
||||
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
||||
# Forward & backward pass through model
|
||||
if isinstance(self.model, AdaptiveModel):
|
||||
if isinstance(self.model, (DataParallel, WrappedDataParallel)):
|
||||
module = self.model.module
|
||||
else:
|
||||
module = self.model
|
||||
|
||||
if isinstance(module, AdaptiveModel):
|
||||
logits = self.model.forward(
|
||||
input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"]
|
||||
)
|
||||
|
||||
elif isinstance(self.model, BiAdaptiveModel):
|
||||
elif isinstance(module, BiAdaptiveModel):
|
||||
logits = self.model.forward(
|
||||
query_input_ids=batch["query_input_ids"],
|
||||
query_segment_ids=batch["query_segment_ids"],
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
@ -17,8 +17,11 @@ from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
||||
from haystack.modeling.model.language_model import get_language_model, DPREncoder
|
||||
from haystack.modeling.model.prediction_head import TextSimilarityHead
|
||||
from haystack.modeling.model.tokenization import get_tokenizer
|
||||
from haystack.nodes.retriever.dense import DensePassageRetriever
|
||||
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
|
||||
|
||||
from ..conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
def test_dpr_modules(caplog=None):
|
||||
if caplog:
|
||||
@ -970,6 +973,33 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
|
||||
assert np.array_equal(all_embeddings["query"][0], all_embeddings3["query"][0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
def test_dpr_training(document_store, tmp_path):
|
||||
retriever = DensePassageRetriever(
|
||||
document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
max_seq_len_query=8,
|
||||
max_seq_len_passage=8,
|
||||
)
|
||||
|
||||
save_dir = f"{tmp_path}/test_dpr_training"
|
||||
retriever.train(
|
||||
data_dir=str(SAMPLES_PATH / "dpr"),
|
||||
train_filename="sample.json",
|
||||
dev_filename="sample.json",
|
||||
test_filename="sample.json",
|
||||
n_epochs=1,
|
||||
batch_size=1,
|
||||
grad_acc_steps=1,
|
||||
save_dir=save_dir,
|
||||
evaluate_every=10,
|
||||
embed_title=True,
|
||||
num_positives=1,
|
||||
num_hard_negatives=1,
|
||||
)
|
||||
|
||||
|
||||
# TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed)
|
||||
# def test_dpr_training():
|
||||
# batch_size = 1
|
||||
@ -982,8 +1012,6 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
|
||||
# use_fast = True
|
||||
# similarity_function = "dot_product"
|
||||
#
|
||||
#
|
||||
#
|
||||
# device, n_gpu = initialize_device_settings(use_cuda=False)
|
||||
#
|
||||
# query_tokenizer = get_tokenizer(pretrained_model_name_or_path=question_lang_model,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user