mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-09 14:23:43 +00:00
fix: Added checks for DataParallel and WrappedDataParallel (#3366)
* Added checks for DataParallel and WrappedDataParallel * Update isinstance checks according to pylint recommendation * Using isinstance over types * Added test for dpr training
This commit is contained in:
parent
db6e5754cd
commit
75641dd024
@ -3,12 +3,14 @@ from typing import Dict, List, Optional, Any
|
|||||||
import logging
|
import logging
|
||||||
import numbers
|
import numbers
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import DataParallel
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
|
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
|
||||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||||
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
||||||
|
from haystack.modeling.model.optimization import WrappedDataParallel
|
||||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||||
from haystack.modeling.visual import BUSH_SEP
|
from haystack.modeling.visual import BUSH_SEP
|
||||||
|
|
||||||
@ -70,9 +72,13 @@ class Evaluator:
|
|||||||
for step, batch in enumerate(tqdm(self.data_loader, desc="Evaluating", mininterval=10)):
|
for step, batch in enumerate(tqdm(self.data_loader, desc="Evaluating", mininterval=10)):
|
||||||
batch = {key: batch[key].to(self.device) for key in batch}
|
batch = {key: batch[key].to(self.device) for key in batch}
|
||||||
|
|
||||||
with torch.no_grad():
|
if isinstance(model, (DataParallel, WrappedDataParallel)):
|
||||||
|
module = model.module
|
||||||
|
else:
|
||||||
|
module = model
|
||||||
|
|
||||||
if isinstance(model, AdaptiveModel):
|
with torch.no_grad():
|
||||||
|
if isinstance(module, AdaptiveModel):
|
||||||
logits = model.forward(
|
logits = model.forward(
|
||||||
input_ids=batch.get("input_ids", None),
|
input_ids=batch.get("input_ids", None),
|
||||||
segment_ids=batch.get("segment_ids", None),
|
segment_ids=batch.get("segment_ids", None),
|
||||||
@ -80,7 +86,7 @@ class Evaluator:
|
|||||||
output_hidden_states=batch.get("output_hidden_states", False),
|
output_hidden_states=batch.get("output_hidden_states", False),
|
||||||
output_attentions=batch.get("output_attentions", False),
|
output_attentions=batch.get("output_attentions", False),
|
||||||
)
|
)
|
||||||
elif isinstance(model, BiAdaptiveModel):
|
elif isinstance(module, BiAdaptiveModel):
|
||||||
logits = model.forward(
|
logits = model.forward(
|
||||||
query_input_ids=batch.get("query_input_ids", None),
|
query_input_ids=batch.get("query_input_ids", None),
|
||||||
query_segment_ids=batch.get("query_segment_ids", None),
|
query_segment_ids=batch.get("query_segment_ids", None),
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataS
|
|||||||
from haystack.modeling.evaluation.eval import Evaluator
|
from haystack.modeling.evaluation.eval import Evaluator
|
||||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||||
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
||||||
from haystack.modeling.model.optimization import get_scheduler
|
from haystack.modeling.model.optimization import get_scheduler, WrappedDataParallel
|
||||||
from haystack.modeling.utils import GracefulKiller
|
from haystack.modeling.utils import GracefulKiller
|
||||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||||
from haystack.utils.early_stopping import EarlyStopping
|
from haystack.utils.early_stopping import EarlyStopping
|
||||||
@ -292,12 +292,17 @@ class Trainer:
|
|||||||
|
|
||||||
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
||||||
# Forward & backward pass through model
|
# Forward & backward pass through model
|
||||||
if isinstance(self.model, AdaptiveModel):
|
if isinstance(self.model, (DataParallel, WrappedDataParallel)):
|
||||||
|
module = self.model.module
|
||||||
|
else:
|
||||||
|
module = self.model
|
||||||
|
|
||||||
|
if isinstance(module, AdaptiveModel):
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"]
|
input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"]
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(self.model, BiAdaptiveModel):
|
elif isinstance(module, BiAdaptiveModel):
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
query_input_ids=batch["query_input_ids"],
|
query_input_ids=batch["query_input_ids"],
|
||||||
query_segment_ids=batch["query_segment_ids"],
|
query_segment_ids=batch["query_segment_ids"],
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
|
import os
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -17,8 +17,11 @@ from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
|||||||
from haystack.modeling.model.language_model import get_language_model, DPREncoder
|
from haystack.modeling.model.language_model import get_language_model, DPREncoder
|
||||||
from haystack.modeling.model.prediction_head import TextSimilarityHead
|
from haystack.modeling.model.prediction_head import TextSimilarityHead
|
||||||
from haystack.modeling.model.tokenization import get_tokenizer
|
from haystack.modeling.model.tokenization import get_tokenizer
|
||||||
|
from haystack.nodes.retriever.dense import DensePassageRetriever
|
||||||
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
|
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
|
||||||
|
|
||||||
|
from ..conftest import SAMPLES_PATH
|
||||||
|
|
||||||
|
|
||||||
def test_dpr_modules(caplog=None):
|
def test_dpr_modules(caplog=None):
|
||||||
if caplog:
|
if caplog:
|
||||||
@ -970,6 +973,33 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
|
|||||||
assert np.array_equal(all_embeddings["query"][0], all_embeddings3["query"][0])
|
assert np.array_equal(all_embeddings["query"][0], all_embeddings3["query"][0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
|
def test_dpr_training(document_store, tmp_path):
|
||||||
|
retriever = DensePassageRetriever(
|
||||||
|
document_store=document_store,
|
||||||
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||||
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||||
|
max_seq_len_query=8,
|
||||||
|
max_seq_len_passage=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_dir = f"{tmp_path}/test_dpr_training"
|
||||||
|
retriever.train(
|
||||||
|
data_dir=str(SAMPLES_PATH / "dpr"),
|
||||||
|
train_filename="sample.json",
|
||||||
|
dev_filename="sample.json",
|
||||||
|
test_filename="sample.json",
|
||||||
|
n_epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
grad_acc_steps=1,
|
||||||
|
save_dir=save_dir,
|
||||||
|
evaluate_every=10,
|
||||||
|
embed_title=True,
|
||||||
|
num_positives=1,
|
||||||
|
num_hard_negatives=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed)
|
# TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed)
|
||||||
# def test_dpr_training():
|
# def test_dpr_training():
|
||||||
# batch_size = 1
|
# batch_size = 1
|
||||||
@ -982,8 +1012,6 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
|
|||||||
# use_fast = True
|
# use_fast = True
|
||||||
# similarity_function = "dot_product"
|
# similarity_function = "dot_product"
|
||||||
#
|
#
|
||||||
#
|
|
||||||
#
|
|
||||||
# device, n_gpu = initialize_device_settings(use_cuda=False)
|
# device, n_gpu = initialize_device_settings(use_cuda=False)
|
||||||
#
|
#
|
||||||
# query_tokenizer = get_tokenizer(pretrained_model_name_or_path=question_lang_model,
|
# query_tokenizer = get_tokenizer(pretrained_model_name_or_path=question_lang_model,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user