Adding distillation loss functions from TinyBERT (#1879)

* initial tinybertdistill commit

* add tinybert distill loss

* remove teacher caching for tinybert

* add tinybert to distil_from method

* Add latest docstring and tutorial changes

* add dim mapping and fix type hints

* fix type hints

* fix dummy input

* fix dim mapping for tinybert loss and add comments/doc strings

* add test for tinybert loss

* Add latest docstring and tutorial changes

* add comment

* fix BERT forward parameters

* add doc string to AdaptiveModel forward method

* remove unnecessary data silo

* fix farm import

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
MichelBartels 2021-12-23 14:54:02 +01:00 committed by GitHub
parent fc8df2163d
commit f33c2b987a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 306 additions and 45 deletions

View File

@ -161,7 +161,7 @@ None
#### distil\_from #### distil\_from
```python ```python
| distil_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0) | distil_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0, tinybert_loss: bool = False, tinybert_epochs: int = 1)
``` ```
Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset
@ -218,8 +218,10 @@ If any checkpoints are stored, a subsequent run of train() will resume training
:param caching whether or not to use caching for preprocessed dataset and teacher logits :param caching whether or not to use caching for preprocessed dataset and teacher logits
- `cache_path`: Path to cache the preprocessed dataset and teacher logits - `cache_path`: Path to cache the preprocessed dataset and teacher logits
- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important. - `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits) - `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. - `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
- `tinybert_loss`: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased.
- `tinybert_epochs`: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
**Returns**: **Returns**:

View File

@ -743,19 +743,29 @@ class DistillationDataSilo(DataSilo):
super().__init__(max_processes=max_processes, processor=processor, batch_size=batch_size, eval_batch_size=eval_batch_size, super().__init__(max_processes=max_processes, processor=processor, batch_size=batch_size, eval_batch_size=eval_batch_size,
distributed=distributed, automatic_loading=automatic_loading, caching=caching, cache_path=cache_path) distributed=distributed, automatic_loading=automatic_loading, caching=caching, cache_path=cache_path)
def _run_teacher(self, batch: List[List[torch.Tensor]], corresponding_chunks: List[int], def _run_teacher(self, batch: dict) -> List[torch.Tensor]:
"""
Run the teacher model on the given batch.
"""
return self.teacher.inferencer.model(**batch)
def _pass_batches(self, batch: List[List[torch.Tensor]], corresponding_chunks: List[int],
teacher_outputs: List[List[Tuple[torch.Tensor, ...]]], tensor_names: List[str]): teacher_outputs: List[List[Tuple[torch.Tensor, ...]]], tensor_names: List[str]):
with torch.no_grad(): with torch.no_grad():
batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...) batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...)
batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature
batch_dict = {key: tensor.to(self.device) for key, tensor in zip(tensor_names, batch_transposed_list)} # create input dict batch_dict = {key: tensor.to(self.device) for key, tensor in zip(tensor_names, batch_transposed_list)} # create input dict
y = self.teacher.inferencer.model(**batch_dict) y = self._run_teacher(batch=batch_dict) # run teacher model
y = [y.cpu() for y in y] y = [y.cpu() for y in y]
self.output_len = len(y)
# grouping by chunk # grouping by chunk
for i, data in zip(corresponding_chunks, zip(*y)): # transpose back for i, data in zip(corresponding_chunks, zip(*y)): # transpose back
teacher_outputs[i].append(data) teacher_outputs[i].append(data)
return return
def _teacher_output_names(self) -> List[str]:
return ["teacher_output_" + str(i) for i in range(self.output_len)]
def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[List[Dict]] = None): def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[List[Dict]] = None):
concat_datasets, tensor_names = super()._get_dataset(filename, dicts) concat_datasets, tensor_names = super()._get_dataset(filename, dicts)
@ -772,16 +782,16 @@ class DistillationDataSilo(DataSilo):
batch.append(x) batch.append(x)
corresponding_chunks.append(i) corresponding_chunks.append(i)
if len(batch) == self.teacher_batch_size: if len(batch) == self.teacher_batch_size:
self._run_teacher(batch, corresponding_chunks, teacher_outputs, tensor_names) # doing forward pass on teacher model self._pass_batches(batch, corresponding_chunks, teacher_outputs, tensor_names) # doing forward pass on teacher model
batch = [] batch = []
corresponding_chunks = [] corresponding_chunks = []
if batch: if batch:
self._run_teacher(batch, corresponding_chunks, teacher_outputs, tensor_names) self._pass_batches(batch, corresponding_chunks, teacher_outputs, tensor_names)
# appending teacher outputs to original dataset # appending teacher outputs to original dataset
for dataset, teacher_output in zip(concat_datasets.datasets, teacher_outputs): for dataset, teacher_output in zip(concat_datasets.datasets, teacher_outputs):
dataset.tensors += tuple(torch.stack(tensors) for tensors in zip(*teacher_output)) dataset.tensors += tuple(torch.stack(tensors) for tensors in zip(*teacher_output))
tensor_names.extend(["teacher_output_" + str(i) for i, _ in enumerate(zip(*teacher_output))]) tensor_names += self._teacher_output_names()
concat_datasets = ConcatDataset(concat_datasets.datasets) # making sure metrics are updated concat_datasets = ConcatDataset(concat_datasets.datasets) # making sure metrics are updated
return concat_datasets, tensor_names return concat_datasets, tensor_names
@ -796,7 +806,8 @@ class DistillationDataSilo(DataSilo):
"max_seq_len": self.processor.max_seq_len, "max_seq_len": self.processor.max_seq_len,
"dev_split": self.processor.dev_split, "dev_split": self.processor.dev_split,
"tasks": self.processor.tasks, "tasks": self.processor.tasks,
"teacher_name_or_path": self.teacher.pipeline_config["params"]["model_name_or_path"] "teacher_name_or_path": self.teacher.pipeline_config["params"]["model_name_or_path"],
"data_silo_type": self.__class__.__name__,
} }
checksum = get_dict_checksum(payload_dict) checksum = get_dict_checksum(payload_dict)
return checksum return checksum

View File

@ -8,7 +8,7 @@ from typing import Iterable, Dict, Union, List, Optional, Callable
import numpy import numpy
import torch import torch
from torch import nn from torch import nn, set_warn_always
from transformers import AutoConfig from transformers import AutoConfig
from transformers.convert_graph_to_onnx import convert, quantize as quantize_model from transformers.convert_graph_to_onnx import convert, quantize as quantize_model
@ -356,18 +356,29 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
all_labels.append(labels) all_labels.append(labels)
return all_labels return all_labels
def forward(self, **kwargs): def forward(self, output_hidden_states: bool = False, output_attentions: bool = False, **kwargs):
""" """
Push data through the whole model and returns logits. The data will Push data through the whole model and returns logits. The data will
propagate through the language model and each of the attached prediction heads. propagate through the language model and each of the attached prediction heads.
:param kwargs: Holds all arguments that need to be passed to the language model :param kwargs: Holds all arguments that need to be passed to the language model
and prediction head(s). and prediction head(s).
:param output_hidden_states: Whether to output hidden states
:param output_attentions: Whether to output attentions
:return: All logits as torch.tensor or multiple tensors. :return: All logits as torch.tensor or multiple tensors.
""" """
# Run forward pass of language model # Run forward pass of language model
sequence_output, pooled_output = self.forward_lm(**kwargs) output_tuple = self.language_model.forward(**kwargs, output_hidden_states=output_hidden_states, output_attentions=output_attentions)
if output_hidden_states:
if output_attentions:
sequence_output, pooled_output, hidden_states, attentions = output_tuple
else:
sequence_output, pooled_output, hidden_states = output_tuple
else:
if output_attentions:
sequence_output, pooled_output, attentions = output_tuple
else:
sequence_output, pooled_output = output_tuple
# Run forward pass of (multiple) prediction heads using the output from above # Run forward pass of (multiple) prediction heads using the output from above
all_logits = [] all_logits = []
if len(self.prediction_heads) > 0: if len(self.prediction_heads) > 0:
@ -392,6 +403,13 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
# just return LM output (e.g. useful for extracting embeddings at inference time) # just return LM output (e.g. useful for extracting embeddings at inference time)
all_logits.append((sequence_output, pooled_output)) all_logits.append((sequence_output, pooled_output))
if output_hidden_states:
if output_attentions:
return all_logits, hidden_states, attentions
else:
return all_logits, hidden_states
elif output_attentions:
return all_logits, attentions
return all_logits return all_logits
def forward_lm(self, **kwargs): def forward_lm(self, **kwargs):

View File

@ -484,6 +484,8 @@ class Bert(LanguageModel):
input_ids: torch.Tensor, input_ids: torch.Tensor,
segment_ids: torch.Tensor, segment_ids: torch.Tensor,
padding_mask: torch.Tensor, padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs, **kwargs,
): ):
""" """
@ -495,19 +497,24 @@ class Bert(LanguageModel):
It is a tensor of shape [batch_size, max_seq_len] It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len] of shape [batch_size, max_seq_len]
:return: Embeddings for each token in the input sequence. :param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence. Can also return hidden states and attentions if specified via the arguments output_hidden_states and output_attentions
""" """
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions
output_tuple = self.model( output_tuple = self.model(
input_ids, input_ids,
token_type_ids=segment_ids, token_type_ids=segment_ids,
attention_mask=padding_mask, attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=False
) )
if self.model.encoder.config.output_hidden_states == True: return output_tuple
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
def enable_hidden_states_output(self): def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True self.model.encoder.config.output_hidden_states = True

View File

@ -1 +1 @@
from haystack.modeling.training.base import Trainer, DistillationTrainer from haystack.modeling.training.base import Trainer, DistillationTrainer, TinyBERTDistillationTrainer

View File

@ -1,9 +1,6 @@
from typing import Optional, Union, Tuple, List, Callable from typing import Optional, Union, Tuple, List, Callable
from typing import TYPE_CHECKING from torch.optim.lr_scheduler import _LRScheduler
if TYPE_CHECKING:
from haystack.nodes import FARMReader
from torch.optim.lr_scheduler import _LRScheduler
import sys import sys
import shutil import shutil
@ -14,7 +11,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
from torch.nn import MSELoss from torch.nn import MSELoss, Linear
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import Optimizer from torch.optim import Optimizer
@ -630,7 +627,7 @@ class DistillationTrainer(Trainer):
""" """
def __init__( def __init__(
self, self,
model: "FARMReader", model: "AdaptiveModel",
optimizer: Optimizer, optimizer: Optimizer,
data_silo: DistillationDataSilo, data_silo: DistillationDataSilo,
epochs: int, epochs: int,
@ -730,4 +727,165 @@ class DistillationTrainer(Trainer):
student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch) student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
distillation_loss = self.distillation_loss_fn(student_logits=logits[0] / self.temperature, teacher_logits=teacher_logits[0] / self.temperature) distillation_loss = self.distillation_loss_fn(student_logits=logits[0] / self.temperature, teacher_logits=teacher_logits[0] / self.temperature)
combined_loss = distillation_loss * self.distillation_loss_weight * (self.temperature ** 2) + student_loss * (1 - self.distillation_loss_weight) combined_loss = distillation_loss * self.distillation_loss_weight * (self.temperature ** 2) + student_loss * (1 - self.distillation_loss_weight)
return self.backward_propagate(combined_loss, step) return self.backward_propagate(combined_loss, step)
class TinyBERTDistillationTrainer(Trainer):
"""
This Trainer implements the first stage of task specific distillation as described in the TinyBERT paper.
The standard DistillationTrainer can be used for the second stage. Unlike the DistillationTrainer, this Trainer does not use
cached teacher outputs as it would be too memory expensive. This means it is much slower than the DistillationTrainer.
**Example**
```python
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D")
teacher = FARMReader(model_name_or_path="twmkn9/bert-base-uncased-squad2")
processor = SquadProcessor(tokenizer=student.inferencer.processor.tokenizer, max_seq_len=384)
student, optimizer, _ = initialize_optimizer(student, n_batches=len(data_silo.loaders["train"]), n_epochs=3, device="cuda:0", learning_rate=3e-5)
data_silo = DataSilo(teacher_model=teacher, batch_size=8, device="cuda:0", processor=processor)
trainer = TinyBertDistillationTrainer(student=student, optimizer=optimizer, data_silo=data_silo, epochs=3, n_gpu=1, device="cuda:0")
trainer.train()
```
"""
def __init__(
self,
model: AdaptiveModel,
teacher_model: AdaptiveModel,
optimizer: Optimizer,
data_silo: DistillationDataSilo,
epochs: int,
n_gpu: int,
device: str,
lr_schedule: Optional["_LRScheduler"]=None,
evaluate_every: int = 100,
eval_report: bool = True,
use_amp: Optional[str] = None,
grad_acc_steps: int = 1,
local_rank: int = -1,
early_stopping: Optional[EarlyStopping] = None,
log_learning_rate: bool = False,
log_loss_every: int = 10,
checkpoint_on_sigterm: bool = False,
checkpoint_every: Optional[int] = None,
checkpoint_root_dir: Optional[Path] = None,
checkpoints_to_keep: int = 3,
from_epoch: int = 0,
from_step: int = 0,
global_step: int = 0,
evaluator_test: bool = True,
disable_tqdm: bool = False,
max_grad_norm: float = 1.0,
):
"""
:param optimizer: An optimizer object that determines the learning strategy to be used during training
:param model: The model to be trained. It needs to be a TinyBERT model.
:param teacher_model: The teacher model used for distillation. This has to be based on bert-base-uncased.
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
:param epochs: How many times the training procedure will loop through the train dataset
:param n_gpu: The number of gpus available for training and evaluation.
:param device: The device on which the train, dev and test tensors should be hosted. Choose from "cpu" and "cuda".
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
:param evaluate_every: Perform dev set evaluation after this many steps of training.
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
:param use_amp: Whether to use automatic mixed precision with Apex. One of the optimization levels must be chosen.
"O1" is recommended in almost all cases.
:param grad_acc_steps: Number of training steps for which the gradients should be accumulated.
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
:param local_rank: Local rank of process when distributed training via DDP is used.
:param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models.
:param log_learning_rate: Whether to log learning rate to Mlflow
:param log_loss_every: Log current train loss after this many train steps.
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
a SIGTERM notifies to save the training state and subsequently the instance is terminated.
:param checkpoint_every: save a train checkpoint after this many steps of training.
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
:param checkpoints_to_keep: maximum number of train checkpoints to save.
:param from_epoch: the epoch number to start the training from. In the case when training resumes from a saved
checkpoint, it is used to fast-forward training to the last epoch in the checkpoint.
:param from_step: the step number to start the training from. In the case when training resumes from a saved
checkpoint, it is used to fast-forward training to the last step in the checkpoint.
:param global_step: the global step number across the training epochs.
:param evaluator_test: whether to perform evaluation on the test set
:param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments)
:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits)
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
"""
super().__init__(model=model, optimizer=optimizer,
data_silo=data_silo, epochs=epochs,
n_gpu=n_gpu, device=device,
lr_schedule=lr_schedule, evaluate_every=evaluate_every,
eval_report=eval_report, use_amp=use_amp,
grad_acc_steps=grad_acc_steps, local_rank=local_rank,
early_stopping=early_stopping, log_learning_rate=log_learning_rate,
log_loss_every=log_loss_every, checkpoint_on_sigterm=checkpoint_on_sigterm,
checkpoint_every=checkpoint_every, checkpoint_root_dir=checkpoint_root_dir,
checkpoints_to_keep=checkpoints_to_keep, from_epoch=from_epoch,
from_step=from_step, global_step=global_step,
evaluator_test=evaluator_test, disable_tqdm=disable_tqdm,
max_grad_norm=max_grad_norm)
self.teacher_model = teacher_model
# creating dummy inputs to get the shapes of hidden states and attention of teacher and student model
dummy_inputs = teacher_model.language_model.model.dummy_inputs
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(device)
dummy_inputs["padding_mask"] = torch.ones_like(dummy_inputs["input_ids"], device=device)
dummy_inputs["segment_ids"] = torch.zeros_like(dummy_inputs["input_ids"], device=device)
with torch.no_grad():
_, teacher_hidden_states, teacher_attentions = self.teacher_model.forward(**dummy_inputs, output_attentions=True, output_hidden_states=True)
_, hidden_states, attentions = self.model.forward(**dummy_inputs, output_attentions=True, output_hidden_states=True)
if len(teacher_attentions) % len(attentions) != 0:
raise ValueError("Teacher and student model do not seem to be compatible. Have you made sure that the student is a TinyBERT model and that the teacher is a BERT model?")
self.teacher_block_size = len(teacher_attentions) // len(attentions)
teacher_dims = [hidden_state.shape[-1] for hidden_state in teacher_hidden_states]
student_dims = [hidden_state.shape[-1] for hidden_state in hidden_states]
# creating linear mappings in case the teacher and student model have different hidden state dimensions
self.dim_mappings: List[Optional[Linear]] = []
for teacher_dim, student_dim in zip(teacher_dims, student_dims):
if teacher_dim != student_dim:
self.dim_mappings.append(Linear(student_dim, teacher_dim, bias=False))
else:
self.dim_mappings.append(None)
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
with torch.no_grad():
_, teacher_hidden_states, teacher_attentions = self.teacher_model.forward(**batch, output_attentions=True, output_hidden_states=True)
_, hidden_states, attentions = self.model.forward(**batch, output_attentions=True, output_hidden_states=True)
loss = torch.tensor(0., device=self.device)
# calculating attention loss
for student_attention, teacher_attention, dim_mapping in zip(attentions,
teacher_attentions[self.teacher_block_size - 1::self.teacher_block_size],
self.dim_mappings):
# this wasn't described in the paper, but it was used in the original implementation
student_attention = torch.where(student_attention <= -1e2, torch.zeros_like(student_attention),
student_attention)
teacher_attention = torch.where(teacher_attention <= -1e2, torch.zeros_like(teacher_attention),
teacher_attention)
loss += F.mse_loss(student_attention, teacher_attention)
# calculating hidden state loss
for student_hidden_state, teacher_hidden_state in zip(hidden_states, teacher_hidden_states[::self.teacher_block_size]):
# linear mapping in case the teacher and student model have different hidden state dimensions, not necessary for attention as attention shape is determined by number of attention heads and sequence length
if dim_mapping:
student_hidden_state = dim_mapping(student_hidden_state)
loss += F.mse_loss(student_hidden_state, teacher_hidden_state)
return self.backward_propagate(loss, step)

View File

@ -15,7 +15,7 @@ from haystack.modeling.infer import QAInferencer
from haystack.modeling.model.optimization import initialize_optimizer from haystack.modeling.model.optimization import initialize_optimizer
from haystack.modeling.model.predictions import QAPred, QACandidate from haystack.modeling.model.predictions import QAPred, QACandidate
from haystack.modeling.model.adaptive_model import AdaptiveModel from haystack.modeling.model.adaptive_model import AdaptiveModel
from haystack.modeling.training import Trainer, DistillationTrainer from haystack.modeling.training import Trainer, DistillationTrainer, TinyBERTDistillationTrainer
from haystack.modeling.evaluation import Evaluator from haystack.modeling.evaluation import Evaluator
from haystack.modeling.utils import set_all_seeds, initialize_device_settings from haystack.modeling.utils import set_all_seeds, initialize_device_settings
@ -181,7 +181,8 @@ class FARMReader(BaseReader):
cache_path: Path = Path("cache/data_silo"), cache_path: Path = Path("cache/data_silo"),
distillation_loss_weight: float = 0.5, distillation_loss_weight: float = 0.5,
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div",
temperature: float = 1.0 temperature: float = 1.0,
tinybert: bool = False,
): ):
if dev_filename: if dev_filename:
dev_split = 0 dev_split = 0
@ -221,10 +222,10 @@ class FARMReader(BaseReader):
# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them # 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
# and calculates a few descriptive statistics of our datasets # and calculates a few descriptive statistics of our datasets
if teacher_model: # checks if teacher model is passed as parameter, in that case assume model distillation is used if teacher_model and not tinybert: # checks if teacher model is passed as parameter, in that case assume model distillation is used
data_silo = DistillationDataSilo(teacher_model, teacher_batch_size or batch_size, device=devices[0], processor=processor, batch_size=batch_size, distributed=False, data_silo = DistillationDataSilo(teacher_model, teacher_batch_size or batch_size, device=devices[0], processor=processor, batch_size=batch_size, distributed=False,
max_processes=num_processes, caching=caching, cache_path=cache_path) max_processes=num_processes, caching=caching, cache_path=cache_path)
else: else: # caching would need too much memory for tinybert distillation so in that case we use the default data silo
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False, max_processes=num_processes, caching=caching, cache_path=cache_path) data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False, max_processes=num_processes, caching=caching, cache_path=cache_path)
# 3. Create an optimizer and pass the already initialized model # 3. Create an optimizer and pass the already initialized model
@ -239,7 +240,27 @@ class FARMReader(BaseReader):
use_amp=use_amp, use_amp=use_amp,
) )
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
if teacher_model: # checks again if teacher model is passed as parameter, in that case assume model distillation is used if tinybert:
if not teacher_model:
raise ValueError("TinyBERT distillation requires a teacher model.")
trainer = TinyBERTDistillationTrainer.create_or_load_checkpoint(
model=model,
teacher_model=teacher_model.inferencer.model, # teacher needs to be passed as teacher outputs aren't cached
optimizer=optimizer,
data_silo=data_silo,
epochs=n_epochs,
n_gpu=n_gpu,
lr_schedule=lr_schedule,
evaluate_every=evaluate_every,
device=devices[0],
use_amp=use_amp,
disable_tqdm=not self.progress_bar,
checkpoint_root_dir=Path(checkpoint_root_dir),
checkpoint_every=checkpoint_every,
checkpoints_to_keep=checkpoints_to_keep,
)
elif teacher_model: # checks again if teacher model is passed as parameter, in that case assume model distillation is used
trainer = DistillationTrainer.create_or_load_checkpoint( trainer = DistillationTrainer.create_or_load_checkpoint(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
@ -383,7 +404,9 @@ class FARMReader(BaseReader):
cache_path: Path = Path("cache/data_silo"), cache_path: Path = Path("cache/data_silo"),
distillation_loss_weight: float = 0.5, distillation_loss_weight: float = 0.5,
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div",
temperature: float = 1.0 temperature: float = 1.0,
tinybert_loss: bool = False,
tinybert_epochs: int = 1,
): ):
""" """
Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset
@ -438,10 +461,24 @@ class FARMReader(BaseReader):
:param caching whether or not to use caching for preprocessed dataset and teacher logits :param caching whether or not to use caching for preprocessed dataset and teacher logits
:param cache_path: Path to cache the preprocessed dataset and teacher logits :param cache_path: Path to cache the preprocessed dataset and teacher logits
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important. :param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits) :param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. :param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
:param tinybert_loss: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased.
:param tinybert_epochs: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
:return: None :return: None
""" """
if tinybert_loss: # do hidden state and attention distillation as additional stage
self._training_procedure(data_dir=data_dir, train_filename=train_filename,
dev_filename=dev_filename, test_filename=test_filename,
use_gpu=use_gpu, batch_size=student_batch_size,
n_epochs=tinybert_epochs, learning_rate=learning_rate,
max_seq_len=max_seq_len, warmup_proportion=warmup_proportion,
dev_split=dev_split, evaluate_every=evaluate_every,
save_dir=save_dir, num_processes=num_processes,
use_amp=use_amp, checkpoint_root_dir=checkpoint_root_dir,
checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep,
teacher_model=teacher_model, teacher_batch_size=teacher_batch_size,
caching=caching, cache_path=cache_path, tinybert=True)
return self._training_procedure(data_dir=data_dir, train_filename=train_filename, return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
dev_filename=dev_filename, test_filename=test_filename, dev_filename=dev_filename, test_filename=test_filename,
use_gpu=use_gpu, batch_size=student_batch_size, use_gpu=use_gpu, batch_size=student_batch_size,

View File

@ -1,15 +1,23 @@
from haystack.nodes import FARMReader from haystack.nodes import FARMReader
import torch import torch
def create_checkpoint(model):
weights = []
for name, weight in model.inferencer.model.named_parameters():
if "weight" in name and weight.requires_grad:
weights.append(torch.clone(weight))
return weights
def assert_weight_change(weights, new_weights):
print([torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights)])
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights))
def test_distillation(): def test_distillation():
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny", num_processes=0) student = FARMReader(model_name_or_path="prajjwal1/bert-tiny", num_processes=0)
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small", num_processes=0) teacher = FARMReader(model_name_or_path="prajjwal1/bert-small", num_processes=0)
# create a checkpoint of weights before distillation # create a checkpoint of weights before distillation
student_weights = [] student_weights = create_checkpoint(student)
for name, weight in student.inferencer.model.named_parameters():
if "weight" in name and weight.requires_grad:
student_weights.append(torch.clone(weight))
assert len(student_weights) == 22 assert len(student_weights) == 22
@ -18,16 +26,36 @@ def test_distillation():
student.distil_from(teacher, data_dir="samples/squad", train_filename="tiny.json") student.distil_from(teacher, data_dir="samples/squad", train_filename="tiny.json")
# create new checkpoint # create new checkpoint
new_student_weights = [torch.clone(param) for param in student.inferencer.model.parameters()] new_student_weights = create_checkpoint(student)
new_student_weights = []
for name, weight in student.inferencer.model.named_parameters():
if "weight" in name and weight.requires_grad:
new_student_weights.append(weight)
assert len(new_student_weights) == 22 assert len(new_student_weights) == 22
new_student_weights.pop(-2) # pooler is not updated due to different attention head new_student_weights.pop(-2) # pooler is not updated due to different attention head
# check if weights have changed # check if weights have changed
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(student_weights, new_student_weights)) assert_weight_change(student_weights, new_student_weights)
def test_tinybert_distillation():
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D")
teacher = FARMReader(model_name_or_path="bert-base-uncased")
# create a checkpoint of weights before distillation
student_weights = create_checkpoint(student)
assert len(student_weights) == 38
student_weights.pop(-1) # last layer is not affected by tinybert loss
student_weights.pop(-1) # pooler is not updated due to different attention head
student._training_procedure(teacher_model=teacher, tinybert=True, data_dir="samples/squad", train_filename="tiny.json")
# create new checkpoint
new_student_weights = create_checkpoint(student)
assert len(new_student_weights) == 38
new_student_weights.pop(-1) # last layer is not affected by tinybert loss
new_student_weights.pop(-1) # pooler is not updated due to different attention head
# check if weights have changed
assert_weight_change(student_weights, new_student_weights)