fix: Added checks for DataParallel and WrappedDataParallel (#3366)

* Added checks for DataParallel and WrappedDataParallel

* Update isinstance checks according to pylint recommendation

* Using isinstance over types

* Added test for dpr training
This commit is contained in:
Sebastian 2022-10-13 08:05:56 +02:00 committed by GitHub
parent db6e5754cd
commit 75641dd024
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 9 deletions

View File

@ -3,12 +3,14 @@ from typing import Dict, List, Optional, Any
import logging import logging
import numbers import numbers
import torch import torch
from torch.nn import DataParallel
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
from haystack.modeling.model.adaptive_model import AdaptiveModel from haystack.modeling.model.adaptive_model import AdaptiveModel
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
from haystack.modeling.model.optimization import WrappedDataParallel
from haystack.utils.experiment_tracking import Tracker as tracker from haystack.utils.experiment_tracking import Tracker as tracker
from haystack.modeling.visual import BUSH_SEP from haystack.modeling.visual import BUSH_SEP
@ -70,9 +72,13 @@ class Evaluator:
for step, batch in enumerate(tqdm(self.data_loader, desc="Evaluating", mininterval=10)): for step, batch in enumerate(tqdm(self.data_loader, desc="Evaluating", mininterval=10)):
batch = {key: batch[key].to(self.device) for key in batch} batch = {key: batch[key].to(self.device) for key in batch}
with torch.no_grad(): if isinstance(model, (DataParallel, WrappedDataParallel)):
module = model.module
else:
module = model
if isinstance(model, AdaptiveModel): with torch.no_grad():
if isinstance(module, AdaptiveModel):
logits = model.forward( logits = model.forward(
input_ids=batch.get("input_ids", None), input_ids=batch.get("input_ids", None),
segment_ids=batch.get("segment_ids", None), segment_ids=batch.get("segment_ids", None),
@ -80,7 +86,7 @@ class Evaluator:
output_hidden_states=batch.get("output_hidden_states", False), output_hidden_states=batch.get("output_hidden_states", False),
output_attentions=batch.get("output_attentions", False), output_attentions=batch.get("output_attentions", False),
) )
elif isinstance(model, BiAdaptiveModel): elif isinstance(module, BiAdaptiveModel):
logits = model.forward( logits = model.forward(
query_input_ids=batch.get("query_input_ids", None), query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None), query_segment_ids=batch.get("query_segment_ids", None),

View File

@ -18,7 +18,7 @@ from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataS
from haystack.modeling.evaluation.eval import Evaluator from haystack.modeling.evaluation.eval import Evaluator
from haystack.modeling.model.adaptive_model import AdaptiveModel from haystack.modeling.model.adaptive_model import AdaptiveModel
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
from haystack.modeling.model.optimization import get_scheduler from haystack.modeling.model.optimization import get_scheduler, WrappedDataParallel
from haystack.modeling.utils import GracefulKiller from haystack.modeling.utils import GracefulKiller
from haystack.utils.experiment_tracking import Tracker as tracker from haystack.utils.experiment_tracking import Tracker as tracker
from haystack.utils.early_stopping import EarlyStopping from haystack.utils.early_stopping import EarlyStopping
@ -292,12 +292,17 @@ class Trainer:
def compute_loss(self, batch: dict, step: int) -> torch.Tensor: def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
# Forward & backward pass through model # Forward & backward pass through model
if isinstance(self.model, AdaptiveModel): if isinstance(self.model, (DataParallel, WrappedDataParallel)):
module = self.model.module
else:
module = self.model
if isinstance(module, AdaptiveModel):
logits = self.model.forward( logits = self.model.forward(
input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"] input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"]
) )
elif isinstance(self.model, BiAdaptiveModel): elif isinstance(module, BiAdaptiveModel):
logits = self.model.forward( logits = self.model.forward(
query_input_ids=batch["query_input_ids"], query_input_ids=batch["query_input_ids"],
query_segment_ids=batch["query_segment_ids"], query_segment_ids=batch["query_segment_ids"],

View File

@ -1,6 +1,6 @@
import os
from typing import Tuple from typing import Tuple
import os
import logging import logging
from pathlib import Path from pathlib import Path
@ -17,8 +17,11 @@ from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
from haystack.modeling.model.language_model import get_language_model, DPREncoder from haystack.modeling.model.language_model import get_language_model, DPREncoder
from haystack.modeling.model.prediction_head import TextSimilarityHead from haystack.modeling.model.prediction_head import TextSimilarityHead
from haystack.modeling.model.tokenization import get_tokenizer from haystack.modeling.model.tokenization import get_tokenizer
from haystack.nodes.retriever.dense import DensePassageRetriever
from haystack.modeling.utils import set_all_seeds, initialize_device_settings from haystack.modeling.utils import set_all_seeds, initialize_device_settings
from ..conftest import SAMPLES_PATH
def test_dpr_modules(caplog=None): def test_dpr_modules(caplog=None):
if caplog: if caplog:
@ -970,6 +973,33 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
assert np.array_equal(all_embeddings["query"][0], all_embeddings3["query"][0]) assert np.array_equal(all_embeddings["query"][0], all_embeddings3["query"][0])
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_dpr_training(document_store, tmp_path):
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
max_seq_len_query=8,
max_seq_len_passage=8,
)
save_dir = f"{tmp_path}/test_dpr_training"
retriever.train(
data_dir=str(SAMPLES_PATH / "dpr"),
train_filename="sample.json",
dev_filename="sample.json",
test_filename="sample.json",
n_epochs=1,
batch_size=1,
grad_acc_steps=1,
save_dir=save_dir,
evaluate_every=10,
embed_title=True,
num_positives=1,
num_hard_negatives=1,
)
# TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed) # TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed)
# def test_dpr_training(): # def test_dpr_training():
# batch_size = 1 # batch_size = 1
@ -982,8 +1012,6 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
# use_fast = True # use_fast = True
# similarity_function = "dot_product" # similarity_function = "dot_product"
# #
#
#
# device, n_gpu = initialize_device_settings(use_cuda=False) # device, n_gpu = initialize_device_settings(use_cuda=False)
# #
# query_tokenizer = get_tokenizer(pretrained_model_name_or_path=question_lang_model, # query_tokenizer = get_tokenizer(pretrained_model_name_or_path=question_lang_model,