mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-18 13:37:55 +00:00

* ci: Simplify Python code with ruff rules SIM * Revert #5828 * ruff --select=I --fix haystack/modeling/infer.py --------- Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
953 lines
49 KiB
Python
953 lines
49 KiB
Python
from typing import Optional, Union, List, Callable
|
|
|
|
import sys
|
|
import shutil
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import numpy
|
|
from tqdm import tqdm
|
|
import torch
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from torch.nn import MSELoss, Linear, Module, ModuleList, DataParallel
|
|
import torch.nn.functional as F
|
|
from torch.optim import Optimizer
|
|
|
|
from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataSilo
|
|
from haystack.modeling.evaluation.eval import Evaluator
|
|
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
|
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
|
|
from haystack.modeling.model.optimization import get_scheduler, WrappedDataParallel
|
|
from haystack.modeling.utils import GracefulKiller
|
|
from haystack.utils.experiment_tracking import Tracker as tracker
|
|
from haystack.utils.early_stopping import EarlyStopping
|
|
from haystack.telemetry import send_event
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Trainer:
|
|
"""
|
|
Handles the main model training procedure. This includes performing evaluation on the dev set at regular
|
|
intervals during training as well as evaluation on the test set at the end of training.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
optimizer,
|
|
data_silo: DataSilo,
|
|
epochs: int,
|
|
n_gpu: int,
|
|
device: torch.device,
|
|
lr_schedule=None,
|
|
evaluate_every: int = 100,
|
|
eval_report: bool = True,
|
|
use_amp: bool = False,
|
|
grad_acc_steps: int = 1,
|
|
local_rank: int = -1,
|
|
early_stopping: Optional[EarlyStopping] = None,
|
|
log_learning_rate: bool = False,
|
|
log_loss_every: int = 10,
|
|
checkpoint_on_sigterm: bool = False,
|
|
checkpoint_every: Optional[int] = None,
|
|
checkpoint_root_dir: Optional[Path] = None,
|
|
checkpoints_to_keep: int = 3,
|
|
from_epoch: int = 0,
|
|
from_step: int = 0,
|
|
global_step: int = 0,
|
|
evaluator_test: bool = True,
|
|
disable_tqdm: bool = False,
|
|
max_grad_norm: float = 1.0,
|
|
):
|
|
"""
|
|
:param optimizer: An optimizer object that determines the learning strategy to be used during training
|
|
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
|
:param epochs: How many times the training procedure will loop through the train dataset
|
|
:param n_gpu: The number of gpus available for training and evaluation.
|
|
:param device: The device on which the train, dev and test tensors should be hosted. Choose from torch.device("cpu") and torch.device("cuda").
|
|
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
|
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
|
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
|
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
|
training speed and reduce GPU memory usage.
|
|
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
|
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
|
:param grad_acc_steps: Number of training steps for which the gradients should be accumulated.
|
|
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
|
:param local_rank: Local rank of process when distributed training via DDP is used.
|
|
:param early_stopping: An initialized EarlyStopping object to control early stopping and saving of the best models.
|
|
:param log_learning_rate: Whether to log learning rate to experiment tracker (e.g. Mlflow)
|
|
:param log_loss_every: Log current train loss after this many train steps.
|
|
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
|
|
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
|
|
a SIGTERM notifies to save the training state and subsequently the instance is terminated.
|
|
:param checkpoint_every: Save a training checkpoint after this many steps of training.
|
|
:param checkpoint_root_dir: The directory Path where all training checkpoints are saved. For each individual
|
|
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
|
:param checkpoints_to_keep: The maximum number of training checkpoints to save.
|
|
:param from_epoch: the epoch number to start the training from. In the case when training resumes from a saved
|
|
checkpoint, it is used to fast-forward training to the last epoch in the checkpoint.
|
|
:param from_step: the step number to start the training from. In the case when training resumes from a saved
|
|
checkpoint, it is used to fast-forward training to the last step in the checkpoint.
|
|
:param global_step: the global step number across the training epochs.
|
|
:param evaluator_test: whether to perform evaluation on the test set
|
|
:param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments)
|
|
:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable
|
|
"""
|
|
amp_mapping = {"O0": False, "O1": True, "O2": True, "O3": True}
|
|
self.model = model
|
|
self.data_silo = data_silo
|
|
self.epochs = int(epochs)
|
|
if isinstance(use_amp, str):
|
|
if use_amp in amp_mapping:
|
|
logger.warning(
|
|
"The Trainer only supports native PyTorch automatic mixed precision and no longer supports the Apex library.\n"
|
|
"Because you provided Apex optimization level %s, automatic mixed precision was set to %s.\n"
|
|
"In the future, set `use_amp=True` to turn on automatic mixed precision.",
|
|
use_amp,
|
|
amp_mapping[use_amp],
|
|
)
|
|
use_amp = amp_mapping[use_amp]
|
|
else:
|
|
raise Exception(
|
|
f"use_amp value {use_amp} is not supported. Please provide use_amp=True to turn on automatic mixed precision."
|
|
)
|
|
self.use_amp = use_amp
|
|
self.optimizer = optimizer
|
|
self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
|
|
self.evaluate_every = evaluate_every
|
|
self.eval_report = eval_report
|
|
self.evaluator_test = evaluator_test
|
|
self.n_gpu = n_gpu
|
|
self.grad_acc_steps = grad_acc_steps
|
|
self.lr_schedule = lr_schedule
|
|
self.device = device
|
|
self.local_rank = local_rank
|
|
self.log_params()
|
|
self.early_stopping = early_stopping
|
|
self.log_learning_rate = log_learning_rate
|
|
self.log_loss_every = log_loss_every
|
|
self.disable_tqdm = disable_tqdm
|
|
self.max_grad_norm = max_grad_norm
|
|
self.test_result = None
|
|
|
|
self.checkpoint_on_sigterm = checkpoint_on_sigterm
|
|
if checkpoint_on_sigterm:
|
|
self.sigterm_handler = GracefulKiller() # type: Optional[GracefulKiller]
|
|
else:
|
|
self.sigterm_handler = None
|
|
self.checkpoint_root_dir = checkpoint_root_dir
|
|
self.checkpoints_to_keep = checkpoints_to_keep
|
|
self.checkpoint_every = checkpoint_every
|
|
if self.checkpoint_every and not checkpoint_root_dir:
|
|
raise Exception("checkpoint_path needs to be supplied when using checkpoint_every.")
|
|
if checkpoint_on_sigterm and not checkpoint_root_dir:
|
|
raise Exception("checkpoint_path needs to be supplied when using checkpoint_on_sigterm.")
|
|
|
|
self.from_epoch = from_epoch
|
|
self.from_step = from_step
|
|
self.global_step = global_step
|
|
|
|
def train(self):
|
|
"""
|
|
Perform the training procedure.
|
|
|
|
The training is visualized by a progress bar. It counts the epochs in a zero based manner.
|
|
For example, when you specify ``epochs=20`` it starts to count from 0 to 19.
|
|
|
|
If trainer evaluates the model with a test set the result of the
|
|
evaluation is stored in ``test_result``.
|
|
|
|
:return: Returns the model after training. When you do ``early_stopping``
|
|
with a ``save_dir`` the best model is loaded and returned.
|
|
"""
|
|
send_event(event_name="Training", event_properties={"class": self.__class__.__name__, "function_name": "train"})
|
|
# connect the prediction heads with the right output from processor
|
|
self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
|
|
# Check that the tokenizer(s) fits the language model(s)
|
|
if hasattr(self.model, "language_model3"):
|
|
self.model.verify_vocab_size(
|
|
vocab_size1=len(self.data_silo.processor.query_tokenizer),
|
|
vocab_size2=len(self.data_silo.processor.passage_tokenizer),
|
|
vocab_size3=len(self.data_silo.processor.table_tokenizer),
|
|
)
|
|
elif hasattr(self.model, "language_model2"):
|
|
self.model.verify_vocab_size(
|
|
vocab_size1=len(self.data_silo.processor.query_tokenizer),
|
|
vocab_size2=len(self.data_silo.processor.passage_tokenizer),
|
|
)
|
|
elif (
|
|
self.model.language_model.name != "DebertaV2"
|
|
): # DebertaV2 has mismatched vocab size on purpose (see https://github.com/huggingface/transformers/issues/12428)
|
|
self.model.verify_vocab_size(vocab_size=len(self.data_silo.processor.tokenizer))
|
|
self.model.train()
|
|
|
|
do_stopping = False
|
|
evalnr = 0
|
|
loss = 0
|
|
resume_from_step = self.from_step
|
|
|
|
for epoch in range(self.from_epoch, self.epochs):
|
|
early_break = False
|
|
self.from_epoch = epoch
|
|
train_data_loader = self.data_silo.get_data_loader("train")
|
|
progress_bar = tqdm(train_data_loader, disable=self.local_rank not in [0, -1] or self.disable_tqdm)
|
|
for step, batch in enumerate(progress_bar):
|
|
# when resuming training from a checkpoint, we want to fast forward to the step of the checkpoint
|
|
if resume_from_step and step <= resume_from_step:
|
|
if step % 10000 == 0:
|
|
logger.info("Skipping %s out of %s steps ...", step, resume_from_step)
|
|
if resume_from_step == step:
|
|
logger.info("Finished skipping %s steps ...", resume_from_step)
|
|
resume_from_step = None
|
|
else:
|
|
continue
|
|
|
|
progress_bar.set_description(f"Train epoch {epoch}/{self.epochs-1} (Cur. train loss: {loss:.4f})")
|
|
|
|
# Only for distributed training: we need to ensure that all ranks still have a batch left for training
|
|
if self.local_rank != -1 and not self._all_ranks_have_data(has_data=True, step=step):
|
|
early_break = True
|
|
break
|
|
|
|
# Move batch of samples to device
|
|
batch = {key: batch[key].to(self.device) for key in batch}
|
|
loss = self.compute_loss(batch, step)
|
|
|
|
# Perform evaluation
|
|
if (
|
|
self.evaluate_every != 0
|
|
and self.global_step % self.evaluate_every == 0
|
|
and self.global_step != 0
|
|
and self.local_rank in [0, -1]
|
|
):
|
|
dev_data_loader = self.data_silo.get_data_loader("dev")
|
|
if dev_data_loader is not None:
|
|
evaluator_dev = Evaluator(
|
|
data_loader=dev_data_loader,
|
|
tasks=self.data_silo.processor.tasks,
|
|
device=self.device,
|
|
report=self.eval_report,
|
|
)
|
|
evalnr += 1
|
|
result = evaluator_dev.eval(self.model)
|
|
evaluator_dev.log_results(result, "Dev", self.global_step)
|
|
if self.early_stopping:
|
|
do_stopping, save_model, eval_value = self.early_stopping.check_stopping(result)
|
|
if save_model:
|
|
logger.info(
|
|
"Saving current best model to %s, eval=%s", self.early_stopping.save_dir, eval_value
|
|
)
|
|
self.model.save(self.early_stopping.save_dir)
|
|
self.data_silo.processor.save(self.early_stopping.save_dir)
|
|
if do_stopping:
|
|
# log the stopping
|
|
logger.info("STOPPING EARLY AT EPOCH %s, STEP %s, EVALUATION %s", epoch, step, evalnr)
|
|
if do_stopping:
|
|
break
|
|
|
|
self.global_step += 1
|
|
self.from_step = step + 1
|
|
|
|
# save the current state as a checkpoint before exiting if a SIGTERM signal is received
|
|
if self.sigterm_handler and self.sigterm_handler.kill_now:
|
|
logger.info("Received a SIGTERM signal. Saving the current train state as a checkpoint ...")
|
|
if self.local_rank in [0, -1]:
|
|
self._save()
|
|
torch.distributed.destroy_process_group()
|
|
sys.exit(0)
|
|
|
|
# save a checkpoint and continue train
|
|
if self.checkpoint_every and step % self.checkpoint_every == 0:
|
|
if self.local_rank in [0, -1]:
|
|
self._save()
|
|
# Let other ranks wait until rank 0 has finished saving
|
|
if self.local_rank != -1:
|
|
torch.distributed.barrier()
|
|
|
|
if do_stopping:
|
|
break
|
|
|
|
# Only for distributed training: we need to ensure that all ranks still have a batch left for training
|
|
if self.local_rank != -1 and not early_break:
|
|
self._all_ranks_have_data(has_data=False)
|
|
|
|
# With early stopping we want to restore the best model
|
|
if self.early_stopping and self.early_stopping.save_dir:
|
|
logger.info("Restoring best model so far from %s", self.early_stopping.save_dir)
|
|
self.model = self.model.load(self.early_stopping.save_dir, self.device)
|
|
self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
|
|
|
|
# Eval on test set
|
|
if self.evaluator_test and self.local_rank in [0, -1]:
|
|
test_data_loader = self.data_silo.get_data_loader("test")
|
|
if test_data_loader is not None:
|
|
evaluator_test = Evaluator(
|
|
data_loader=test_data_loader, tasks=self.data_silo.processor.tasks, device=self.device
|
|
)
|
|
self.test_result = evaluator_test.eval(self.model)
|
|
evaluator_test.log_results(self.test_result, "Test", self.global_step)
|
|
self.model.eval()
|
|
return self.model
|
|
|
|
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
|
# Forward & backward pass through model
|
|
if isinstance(self.model, (DataParallel, WrappedDataParallel)):
|
|
module = self.model.module
|
|
else:
|
|
module = self.model
|
|
|
|
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
|
if isinstance(module, AdaptiveModel):
|
|
logits = self.model.forward(
|
|
input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"]
|
|
)
|
|
|
|
elif isinstance(module, BiAdaptiveModel):
|
|
logits = self.model.forward(
|
|
query_input_ids=batch["query_input_ids"],
|
|
query_segment_ids=batch["query_segment_ids"],
|
|
query_attention_mask=batch["query_attention_mask"],
|
|
passage_input_ids=batch["passage_input_ids"],
|
|
passage_segment_ids=batch["passage_segment_ids"],
|
|
passage_attention_mask=batch["passage_attention_mask"],
|
|
)
|
|
|
|
else:
|
|
logits = self.model.forward(**batch)
|
|
|
|
per_sample_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
|
|
loss = self.adjust_loss(per_sample_loss)
|
|
return self.backward_propagate(loss, step)
|
|
|
|
def backward_propagate(self, loss: torch.Tensor, step: int):
|
|
if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0] and self.local_rank in [-1, 0]:
|
|
tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
|
|
if self.log_learning_rate:
|
|
tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
|
|
|
|
self.scaler.scale(loss).backward()
|
|
|
|
if step % self.grad_acc_steps == 0:
|
|
if self.max_grad_norm is not None:
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
self.optimizer.zero_grad()
|
|
if self.lr_schedule:
|
|
self.lr_schedule.step()
|
|
return loss
|
|
|
|
def adjust_loss(self, loss: torch.Tensor):
|
|
loss = loss.mean()
|
|
if self.grad_acc_steps > 1:
|
|
loss = loss / self.grad_acc_steps
|
|
return loss
|
|
|
|
def log_params(self):
|
|
params = {"epochs": self.epochs, "n_gpu": self.n_gpu, "device": self.device, "use_amp": self.use_amp}
|
|
tracker.track_params(params)
|
|
|
|
@classmethod
|
|
def create_or_load_checkpoint(
|
|
cls,
|
|
data_silo: DataSilo,
|
|
checkpoint_root_dir: Path,
|
|
model,
|
|
optimizer,
|
|
local_rank: int = -1,
|
|
resume_from_checkpoint: str = "latest",
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Try loading a saved Trainer checkpoint. If no checkpoint found, it creates a new instance of Trainer.
|
|
|
|
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
|
:param checkpoint_root_dir: Path of the directory where all train checkpoints are saved. Each individual
|
|
checkpoint is stored in a sub-directory under it.
|
|
:param resume_from_checkpoint: the checkpoint name to start training from, e.g., "epoch_1_step_4532". It
|
|
defaults to "latest", using the checkpoint with the highest train steps.
|
|
"""
|
|
checkpoint_to_load = None
|
|
if checkpoint_root_dir and checkpoint_root_dir.exists():
|
|
if resume_from_checkpoint == "latest":
|
|
saved_checkpoints = cls._get_checkpoints(checkpoint_root_dir)
|
|
if saved_checkpoints:
|
|
checkpoint_to_load = saved_checkpoints[0] # latest checkpoint
|
|
else:
|
|
checkpoint_to_load = None
|
|
else:
|
|
checkpoint_to_load = checkpoint_root_dir / resume_from_checkpoint
|
|
|
|
if checkpoint_to_load:
|
|
# TODO load empty model class from config instead of passing here?
|
|
trainer = cls._load_checkpoint(
|
|
path=checkpoint_to_load, data_silo=data_silo, model=model, optimizer=optimizer, local_rank=local_rank
|
|
)
|
|
logger.info("Resuming training from the train checkpoint at %s ...", checkpoint_to_load)
|
|
else:
|
|
logger.info("No train checkpoints found. Starting a new training ...")
|
|
trainer = cls(
|
|
data_silo=data_silo,
|
|
model=model,
|
|
optimizer=optimizer,
|
|
local_rank=local_rank,
|
|
checkpoint_root_dir=checkpoint_root_dir,
|
|
**kwargs,
|
|
)
|
|
return trainer
|
|
|
|
@classmethod
|
|
def _load_checkpoint(cls, path: Path, data_silo: DataSilo, model, optimizer, local_rank: int = -1):
|
|
"""
|
|
Load the train checkpoint at given path.
|
|
|
|
:param path: The checkpoint path is subdirectory under checkpoint_root_dir. The individual checkpoint dirs have
|
|
a default naming convention of "epoch_{epoch_num}_step_{step_num}".
|
|
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
|
"""
|
|
if not path.exists():
|
|
raise Exception(f"The checkpoint path {path} does not exists.")
|
|
|
|
# In distributed mode, we save the model only once from process 0 (using cuda:0)
|
|
# At loading time, we need to load the model to the current cuda device (instead of back to cuda:0)
|
|
# Note: This assumes exactly one GPU per process (as recommended by PyTorch)
|
|
if local_rank == -1:
|
|
map_location = None
|
|
else:
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
map_location = {"cuda:0": f"cuda:{local_rank}"}
|
|
|
|
trainer_checkpoint = torch.load(path / "trainer", map_location=map_location)
|
|
trainer_state_dict = trainer_checkpoint["trainer_state"]
|
|
if local_rank != -1:
|
|
trainer_state_dict["device"] = device
|
|
trainer_state_dict["local_rank"] = local_rank
|
|
|
|
# Just setting seeds is not sufficient to have deterministic results when resuming
|
|
# training from a checkpoint. Additionally, the previous states of Random Number
|
|
# Generators also need to be restored from the saved checkpoint.
|
|
numpy_rng_state = trainer_checkpoint["numpy_rng_state"]
|
|
numpy.random.set_state(numpy_rng_state)
|
|
rng_state = trainer_checkpoint["rng_state"]
|
|
cuda_rng_state = trainer_checkpoint["cuda_rng_state"]
|
|
torch.set_rng_state(rng_state)
|
|
torch.cuda.set_rng_state(cuda_rng_state)
|
|
|
|
model.load_state_dict(trainer_checkpoint["model_state"], strict=True)
|
|
optimizer.load_state_dict(trainer_checkpoint["optimizer_state"])
|
|
|
|
scheduler_state_dict = trainer_checkpoint["scheduler_state"]
|
|
scheduler_opts = trainer_checkpoint["scheduler_opts"]
|
|
scheduler = get_scheduler(optimizer, scheduler_opts)
|
|
scheduler.load_state_dict(scheduler_state_dict)
|
|
|
|
trainer = Trainer(
|
|
data_silo=data_silo, model=model, optimizer=optimizer, lr_schedule=scheduler, **trainer_state_dict
|
|
)
|
|
|
|
logger.info("Loaded a train checkpoint from %s", path)
|
|
return trainer
|
|
|
|
@classmethod
|
|
def _get_checkpoints(cls, checkpoint_root_dir: Path):
|
|
"""
|
|
Get a list of checkpoint dirs sorted by the number of training steps.
|
|
"""
|
|
dirs = [d for d in checkpoint_root_dir.iterdir() if d.is_dir() and d.name.startswith("epoch")]
|
|
|
|
checkpoints_with_epoch_and_step = [] # list of tuple(checkpoint_dir, epoch, step)
|
|
for d in dirs:
|
|
epoch, step = [int(s) for s in str(d).split("_") if s.isdigit()]
|
|
checkpoints_with_epoch_and_step.append((d, epoch, step))
|
|
|
|
sorted_checkpoints_with_epoch_and_step = sorted(
|
|
checkpoints_with_epoch_and_step, key=lambda tup: (tup[1], tup[2]), reverse=True # sort by epoch and step
|
|
)
|
|
sorted_checkpoints = [tup[0] for tup in sorted_checkpoints_with_epoch_and_step]
|
|
|
|
return sorted_checkpoints
|
|
|
|
def _save(self):
|
|
"""
|
|
Save a train checkpoint at the Trainer's checkpoint_path.
|
|
|
|
Some objects(eg, scheduler) in the Trainer are not serializable using the Pickle module. For these objects,
|
|
the state_dict is stored for the checkpoint, that can be used to reconstruct a similar state upon resuming
|
|
train from the checkpoint.
|
|
|
|
#TODO The model is currently saved as a whole serialized object. The disadvantage of this approach is that it is
|
|
bound to specifics Python version, haystack version, directory structures etc. A more modular and reusable approach
|
|
is to save using AdaptiveModel's save() method where the model and the state_dict are stored separately.
|
|
|
|
# TODO custom defined evaluators are not saved in the checkpoint.
|
|
"""
|
|
logger.info("Saving a train checkpoint ...")
|
|
checkpoint_path = self.checkpoint_root_dir / "checkpoint_in_progress"
|
|
checkpoint_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
trainer_state_dict = self._get_state_dict()
|
|
|
|
# save as a regular AdaptiveModel (e.g. for down-stream eval during training from scratch)
|
|
self.model.save(checkpoint_path)
|
|
|
|
# save all state dicst (incl. the model) to have full reproducibility
|
|
torch.save(
|
|
{
|
|
"trainer_state": trainer_state_dict,
|
|
"model_state": self.model.state_dict(),
|
|
"optimizer_state": self.optimizer.state_dict(),
|
|
"scheduler_opts": self.lr_schedule.opts,
|
|
"scheduler_state": self.lr_schedule.state_dict(),
|
|
"numpy_rng_state": numpy.random.get_state(),
|
|
"rng_state": torch.get_rng_state(),
|
|
"cuda_rng_state": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
|
},
|
|
checkpoint_path / "trainer",
|
|
)
|
|
|
|
checkpoint_name = f"epoch_{self.from_epoch}_step_{self.from_step-1}"
|
|
checkpoint_path.replace(Path(checkpoint_path.parent) / checkpoint_name)
|
|
|
|
saved_checkpoints = self._get_checkpoints(self.checkpoint_root_dir)
|
|
if len(saved_checkpoints) > self.checkpoints_to_keep:
|
|
for cp in saved_checkpoints[self.checkpoints_to_keep :]:
|
|
shutil.rmtree(cp)
|
|
|
|
logger.info("Saved a training checkpoint after %s", checkpoint_name)
|
|
|
|
def _get_state_dict(self):
|
|
"""
|
|
Serializable state dictionary of a Trainer object
|
|
"""
|
|
state_dict = {
|
|
"evaluate_every": self.evaluate_every,
|
|
"n_gpu": self.n_gpu,
|
|
"grad_acc_steps": self.grad_acc_steps,
|
|
"device": self.device,
|
|
"local_rank": self.local_rank,
|
|
"early_stopping": self.early_stopping,
|
|
"epochs": self.epochs,
|
|
"checkpoint_on_sigterm": self.checkpoint_on_sigterm,
|
|
"checkpoint_root_dir": self.checkpoint_root_dir,
|
|
"checkpoint_every": self.checkpoint_every,
|
|
"checkpoints_to_keep": self.checkpoints_to_keep,
|
|
"from_epoch": self.from_epoch,
|
|
"from_step": self.from_step,
|
|
"global_step": self.global_step,
|
|
"log_learning_rate": self.log_learning_rate,
|
|
"log_loss_every": self.log_loss_every,
|
|
"disable_tqdm": self.disable_tqdm,
|
|
"use_amp": self.use_amp,
|
|
}
|
|
|
|
return state_dict
|
|
|
|
def _all_ranks_have_data(self, has_data: bool, step: Optional[int] = None):
|
|
"""
|
|
Verify in distributed training if all ranks still have data left. We send a "1" from here if this rank has data
|
|
and a "0" if a process has none .
|
|
If all ranks have data, they'll all send a 1, our sum equals world_size and we continue training.
|
|
If not, we must break the loop for those who still have data to synchronize again.
|
|
:param has_data: bool, whether the current rank has still data
|
|
:param step: int, current step (only used for logging)
|
|
:return: bool, whether all ranks have training data left
|
|
"""
|
|
if has_data:
|
|
ranks_with_data = torch.ones(1).to(self.device)
|
|
else:
|
|
ranks_with_data = torch.zeros(1).to(self.device)
|
|
|
|
torch.distributed.all_reduce(ranks_with_data, op=torch.distributed.ReduceOp.SUM)
|
|
|
|
if ranks_with_data < torch.distributed.get_world_size():
|
|
if step is not None:
|
|
logger.info(
|
|
"Stopping epoch %s at step %s for rank %s since at least one other rank "
|
|
"(~ one GPU) in distributed training doesn't have any more batches... ",
|
|
self.from_epoch,
|
|
step,
|
|
self.local_rank,
|
|
)
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
class DistillationTrainer(Trainer):
|
|
"""
|
|
This trainer uses the teacher logits from DistillationDataSilo
|
|
to compute a distillation loss in addition to the loss based on the labels.
|
|
|
|
**Example**
|
|
```python
|
|
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
|
|
teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2")
|
|
|
|
processor = SquadProcessor(tokenizer=student.inferencer.processor.tokenizer, max_seq_len=384)
|
|
student, optimizer, _ = initialize_optimizer(student, n_batches=len(data_silo.loaders["train"]), n_epochs=3, device="cuda:0", learning_rate=3e-5)
|
|
|
|
data_silo = DistillationDataSilo(teacher_model=teacher, teacher_batch_size=2, batch_size=8, device="cuda:0", processor=processor)
|
|
trainer = DistillationTrainer(student=student, optimizer=optimizer, data_silo=data_silo, epochs=3, n_gpu=1, device="cuda:0")
|
|
|
|
trainer.train()
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: "AdaptiveModel",
|
|
optimizer: Optimizer,
|
|
data_silo: DistillationDataSilo,
|
|
epochs: int,
|
|
n_gpu: int,
|
|
device: torch.device,
|
|
lr_schedule: Optional[_LRScheduler] = None,
|
|
evaluate_every: int = 100,
|
|
eval_report: bool = True,
|
|
use_amp: bool = False,
|
|
grad_acc_steps: int = 1,
|
|
local_rank: int = -1,
|
|
early_stopping: Optional[EarlyStopping] = None,
|
|
log_learning_rate: bool = False,
|
|
log_loss_every: int = 10,
|
|
checkpoint_on_sigterm: bool = False,
|
|
checkpoint_every: Optional[int] = None,
|
|
checkpoint_root_dir: Optional[Path] = None,
|
|
checkpoints_to_keep: int = 3,
|
|
from_epoch: int = 0,
|
|
from_step: int = 0,
|
|
global_step: int = 0,
|
|
evaluator_test: bool = True,
|
|
disable_tqdm: bool = False,
|
|
max_grad_norm: float = 1.0,
|
|
distillation_loss_weight: float = 0.5,
|
|
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div",
|
|
temperature: float = 1.0,
|
|
):
|
|
"""
|
|
:param optimizer: An optimizer object that determines the learning strategy to be used during training
|
|
:param model: The model to be trained
|
|
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
|
:param epochs: How many times the training procedure will loop through the train dataset
|
|
:param n_gpu: The number of gpus available for training and evaluation.
|
|
:param device: The device on which the train, dev and test tensors should be hosted. Choose from torch.device("cpu") and torch.device("cuda").
|
|
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
|
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
|
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
|
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
|
training speed and reduce GPU memory usage.
|
|
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
|
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
|
:param grad_acc_steps: Number of training steps for which the gradients should be accumulated.
|
|
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
|
:param local_rank: Local rank of process when distributed training via DDP is used.
|
|
:param early_stopping: An initialized EarlyStopping object to control early stopping and saving of the best models.
|
|
:param log_learning_rate: Whether to log learning rate to experiment tracker (e.g. Mlflow)
|
|
:param log_loss_every: Log current train loss after this many train steps.
|
|
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
|
|
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
|
|
a SIGTERM notifies to save the training state and subsequently the instance is terminated.
|
|
:param checkpoint_every: save a train checkpoint after this many steps of training.
|
|
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
|
|
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
|
:param checkpoints_to_keep: maximum number of train checkpoints to save.
|
|
:param from_epoch: the epoch number to start the training from. In the case when training resumes from a saved
|
|
checkpoint, it is used to fast-forward training to the last epoch in the checkpoint.
|
|
:param from_step: the step number to start the training from. In the case when training resumes from a saved
|
|
checkpoint, it is used to fast-forward training to the last step in the checkpoint.
|
|
:param global_step: the global step number across the training epochs.
|
|
:param evaluator_test: whether to perform evaluation on the test set
|
|
:param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments)
|
|
:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable
|
|
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
|
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
|
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
|
"""
|
|
super().__init__(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
data_silo=data_silo,
|
|
epochs=epochs,
|
|
n_gpu=n_gpu,
|
|
device=device,
|
|
lr_schedule=lr_schedule,
|
|
evaluate_every=evaluate_every,
|
|
eval_report=eval_report,
|
|
use_amp=use_amp,
|
|
grad_acc_steps=grad_acc_steps,
|
|
local_rank=local_rank,
|
|
early_stopping=early_stopping,
|
|
log_learning_rate=log_learning_rate,
|
|
log_loss_every=log_loss_every,
|
|
checkpoint_on_sigterm=checkpoint_on_sigterm,
|
|
checkpoint_every=checkpoint_every,
|
|
checkpoint_root_dir=checkpoint_root_dir,
|
|
checkpoints_to_keep=checkpoints_to_keep,
|
|
from_epoch=from_epoch,
|
|
from_step=from_step,
|
|
global_step=global_step,
|
|
evaluator_test=evaluator_test,
|
|
disable_tqdm=disable_tqdm,
|
|
max_grad_norm=max_grad_norm,
|
|
)
|
|
self.distillation_loss_weight = distillation_loss_weight
|
|
if distillation_loss == "mse":
|
|
self.distillation_loss_fn = MSELoss()
|
|
elif distillation_loss == "kl_div":
|
|
self.distillation_loss_fn = self._kl_div # type: ignore [assignment]
|
|
self.temperature = temperature
|
|
|
|
def _kl_div(self, student_logits, teacher_logits):
|
|
student_log_probs = F.log_softmax(student_logits, dim=-2)
|
|
teacher_probs = F.softmax(teacher_logits, dim=-2)
|
|
return F.kl_div(student_log_probs, teacher_probs, reduction="batchmean")
|
|
|
|
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
|
keys = list(batch.keys())
|
|
keys = [key for key in keys if key.startswith("teacher_output")]
|
|
teacher_logits = [batch.pop(key) for key in keys]
|
|
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
|
logits = self.model.forward(
|
|
input_ids=batch.get("input_ids"),
|
|
segment_ids=batch.get("segment_ids"),
|
|
padding_mask=batch.get("padding_mask"),
|
|
output_hidden_states=batch.get("output_hidden_states"),
|
|
output_attentions=batch.get("output_attentions"),
|
|
)
|
|
student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
|
|
distillation_loss = self.distillation_loss_fn(
|
|
student_logits=logits[0] / self.temperature, teacher_logits=teacher_logits[0] / self.temperature
|
|
)
|
|
combined_loss = distillation_loss * self.distillation_loss_weight * (
|
|
self.temperature**2
|
|
) + student_loss * (1 - self.distillation_loss_weight)
|
|
loss = self.adjust_loss(combined_loss)
|
|
return self.backward_propagate(loss, step)
|
|
|
|
|
|
class TinyBERTDistillationTrainer(Trainer):
|
|
"""
|
|
This Trainer implements the first stage of task specific distillation as described in the TinyBERT paper.
|
|
The standard DistillationTrainer can be used for the second stage. Unlike the DistillationTrainer, this Trainer does not use
|
|
cached teacher outputs as it would be too memory expensive. This means it is much slower than the DistillationTrainer.
|
|
|
|
**Example**
|
|
```python
|
|
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D")
|
|
teacher = FARMReader(model_name_or_path="twmkn9/bert-base-uncased-squad2")
|
|
|
|
processor = SquadProcessor(tokenizer=student.inferencer.processor.tokenizer, max_seq_len=384)
|
|
student, optimizer, _ = initialize_optimizer(student, n_batches=len(data_silo.loaders["train"]), n_epochs=3, device="cuda:0", learning_rate=3e-5)
|
|
|
|
data_silo = DataSilo(teacher_model=teacher, batch_size=8, device="cuda:0", processor=processor)
|
|
trainer = TinyBertDistillationTrainer(student=student, optimizer=optimizer, data_silo=data_silo, epochs=3, n_gpu=1, device="cuda:0")
|
|
|
|
trainer.train()
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: AdaptiveModel,
|
|
teacher_model: AdaptiveModel,
|
|
optimizer: Optimizer,
|
|
data_silo: DistillationDataSilo,
|
|
epochs: int,
|
|
n_gpu: int,
|
|
device: torch.device,
|
|
lr_schedule: Optional[_LRScheduler] = None,
|
|
evaluate_every: int = 100,
|
|
eval_report: bool = True,
|
|
use_amp: bool = False,
|
|
grad_acc_steps: int = 1,
|
|
local_rank: int = -1,
|
|
early_stopping: Optional[EarlyStopping] = None,
|
|
log_learning_rate: bool = False,
|
|
log_loss_every: int = 10,
|
|
checkpoint_on_sigterm: bool = False,
|
|
checkpoint_every: Optional[int] = None,
|
|
checkpoint_root_dir: Optional[Path] = None,
|
|
checkpoints_to_keep: int = 3,
|
|
from_epoch: int = 0,
|
|
from_step: int = 0,
|
|
global_step: int = 0,
|
|
evaluator_test: bool = True,
|
|
disable_tqdm: bool = False,
|
|
max_grad_norm: float = 1.0,
|
|
):
|
|
"""
|
|
:param optimizer: An optimizer object that determines the learning strategy to be used during training
|
|
:param model: The model to be trained. It needs to be a TinyBERT model.
|
|
:param teacher_model: The teacher model used for distillation. This has to be based on bert-base-uncased.
|
|
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
|
:param epochs: How many times the training procedure will loop through the train dataset
|
|
:param n_gpu: The number of gpus available for training and evaluation.
|
|
:param device: The device on which the train, dev and test tensors should be hosted. Choose from torch.device("cpu") and torch.device("cuda").
|
|
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
|
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
|
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
|
:param use_amp: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
|
|
training speed and reduce GPU memory usage.
|
|
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
|
|
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
|
|
:param grad_acc_steps: Number of training steps for which the gradients should be accumulated.
|
|
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
|
:param local_rank: Local rank of process when distributed training via DDP is used.
|
|
:param early_stopping: An initialized EarlyStopping object to control early stopping and saving of the best models.
|
|
:param log_learning_rate: Whether to log learning rate to experiment tracker (e.g. Mlflow)
|
|
:param log_loss_every: Log current train loss after this many train steps.
|
|
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
|
|
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
|
|
a SIGTERM notifies to save the training state and subsequently the instance is terminated.
|
|
:param checkpoint_every: save a train checkpoint after this many steps of training.
|
|
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
|
|
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
|
:param checkpoints_to_keep: maximum number of train checkpoints to save.
|
|
:param from_epoch: the epoch number to start the training from. In the case when training resumes from a saved
|
|
checkpoint, it is used to fast-forward training to the last epoch in the checkpoint.
|
|
:param from_step: the step number to start the training from. In the case when training resumes from a saved
|
|
checkpoint, it is used to fast-forward training to the last step in the checkpoint.
|
|
:param global_step: the global step number across the training epochs.
|
|
:param evaluator_test: whether to perform evaluation on the test set
|
|
:param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments)
|
|
:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable
|
|
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
|
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
|
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
|
"""
|
|
super().__init__(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
data_silo=data_silo,
|
|
epochs=epochs,
|
|
n_gpu=n_gpu,
|
|
device=device,
|
|
lr_schedule=lr_schedule,
|
|
evaluate_every=evaluate_every,
|
|
eval_report=eval_report,
|
|
use_amp=use_amp,
|
|
grad_acc_steps=grad_acc_steps,
|
|
local_rank=local_rank,
|
|
early_stopping=early_stopping,
|
|
log_learning_rate=log_learning_rate,
|
|
log_loss_every=log_loss_every,
|
|
checkpoint_on_sigterm=checkpoint_on_sigterm,
|
|
checkpoint_every=checkpoint_every,
|
|
checkpoint_root_dir=checkpoint_root_dir,
|
|
checkpoints_to_keep=checkpoints_to_keep,
|
|
from_epoch=from_epoch,
|
|
from_step=from_step,
|
|
global_step=global_step,
|
|
evaluator_test=evaluator_test,
|
|
disable_tqdm=disable_tqdm,
|
|
max_grad_norm=max_grad_norm,
|
|
)
|
|
|
|
self.loss = DistillationLoss(model, teacher_model, device)
|
|
if torch.cuda.device_count() > 1 and device.type == "cuda":
|
|
self.loss = DataParallel(self.loss).to(device) # type: ignore [assignment]
|
|
|
|
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
|
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
|
loss = torch.sum(
|
|
self.loss(
|
|
input_ids=batch.get("input_ids"),
|
|
segment_ids=batch.get("segment_ids"),
|
|
padding_mask=batch.get("padding_mask"),
|
|
)
|
|
)
|
|
loss = self.adjust_loss(loss)
|
|
return self.backward_propagate(loss, step)
|
|
|
|
|
|
class DistillationLoss(Module):
|
|
"""
|
|
Calculates the distillation loss in a separate module to allow for data parallelization.
|
|
"""
|
|
|
|
def __init__(self, model: Union[DataParallel, AdaptiveModel], teacher_model: Module, device: torch.device):
|
|
super().__init__()
|
|
self.model = model.module.to(device) if isinstance(model, DataParallel) else model.to(device)
|
|
self.teacher_model = teacher_model.to(device)
|
|
|
|
# creating dummy inputs to get the shapes of hidden states and attention of teacher and student model
|
|
dummy_inputs = teacher_model.language_model.model.dummy_inputs # type: ignore [union-attr]
|
|
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(device) # type: ignore [operator,index]
|
|
dummy_inputs["padding_mask"] = torch.ones_like(dummy_inputs["input_ids"], device=device) # type: ignore [operator,index]
|
|
dummy_inputs["segment_ids"] = torch.zeros_like(dummy_inputs["input_ids"], device=device) # type: ignore [operator,index]
|
|
|
|
with torch.no_grad():
|
|
_, teacher_hidden_states, teacher_attentions = self.teacher_model.forward( # type: ignore [arg-type]
|
|
**dummy_inputs, output_attentions=True, output_hidden_states=True
|
|
)
|
|
_, hidden_states, attentions = self.model.forward( # type: ignore [arg-type]
|
|
**dummy_inputs, output_attentions=True, output_hidden_states=True
|
|
)
|
|
|
|
if len(teacher_attentions) % len(attentions) != 0:
|
|
raise ValueError(
|
|
"Teacher and student model do not seem to be compatible. Have you made sure that the student is a TinyBERT model and that the teacher is a BERT model?"
|
|
)
|
|
|
|
self.teacher_block_size = len(teacher_attentions) // len(attentions)
|
|
|
|
teacher_dims = [hidden_state.shape[-1] for hidden_state in teacher_hidden_states]
|
|
student_dims = [hidden_state.shape[-1] for hidden_state in hidden_states]
|
|
|
|
# creating linear mappings in case the teacher and student model have different hidden state dimensions
|
|
self.dim_mappings: List[Optional[Linear]] = ModuleList([]) # type: ignore [assignment]
|
|
|
|
for teacher_dim, student_dim in zip(teacher_dims, student_dims):
|
|
if teacher_dim != student_dim:
|
|
self.dim_mappings.append(Linear(student_dim, teacher_dim, bias=False).to(device))
|
|
else:
|
|
self.dim_mappings.append(None)
|
|
|
|
def forward(self, input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor):
|
|
with torch.no_grad():
|
|
_, teacher_hidden_states, teacher_attentions = self.teacher_model.forward(
|
|
input_ids=input_ids,
|
|
segment_ids=segment_ids,
|
|
padding_mask=padding_mask,
|
|
output_attentions=True,
|
|
output_hidden_states=True,
|
|
)
|
|
_, hidden_states, attentions = self.model.forward(
|
|
input_ids=input_ids,
|
|
segment_ids=segment_ids,
|
|
padding_mask=padding_mask,
|
|
output_attentions=True,
|
|
output_hidden_states=True,
|
|
)
|
|
loss = torch.tensor(0.0, device=input_ids.device)
|
|
|
|
# calculating attention loss
|
|
for student_attention, teacher_attention, dim_mapping in zip(
|
|
attentions, teacher_attentions[self.teacher_block_size - 1 :: self.teacher_block_size], self.dim_mappings
|
|
):
|
|
# this wasn't described in the paper, but it was used in the original implementation
|
|
student_attention = torch.where(
|
|
student_attention <= -1e2, torch.zeros_like(student_attention), student_attention
|
|
)
|
|
teacher_attention = torch.where(
|
|
teacher_attention <= -1e2, torch.zeros_like(teacher_attention), teacher_attention
|
|
)
|
|
|
|
loss += F.mse_loss(student_attention, teacher_attention)
|
|
|
|
# calculating hidden state loss
|
|
for student_hidden_state, teacher_hidden_state in zip(
|
|
hidden_states, teacher_hidden_states[:: self.teacher_block_size]
|
|
):
|
|
# linear mapping in case the teacher and student model have different hidden state dimensions, not necessary for attention as attention shape is determined by number of attention heads and sequence length
|
|
if dim_mapping:
|
|
student_hidden_state = dim_mapping(student_hidden_state)
|
|
|
|
loss += F.mse_loss(student_hidden_state, teacher_hidden_state)
|
|
|
|
return torch.unsqueeze(loss, -1)
|