From 75641dd024976591b96285e6eb7cded6254f87d5 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Thu, 13 Oct 2022 08:05:56 +0200 Subject: [PATCH] fix: Added checks for DataParallel and WrappedDataParallel (#3366) * Added checks for DataParallel and WrappedDataParallel * Update isinstance checks according to pylint recommendation * Using isinstance over types * Added test for dpr training --- haystack/modeling/evaluation/eval.py | 12 +++++++--- haystack/modeling/training/base.py | 11 ++++++--- test/modeling/test_dpr.py | 34 +++++++++++++++++++++++++--- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/haystack/modeling/evaluation/eval.py b/haystack/modeling/evaluation/eval.py index 4df198b2d..b12593dd6 100644 --- a/haystack/modeling/evaluation/eval.py +++ b/haystack/modeling/evaluation/eval.py @@ -3,12 +3,14 @@ from typing import Dict, List, Optional, Any import logging import numbers import torch +from torch.nn import DataParallel import numpy as np from tqdm import tqdm from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics from haystack.modeling.model.adaptive_model import AdaptiveModel from haystack.modeling.model.biadaptive_model import BiAdaptiveModel +from haystack.modeling.model.optimization import WrappedDataParallel from haystack.utils.experiment_tracking import Tracker as tracker from haystack.modeling.visual import BUSH_SEP @@ -70,9 +72,13 @@ class Evaluator: for step, batch in enumerate(tqdm(self.data_loader, desc="Evaluating", mininterval=10)): batch = {key: batch[key].to(self.device) for key in batch} - with torch.no_grad(): + if isinstance(model, (DataParallel, WrappedDataParallel)): + module = model.module + else: + module = model - if isinstance(model, AdaptiveModel): + with torch.no_grad(): + if isinstance(module, AdaptiveModel): logits = model.forward( input_ids=batch.get("input_ids", None), segment_ids=batch.get("segment_ids", None), @@ -80,7 +86,7 @@ class Evaluator: output_hidden_states=batch.get("output_hidden_states", False), output_attentions=batch.get("output_attentions", False), ) - elif isinstance(model, BiAdaptiveModel): + elif isinstance(module, BiAdaptiveModel): logits = model.forward( query_input_ids=batch.get("query_input_ids", None), query_segment_ids=batch.get("query_segment_ids", None), diff --git a/haystack/modeling/training/base.py b/haystack/modeling/training/base.py index e278ac4f7..a761a62b3 100644 --- a/haystack/modeling/training/base.py +++ b/haystack/modeling/training/base.py @@ -18,7 +18,7 @@ from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataS from haystack.modeling.evaluation.eval import Evaluator from haystack.modeling.model.adaptive_model import AdaptiveModel from haystack.modeling.model.biadaptive_model import BiAdaptiveModel -from haystack.modeling.model.optimization import get_scheduler +from haystack.modeling.model.optimization import get_scheduler, WrappedDataParallel from haystack.modeling.utils import GracefulKiller from haystack.utils.experiment_tracking import Tracker as tracker from haystack.utils.early_stopping import EarlyStopping @@ -292,12 +292,17 @@ class Trainer: def compute_loss(self, batch: dict, step: int) -> torch.Tensor: # Forward & backward pass through model - if isinstance(self.model, AdaptiveModel): + if isinstance(self.model, (DataParallel, WrappedDataParallel)): + module = self.model.module + else: + module = self.model + + if isinstance(module, AdaptiveModel): logits = self.model.forward( input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"] ) - elif isinstance(self.model, BiAdaptiveModel): + elif isinstance(module, BiAdaptiveModel): logits = self.model.forward( query_input_ids=batch["query_input_ids"], query_segment_ids=batch["query_segment_ids"], diff --git a/test/modeling/test_dpr.py b/test/modeling/test_dpr.py index af1cf0e91..7cc6d0119 100644 --- a/test/modeling/test_dpr.py +++ b/test/modeling/test_dpr.py @@ -1,6 +1,6 @@ +import os from typing import Tuple -import os import logging from pathlib import Path @@ -17,8 +17,11 @@ from haystack.modeling.model.biadaptive_model import BiAdaptiveModel from haystack.modeling.model.language_model import get_language_model, DPREncoder from haystack.modeling.model.prediction_head import TextSimilarityHead from haystack.modeling.model.tokenization import get_tokenizer +from haystack.nodes.retriever.dense import DensePassageRetriever from haystack.modeling.utils import set_all_seeds, initialize_device_settings +from ..conftest import SAMPLES_PATH + def test_dpr_modules(caplog=None): if caplog: @@ -970,6 +973,33 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa assert np.array_equal(all_embeddings["query"][0], all_embeddings3["query"][0]) +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +def test_dpr_training(document_store, tmp_path): + retriever = DensePassageRetriever( + document_store=document_store, + query_embedding_model="facebook/dpr-question_encoder-single-nq-base", + passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", + max_seq_len_query=8, + max_seq_len_passage=8, + ) + + save_dir = f"{tmp_path}/test_dpr_training" + retriever.train( + data_dir=str(SAMPLES_PATH / "dpr"), + train_filename="sample.json", + dev_filename="sample.json", + test_filename="sample.json", + n_epochs=1, + batch_size=1, + grad_acc_steps=1, + save_dir=save_dir, + evaluate_every=10, + embed_title=True, + num_positives=1, + num_hard_negatives=1, + ) + + # TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed) # def test_dpr_training(): # batch_size = 1 @@ -982,8 +1012,6 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa # use_fast = True # similarity_function = "dot_product" # -# -# # device, n_gpu = initialize_device_settings(use_cuda=False) # # query_tokenizer = get_tokenizer(pretrained_model_name_or_path=question_lang_model,