mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-18 11:28:38 +00:00
Adding distillation loss functions from TinyBERT (#1879)
* initial tinybertdistill commit * add tinybert distill loss * remove teacher caching for tinybert * add tinybert to distil_from method * Add latest docstring and tutorial changes * add dim mapping and fix type hints * fix type hints * fix dummy input * fix dim mapping for tinybert loss and add comments/doc strings * add test for tinybert loss * Add latest docstring and tutorial changes * add comment * fix BERT forward parameters * add doc string to AdaptiveModel forward method * remove unnecessary data silo * fix farm import Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
fc8df2163d
commit
f33c2b987a
@ -161,7 +161,7 @@ None
|
||||
#### distil\_from
|
||||
|
||||
```python
|
||||
| distil_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0)
|
||||
| distil_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0, tinybert_loss: bool = False, tinybert_epochs: int = 1)
|
||||
```
|
||||
|
||||
Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset
|
||||
@ -218,8 +218,10 @@ If any checkpoints are stored, a subsequent run of train() will resume training
|
||||
:param caching whether or not to use caching for preprocessed dataset and teacher logits
|
||||
- `cache_path`: Path to cache the preprocessed dataset and teacher logits
|
||||
- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
||||
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits)
|
||||
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
||||
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
||||
- `tinybert_loss`: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased.
|
||||
- `tinybert_epochs`: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
|
||||
|
||||
**Returns**:
|
||||
|
||||
|
@ -743,19 +743,29 @@ class DistillationDataSilo(DataSilo):
|
||||
super().__init__(max_processes=max_processes, processor=processor, batch_size=batch_size, eval_batch_size=eval_batch_size,
|
||||
distributed=distributed, automatic_loading=automatic_loading, caching=caching, cache_path=cache_path)
|
||||
|
||||
def _run_teacher(self, batch: List[List[torch.Tensor]], corresponding_chunks: List[int],
|
||||
def _run_teacher(self, batch: dict) -> List[torch.Tensor]:
|
||||
"""
|
||||
Run the teacher model on the given batch.
|
||||
"""
|
||||
return self.teacher.inferencer.model(**batch)
|
||||
|
||||
def _pass_batches(self, batch: List[List[torch.Tensor]], corresponding_chunks: List[int],
|
||||
teacher_outputs: List[List[Tuple[torch.Tensor, ...]]], tensor_names: List[str]):
|
||||
with torch.no_grad():
|
||||
batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...)
|
||||
batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature
|
||||
batch_dict = {key: tensor.to(self.device) for key, tensor in zip(tensor_names, batch_transposed_list)} # create input dict
|
||||
y = self.teacher.inferencer.model(**batch_dict)
|
||||
y = self._run_teacher(batch=batch_dict) # run teacher model
|
||||
y = [y.cpu() for y in y]
|
||||
self.output_len = len(y)
|
||||
|
||||
# grouping by chunk
|
||||
for i, data in zip(corresponding_chunks, zip(*y)): # transpose back
|
||||
teacher_outputs[i].append(data)
|
||||
return
|
||||
|
||||
def _teacher_output_names(self) -> List[str]:
|
||||
return ["teacher_output_" + str(i) for i in range(self.output_len)]
|
||||
|
||||
def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[List[Dict]] = None):
|
||||
concat_datasets, tensor_names = super()._get_dataset(filename, dicts)
|
||||
@ -772,16 +782,16 @@ class DistillationDataSilo(DataSilo):
|
||||
batch.append(x)
|
||||
corresponding_chunks.append(i)
|
||||
if len(batch) == self.teacher_batch_size:
|
||||
self._run_teacher(batch, corresponding_chunks, teacher_outputs, tensor_names) # doing forward pass on teacher model
|
||||
self._pass_batches(batch, corresponding_chunks, teacher_outputs, tensor_names) # doing forward pass on teacher model
|
||||
batch = []
|
||||
corresponding_chunks = []
|
||||
if batch:
|
||||
self._run_teacher(batch, corresponding_chunks, teacher_outputs, tensor_names)
|
||||
self._pass_batches(batch, corresponding_chunks, teacher_outputs, tensor_names)
|
||||
|
||||
# appending teacher outputs to original dataset
|
||||
for dataset, teacher_output in zip(concat_datasets.datasets, teacher_outputs):
|
||||
dataset.tensors += tuple(torch.stack(tensors) for tensors in zip(*teacher_output))
|
||||
tensor_names.extend(["teacher_output_" + str(i) for i, _ in enumerate(zip(*teacher_output))])
|
||||
tensor_names += self._teacher_output_names()
|
||||
concat_datasets = ConcatDataset(concat_datasets.datasets) # making sure metrics are updated
|
||||
return concat_datasets, tensor_names
|
||||
|
||||
@ -796,7 +806,8 @@ class DistillationDataSilo(DataSilo):
|
||||
"max_seq_len": self.processor.max_seq_len,
|
||||
"dev_split": self.processor.dev_split,
|
||||
"tasks": self.processor.tasks,
|
||||
"teacher_name_or_path": self.teacher.pipeline_config["params"]["model_name_or_path"]
|
||||
"teacher_name_or_path": self.teacher.pipeline_config["params"]["model_name_or_path"],
|
||||
"data_silo_type": self.__class__.__name__,
|
||||
}
|
||||
checksum = get_dict_checksum(payload_dict)
|
||||
return checksum
|
@ -8,7 +8,7 @@ from typing import Iterable, Dict, Union, List, Optional, Callable
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import nn, set_warn_always
|
||||
from transformers import AutoConfig
|
||||
from transformers.convert_graph_to_onnx import convert, quantize as quantize_model
|
||||
|
||||
@ -356,18 +356,29 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
all_labels.append(labels)
|
||||
return all_labels
|
||||
|
||||
def forward(self, **kwargs):
|
||||
def forward(self, output_hidden_states: bool = False, output_attentions: bool = False, **kwargs):
|
||||
"""
|
||||
Push data through the whole model and returns logits. The data will
|
||||
propagate through the language model and each of the attached prediction heads.
|
||||
|
||||
:param kwargs: Holds all arguments that need to be passed to the language model
|
||||
and prediction head(s).
|
||||
:param output_hidden_states: Whether to output hidden states
|
||||
:param output_attentions: Whether to output attentions
|
||||
:return: All logits as torch.tensor or multiple tensors.
|
||||
"""
|
||||
# Run forward pass of language model
|
||||
sequence_output, pooled_output = self.forward_lm(**kwargs)
|
||||
|
||||
output_tuple = self.language_model.forward(**kwargs, output_hidden_states=output_hidden_states, output_attentions=output_attentions)
|
||||
if output_hidden_states:
|
||||
if output_attentions:
|
||||
sequence_output, pooled_output, hidden_states, attentions = output_tuple
|
||||
else:
|
||||
sequence_output, pooled_output, hidden_states = output_tuple
|
||||
else:
|
||||
if output_attentions:
|
||||
sequence_output, pooled_output, attentions = output_tuple
|
||||
else:
|
||||
sequence_output, pooled_output = output_tuple
|
||||
# Run forward pass of (multiple) prediction heads using the output from above
|
||||
all_logits = []
|
||||
if len(self.prediction_heads) > 0:
|
||||
@ -392,6 +403,13 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
# just return LM output (e.g. useful for extracting embeddings at inference time)
|
||||
all_logits.append((sequence_output, pooled_output))
|
||||
|
||||
if output_hidden_states:
|
||||
if output_attentions:
|
||||
return all_logits, hidden_states, attentions
|
||||
else:
|
||||
return all_logits, hidden_states
|
||||
elif output_attentions:
|
||||
return all_logits, attentions
|
||||
return all_logits
|
||||
|
||||
def forward_lm(self, **kwargs):
|
||||
|
@ -484,6 +484,8 @@ class Bert(LanguageModel):
|
||||
input_ids: torch.Tensor,
|
||||
segment_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -495,19 +497,24 @@ class Bert(LanguageModel):
|
||||
It is a tensor of shape [batch_size, max_seq_len]
|
||||
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, max_seq_len]
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
|
||||
:param output_attentions: Whether to output attentions in addition to the embeddings
|
||||
:return: Embeddings for each token in the input sequence. Can also return hidden states and attentions if specified via the arguments output_hidden_states and output_attentions
|
||||
"""
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.model.encoder.config.output_hidden_states
|
||||
if output_attentions is None:
|
||||
output_attentions = self.model.encoder.config.output_attentions
|
||||
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=padding_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=False
|
||||
)
|
||||
if self.model.encoder.config.output_hidden_states == True:
|
||||
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
|
||||
return sequence_output, pooled_output, all_hidden_states
|
||||
else:
|
||||
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
|
||||
return sequence_output, pooled_output
|
||||
return output_tuple
|
||||
|
||||
def enable_hidden_states_output(self):
|
||||
self.model.encoder.config.output_hidden_states = True
|
||||
|
@ -1 +1 @@
|
||||
from haystack.modeling.training.base import Trainer, DistillationTrainer
|
||||
from haystack.modeling.training.base import Trainer, DistillationTrainer, TinyBERTDistillationTrainer
|
@ -1,9 +1,6 @@
|
||||
from typing import Optional, Union, Tuple, List, Callable
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from haystack.nodes import FARMReader
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
@ -14,7 +11,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
from torch.nn import MSELoss
|
||||
from torch.nn import MSELoss, Linear
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@ -630,7 +627,7 @@ class DistillationTrainer(Trainer):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: "FARMReader",
|
||||
model: "AdaptiveModel",
|
||||
optimizer: Optimizer,
|
||||
data_silo: DistillationDataSilo,
|
||||
epochs: int,
|
||||
@ -730,4 +727,165 @@ class DistillationTrainer(Trainer):
|
||||
student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch)
|
||||
distillation_loss = self.distillation_loss_fn(student_logits=logits[0] / self.temperature, teacher_logits=teacher_logits[0] / self.temperature)
|
||||
combined_loss = distillation_loss * self.distillation_loss_weight * (self.temperature ** 2) + student_loss * (1 - self.distillation_loss_weight)
|
||||
return self.backward_propagate(combined_loss, step)
|
||||
return self.backward_propagate(combined_loss, step)
|
||||
|
||||
class TinyBERTDistillationTrainer(Trainer):
|
||||
"""
|
||||
This Trainer implements the first stage of task specific distillation as described in the TinyBERT paper.
|
||||
The standard DistillationTrainer can be used for the second stage. Unlike the DistillationTrainer, this Trainer does not use
|
||||
cached teacher outputs as it would be too memory expensive. This means it is much slower than the DistillationTrainer.
|
||||
|
||||
**Example**
|
||||
```python
|
||||
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D")
|
||||
teacher = FARMReader(model_name_or_path="twmkn9/bert-base-uncased-squad2")
|
||||
|
||||
processor = SquadProcessor(tokenizer=student.inferencer.processor.tokenizer, max_seq_len=384)
|
||||
student, optimizer, _ = initialize_optimizer(student, n_batches=len(data_silo.loaders["train"]), n_epochs=3, device="cuda:0", learning_rate=3e-5)
|
||||
|
||||
data_silo = DataSilo(teacher_model=teacher, batch_size=8, device="cuda:0", processor=processor)
|
||||
trainer = TinyBertDistillationTrainer(student=student, optimizer=optimizer, data_silo=data_silo, epochs=3, n_gpu=1, device="cuda:0")
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: AdaptiveModel,
|
||||
teacher_model: AdaptiveModel,
|
||||
optimizer: Optimizer,
|
||||
data_silo: DistillationDataSilo,
|
||||
epochs: int,
|
||||
n_gpu: int,
|
||||
device: str,
|
||||
lr_schedule: Optional["_LRScheduler"]=None,
|
||||
evaluate_every: int = 100,
|
||||
eval_report: bool = True,
|
||||
use_amp: Optional[str] = None,
|
||||
grad_acc_steps: int = 1,
|
||||
local_rank: int = -1,
|
||||
early_stopping: Optional[EarlyStopping] = None,
|
||||
log_learning_rate: bool = False,
|
||||
log_loss_every: int = 10,
|
||||
checkpoint_on_sigterm: bool = False,
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoint_root_dir: Optional[Path] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
from_epoch: int = 0,
|
||||
from_step: int = 0,
|
||||
global_step: int = 0,
|
||||
evaluator_test: bool = True,
|
||||
disable_tqdm: bool = False,
|
||||
max_grad_norm: float = 1.0,
|
||||
):
|
||||
"""
|
||||
:param optimizer: An optimizer object that determines the learning strategy to be used during training
|
||||
:param model: The model to be trained. It needs to be a TinyBERT model.
|
||||
:param teacher_model: The teacher model used for distillation. This has to be based on bert-base-uncased.
|
||||
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
||||
:param epochs: How many times the training procedure will loop through the train dataset
|
||||
:param n_gpu: The number of gpus available for training and evaluation.
|
||||
:param device: The device on which the train, dev and test tensors should be hosted. Choose from "cpu" and "cuda".
|
||||
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
||||
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
||||
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
||||
:param use_amp: Whether to use automatic mixed precision with Apex. One of the optimization levels must be chosen.
|
||||
"O1" is recommended in almost all cases.
|
||||
:param grad_acc_steps: Number of training steps for which the gradients should be accumulated.
|
||||
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
||||
:param local_rank: Local rank of process when distributed training via DDP is used.
|
||||
:param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models.
|
||||
:param log_learning_rate: Whether to log learning rate to Mlflow
|
||||
:param log_loss_every: Log current train loss after this many train steps.
|
||||
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
|
||||
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
|
||||
a SIGTERM notifies to save the training state and subsequently the instance is terminated.
|
||||
:param checkpoint_every: save a train checkpoint after this many steps of training.
|
||||
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
|
||||
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
|
||||
:param checkpoints_to_keep: maximum number of train checkpoints to save.
|
||||
:param from_epoch: the epoch number to start the training from. In the case when training resumes from a saved
|
||||
checkpoint, it is used to fast-forward training to the last epoch in the checkpoint.
|
||||
:param from_step: the step number to start the training from. In the case when training resumes from a saved
|
||||
checkpoint, it is used to fast-forward training to the last step in the checkpoint.
|
||||
:param global_step: the global step number across the training epochs.
|
||||
:param evaluator_test: whether to perform evaluation on the test set
|
||||
:param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments)
|
||||
:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable
|
||||
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
||||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits)
|
||||
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
||||
"""
|
||||
super().__init__(model=model, optimizer=optimizer,
|
||||
data_silo=data_silo, epochs=epochs,
|
||||
n_gpu=n_gpu, device=device,
|
||||
lr_schedule=lr_schedule, evaluate_every=evaluate_every,
|
||||
eval_report=eval_report, use_amp=use_amp,
|
||||
grad_acc_steps=grad_acc_steps, local_rank=local_rank,
|
||||
early_stopping=early_stopping, log_learning_rate=log_learning_rate,
|
||||
log_loss_every=log_loss_every, checkpoint_on_sigterm=checkpoint_on_sigterm,
|
||||
checkpoint_every=checkpoint_every, checkpoint_root_dir=checkpoint_root_dir,
|
||||
checkpoints_to_keep=checkpoints_to_keep, from_epoch=from_epoch,
|
||||
from_step=from_step, global_step=global_step,
|
||||
evaluator_test=evaluator_test, disable_tqdm=disable_tqdm,
|
||||
max_grad_norm=max_grad_norm)
|
||||
self.teacher_model = teacher_model
|
||||
|
||||
# creating dummy inputs to get the shapes of hidden states and attention of teacher and student model
|
||||
dummy_inputs = teacher_model.language_model.model.dummy_inputs
|
||||
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(device)
|
||||
dummy_inputs["padding_mask"] = torch.ones_like(dummy_inputs["input_ids"], device=device)
|
||||
dummy_inputs["segment_ids"] = torch.zeros_like(dummy_inputs["input_ids"], device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
_, teacher_hidden_states, teacher_attentions = self.teacher_model.forward(**dummy_inputs, output_attentions=True, output_hidden_states=True)
|
||||
_, hidden_states, attentions = self.model.forward(**dummy_inputs, output_attentions=True, output_hidden_states=True)
|
||||
|
||||
if len(teacher_attentions) % len(attentions) != 0:
|
||||
raise ValueError("Teacher and student model do not seem to be compatible. Have you made sure that the student is a TinyBERT model and that the teacher is a BERT model?")
|
||||
|
||||
self.teacher_block_size = len(teacher_attentions) // len(attentions)
|
||||
|
||||
teacher_dims = [hidden_state.shape[-1] for hidden_state in teacher_hidden_states]
|
||||
student_dims = [hidden_state.shape[-1] for hidden_state in hidden_states]
|
||||
|
||||
# creating linear mappings in case the teacher and student model have different hidden state dimensions
|
||||
self.dim_mappings: List[Optional[Linear]] = []
|
||||
|
||||
for teacher_dim, student_dim in zip(teacher_dims, student_dims):
|
||||
if teacher_dim != student_dim:
|
||||
self.dim_mappings.append(Linear(student_dim, teacher_dim, bias=False))
|
||||
else:
|
||||
self.dim_mappings.append(None)
|
||||
|
||||
|
||||
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
_, teacher_hidden_states, teacher_attentions = self.teacher_model.forward(**batch, output_attentions=True, output_hidden_states=True)
|
||||
|
||||
_, hidden_states, attentions = self.model.forward(**batch, output_attentions=True, output_hidden_states=True)
|
||||
|
||||
loss = torch.tensor(0., device=self.device)
|
||||
|
||||
# calculating attention loss
|
||||
for student_attention, teacher_attention, dim_mapping in zip(attentions,
|
||||
teacher_attentions[self.teacher_block_size - 1::self.teacher_block_size],
|
||||
self.dim_mappings):
|
||||
|
||||
# this wasn't described in the paper, but it was used in the original implementation
|
||||
student_attention = torch.where(student_attention <= -1e2, torch.zeros_like(student_attention),
|
||||
student_attention)
|
||||
teacher_attention = torch.where(teacher_attention <= -1e2, torch.zeros_like(teacher_attention),
|
||||
teacher_attention)
|
||||
|
||||
loss += F.mse_loss(student_attention, teacher_attention)
|
||||
|
||||
# calculating hidden state loss
|
||||
for student_hidden_state, teacher_hidden_state in zip(hidden_states, teacher_hidden_states[::self.teacher_block_size]):
|
||||
# linear mapping in case the teacher and student model have different hidden state dimensions, not necessary for attention as attention shape is determined by number of attention heads and sequence length
|
||||
if dim_mapping:
|
||||
student_hidden_state = dim_mapping(student_hidden_state)
|
||||
|
||||
loss += F.mse_loss(student_hidden_state, teacher_hidden_state)
|
||||
|
||||
return self.backward_propagate(loss, step)
|
@ -15,7 +15,7 @@ from haystack.modeling.infer import QAInferencer
|
||||
from haystack.modeling.model.optimization import initialize_optimizer
|
||||
from haystack.modeling.model.predictions import QAPred, QACandidate
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||
from haystack.modeling.training import Trainer, DistillationTrainer
|
||||
from haystack.modeling.training import Trainer, DistillationTrainer, TinyBERTDistillationTrainer
|
||||
from haystack.modeling.evaluation import Evaluator
|
||||
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
|
||||
|
||||
@ -181,7 +181,8 @@ class FARMReader(BaseReader):
|
||||
cache_path: Path = Path("cache/data_silo"),
|
||||
distillation_loss_weight: float = 0.5,
|
||||
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div",
|
||||
temperature: float = 1.0
|
||||
temperature: float = 1.0,
|
||||
tinybert: bool = False,
|
||||
):
|
||||
if dev_filename:
|
||||
dev_split = 0
|
||||
@ -221,10 +222,10 @@ class FARMReader(BaseReader):
|
||||
|
||||
# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
|
||||
# and calculates a few descriptive statistics of our datasets
|
||||
if teacher_model: # checks if teacher model is passed as parameter, in that case assume model distillation is used
|
||||
if teacher_model and not tinybert: # checks if teacher model is passed as parameter, in that case assume model distillation is used
|
||||
data_silo = DistillationDataSilo(teacher_model, teacher_batch_size or batch_size, device=devices[0], processor=processor, batch_size=batch_size, distributed=False,
|
||||
max_processes=num_processes, caching=caching, cache_path=cache_path)
|
||||
else:
|
||||
else: # caching would need too much memory for tinybert distillation so in that case we use the default data silo
|
||||
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False, max_processes=num_processes, caching=caching, cache_path=cache_path)
|
||||
|
||||
# 3. Create an optimizer and pass the already initialized model
|
||||
@ -239,7 +240,27 @@ class FARMReader(BaseReader):
|
||||
use_amp=use_amp,
|
||||
)
|
||||
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||
if teacher_model: # checks again if teacher model is passed as parameter, in that case assume model distillation is used
|
||||
if tinybert:
|
||||
if not teacher_model:
|
||||
raise ValueError("TinyBERT distillation requires a teacher model.")
|
||||
trainer = TinyBERTDistillationTrainer.create_or_load_checkpoint(
|
||||
model=model,
|
||||
teacher_model=teacher_model.inferencer.model, # teacher needs to be passed as teacher outputs aren't cached
|
||||
optimizer=optimizer,
|
||||
data_silo=data_silo,
|
||||
epochs=n_epochs,
|
||||
n_gpu=n_gpu,
|
||||
lr_schedule=lr_schedule,
|
||||
evaluate_every=evaluate_every,
|
||||
device=devices[0],
|
||||
use_amp=use_amp,
|
||||
disable_tqdm=not self.progress_bar,
|
||||
checkpoint_root_dir=Path(checkpoint_root_dir),
|
||||
checkpoint_every=checkpoint_every,
|
||||
checkpoints_to_keep=checkpoints_to_keep,
|
||||
)
|
||||
|
||||
elif teacher_model: # checks again if teacher model is passed as parameter, in that case assume model distillation is used
|
||||
trainer = DistillationTrainer.create_or_load_checkpoint(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
@ -383,7 +404,9 @@ class FARMReader(BaseReader):
|
||||
cache_path: Path = Path("cache/data_silo"),
|
||||
distillation_loss_weight: float = 0.5,
|
||||
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div",
|
||||
temperature: float = 1.0
|
||||
temperature: float = 1.0,
|
||||
tinybert_loss: bool = False,
|
||||
tinybert_epochs: int = 1,
|
||||
):
|
||||
"""
|
||||
Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset
|
||||
@ -438,10 +461,24 @@ class FARMReader(BaseReader):
|
||||
:param caching whether or not to use caching for preprocessed dataset and teacher logits
|
||||
:param cache_path: Path to cache the preprocessed dataset and teacher logits
|
||||
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
||||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits)
|
||||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
||||
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
||||
:param tinybert_loss: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased.
|
||||
:param tinybert_epochs: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
|
||||
:return: None
|
||||
"""
|
||||
if tinybert_loss: # do hidden state and attention distillation as additional stage
|
||||
self._training_procedure(data_dir=data_dir, train_filename=train_filename,
|
||||
dev_filename=dev_filename, test_filename=test_filename,
|
||||
use_gpu=use_gpu, batch_size=student_batch_size,
|
||||
n_epochs=tinybert_epochs, learning_rate=learning_rate,
|
||||
max_seq_len=max_seq_len, warmup_proportion=warmup_proportion,
|
||||
dev_split=dev_split, evaluate_every=evaluate_every,
|
||||
save_dir=save_dir, num_processes=num_processes,
|
||||
use_amp=use_amp, checkpoint_root_dir=checkpoint_root_dir,
|
||||
checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep,
|
||||
teacher_model=teacher_model, teacher_batch_size=teacher_batch_size,
|
||||
caching=caching, cache_path=cache_path, tinybert=True)
|
||||
return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
|
||||
dev_filename=dev_filename, test_filename=test_filename,
|
||||
use_gpu=use_gpu, batch_size=student_batch_size,
|
||||
|
@ -1,15 +1,23 @@
|
||||
from haystack.nodes import FARMReader
|
||||
import torch
|
||||
|
||||
def create_checkpoint(model):
|
||||
weights = []
|
||||
for name, weight in model.inferencer.model.named_parameters():
|
||||
if "weight" in name and weight.requires_grad:
|
||||
weights.append(torch.clone(weight))
|
||||
return weights
|
||||
|
||||
def assert_weight_change(weights, new_weights):
|
||||
print([torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights)])
|
||||
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights))
|
||||
|
||||
def test_distillation():
|
||||
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny", num_processes=0)
|
||||
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small", num_processes=0)
|
||||
|
||||
# create a checkpoint of weights before distillation
|
||||
student_weights = []
|
||||
for name, weight in student.inferencer.model.named_parameters():
|
||||
if "weight" in name and weight.requires_grad:
|
||||
student_weights.append(torch.clone(weight))
|
||||
student_weights = create_checkpoint(student)
|
||||
|
||||
assert len(student_weights) == 22
|
||||
|
||||
@ -18,16 +26,36 @@ def test_distillation():
|
||||
student.distil_from(teacher, data_dir="samples/squad", train_filename="tiny.json")
|
||||
|
||||
# create new checkpoint
|
||||
new_student_weights = [torch.clone(param) for param in student.inferencer.model.parameters()]
|
||||
|
||||
new_student_weights = []
|
||||
for name, weight in student.inferencer.model.named_parameters():
|
||||
if "weight" in name and weight.requires_grad:
|
||||
new_student_weights.append(weight)
|
||||
new_student_weights = create_checkpoint(student)
|
||||
|
||||
assert len(new_student_weights) == 22
|
||||
|
||||
new_student_weights.pop(-2) # pooler is not updated due to different attention head
|
||||
|
||||
# check if weights have changed
|
||||
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(student_weights, new_student_weights))
|
||||
assert_weight_change(student_weights, new_student_weights)
|
||||
|
||||
def test_tinybert_distillation():
|
||||
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D")
|
||||
teacher = FARMReader(model_name_or_path="bert-base-uncased")
|
||||
|
||||
# create a checkpoint of weights before distillation
|
||||
student_weights = create_checkpoint(student)
|
||||
|
||||
assert len(student_weights) == 38
|
||||
|
||||
student_weights.pop(-1) # last layer is not affected by tinybert loss
|
||||
student_weights.pop(-1) # pooler is not updated due to different attention head
|
||||
|
||||
student._training_procedure(teacher_model=teacher, tinybert=True, data_dir="samples/squad", train_filename="tiny.json")
|
||||
|
||||
# create new checkpoint
|
||||
new_student_weights = create_checkpoint(student)
|
||||
|
||||
assert len(new_student_weights) == 38
|
||||
|
||||
new_student_weights.pop(-1) # last layer is not affected by tinybert loss
|
||||
new_student_weights.pop(-1) # pooler is not updated due to different attention head
|
||||
|
||||
# check if weights have changed
|
||||
assert_weight_change(student_weights, new_student_weights)
|
Loading…
x
Reference in New Issue
Block a user