fix: Added checks for DataParallel and WrappedDataParallel (#3366)

* Added checks for DataParallel and WrappedDataParallel

* Update isinstance checks according to pylint recommendation

* Using isinstance over types

* Added test for dpr training
This commit is contained in:
Sebastian 2022-10-13 08:05:56 +02:00 committed by GitHub
parent db6e5754cd
commit 75641dd024
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 9 deletions

View File

@ -3,12 +3,14 @@ from typing import Dict, List, Optional, Any
import logging
import numbers
import torch
from torch.nn import DataParallel
import numpy as np
from tqdm import tqdm
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
from haystack.modeling.model.adaptive_model import AdaptiveModel
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
from haystack.modeling.model.optimization import WrappedDataParallel
from haystack.utils.experiment_tracking import Tracker as tracker
from haystack.modeling.visual import BUSH_SEP
@ -70,9 +72,13 @@ class Evaluator:
for step, batch in enumerate(tqdm(self.data_loader, desc="Evaluating", mininterval=10)):
batch = {key: batch[key].to(self.device) for key in batch}
with torch.no_grad():
if isinstance(model, (DataParallel, WrappedDataParallel)):
module = model.module
else:
module = model
if isinstance(model, AdaptiveModel):
with torch.no_grad():
if isinstance(module, AdaptiveModel):
logits = model.forward(
input_ids=batch.get("input_ids", None),
segment_ids=batch.get("segment_ids", None),
@ -80,7 +86,7 @@ class Evaluator:
output_hidden_states=batch.get("output_hidden_states", False),
output_attentions=batch.get("output_attentions", False),
)
elif isinstance(model, BiAdaptiveModel):
elif isinstance(module, BiAdaptiveModel):
logits = model.forward(
query_input_ids=batch.get("query_input_ids", None),
query_segment_ids=batch.get("query_segment_ids", None),

View File

@ -18,7 +18,7 @@ from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataS
from haystack.modeling.evaluation.eval import Evaluator
from haystack.modeling.model.adaptive_model import AdaptiveModel
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
from haystack.modeling.model.optimization import get_scheduler
from haystack.modeling.model.optimization import get_scheduler, WrappedDataParallel
from haystack.modeling.utils import GracefulKiller
from haystack.utils.experiment_tracking import Tracker as tracker
from haystack.utils.early_stopping import EarlyStopping
@ -292,12 +292,17 @@ class Trainer:
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
# Forward & backward pass through model
if isinstance(self.model, AdaptiveModel):
if isinstance(self.model, (DataParallel, WrappedDataParallel)):
module = self.model.module
else:
module = self.model
if isinstance(module, AdaptiveModel):
logits = self.model.forward(
input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"]
)
elif isinstance(self.model, BiAdaptiveModel):
elif isinstance(module, BiAdaptiveModel):
logits = self.model.forward(
query_input_ids=batch["query_input_ids"],
query_segment_ids=batch["query_segment_ids"],

View File

@ -1,6 +1,6 @@
import os
from typing import Tuple
import os
import logging
from pathlib import Path
@ -17,8 +17,11 @@ from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
from haystack.modeling.model.language_model import get_language_model, DPREncoder
from haystack.modeling.model.prediction_head import TextSimilarityHead
from haystack.modeling.model.tokenization import get_tokenizer
from haystack.nodes.retriever.dense import DensePassageRetriever
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
from ..conftest import SAMPLES_PATH
def test_dpr_modules(caplog=None):
if caplog:
@ -970,6 +973,33 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
assert np.array_equal(all_embeddings["query"][0], all_embeddings3["query"][0])
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_dpr_training(document_store, tmp_path):
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
max_seq_len_query=8,
max_seq_len_passage=8,
)
save_dir = f"{tmp_path}/test_dpr_training"
retriever.train(
data_dir=str(SAMPLES_PATH / "dpr"),
train_filename="sample.json",
dev_filename="sample.json",
test_filename="sample.json",
n_epochs=1,
batch_size=1,
grad_acc_steps=1,
save_dir=save_dir,
evaluate_every=10,
embed_title=True,
num_positives=1,
num_hard_negatives=1,
)
# TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed)
# def test_dpr_training():
# batch_size = 1
@ -982,8 +1012,6 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
# use_fast = True
# similarity_function = "dot_product"
#
#
#
# device, n_gpu = initialize_device_settings(use_cuda=False)
#
# query_tokenizer = get_tokenizer(pretrained_model_name_or_path=question_lang_model,