distinguish intermediate layer & prediction layer distillation phases with different parameters (#2001)

* add parameters to allow for different hyperparameters in stage 1 and 2 of tinybert distillation

* Add latest docstring and tutorial changes

* improve default parameters

* Add latest docstring and tutorial changes

* split up distillation method

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
MichelBartels 2022-01-14 20:40:38 +01:00 committed by GitHub
parent f42d2e8ba0
commit 0cca2b97cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 188 additions and 27 deletions

View File

@ -157,23 +157,26 @@ If any checkpoints are stored, a subsequent run of train() will resume training
None
<a name="farm.FARMReader.distil_from"></a>
#### distil\_from
<a name="farm.FARMReader.distil_prediction_layer_from"></a>
#### distil\_prediction\_layer\_from
```python
| distil_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0, tinybert_loss: bool = False, tinybert_epochs: int = 1)
| distil_prediction_layer_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 3e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0)
```
Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset
and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model
Fine-tune a model on a QA dataset using logit-based distillation. You need to provide a teacher model that is already finetuned on the dataset
and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model.
using a more complex teacher.
Originally proposed in: https://arxiv.org/pdf/1503.02531.pdf
This can also be considered as the second stage of distillation finetuning as described in the TinyBERT paper:
https://arxiv.org/pdf/1909.10351.pdf
**Example**
```python
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2")
student.distil_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
student.distil_prediction_layer_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
```
@ -222,6 +225,75 @@ If any checkpoints are stored, a subsequent run of train() will resume training
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
- `tinybert_loss`: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased.
- `tinybert_epochs`: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
- `tinybert_learning_rate`: Learning rate to use when training the student model with the TinyBERT loss function.
- `tinybert_train_filename`: Filename of training data to use when training the student model with the TinyBERT loss function. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script. If not specified, the training data from the original training is used.
**Returns**:
None
<a name="farm.FARMReader.distil_intermediate_layers_from"></a>
#### distil\_intermediate\_layers\_from
```python
| distil_intermediate_layers_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 5, learning_rate: float = 5e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse", temperature: float = 1.0)
```
The first stage of distillation finetuning as described in the TinyBERT paper:
https://arxiv.org/pdf/1909.10351.pdf
**Example**
```python
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
teacher = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D")
student.distil_intermediate_layers_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
```
Checkpoints can be stored via setting `checkpoint_every` to a custom number of steps.
If any checkpoints are stored, a subsequent run of train() will resume training from the latest available checkpoint.
**Arguments**:
- `teacher_model`: Model whose logits will be used to improve accuracy
- `data_dir`: Path to directory containing your training data in SQuAD style
- `train_filename`: Filename of training data. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script
- `dev_filename`: Filename of dev / eval data
- `test_filename`: Filename of test data
- `dev_split`: Instead of specifying a dev_filename, you can also specify a ratio (e.g. 0.1) here
that gets split off from training data for eval.
- `use_gpu`: Whether to use GPU (if available)
- `student_batch_size`: Number of samples the student model receives in one batch for training
- `student_batch_size`: Number of samples the teacher model receives in one batch for distillation
- `n_epochs`: Number of iterations on the whole training data set
- `learning_rate`: Learning rate of the optimizer
- `max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down.
- `warmup_proportion`: Proportion of training steps until maximum learning rate is reached.
Until that point LR is increasing linearly. After that it's decreasing again linearly.
Options for different schedules are available in FARM.
- `evaluate_every`: Evaluate the model every X steps on the hold-out eval dataset
- `save_dir`: Path to store the final model
- `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing.
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
Set to None to use all CPU cores minus one.
- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
Available options:
None (Don't use AMP)
"O0" (Normal FP32 training)
"O1" (Mixed Precision => Recommended)
"O2" (Almost FP16)
"O3" (Pure FP16).
See details on: https://nvidia.github.io/apex/amp.html
- `checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
- `checkpoint_every`: save a train checkpoint after this many steps of training.
- `checkpoints_to_keep`: maximum number of train checkpoints to save.
:param caching whether or not to use caching for preprocessed dataset and teacher logits
- `cache_path`: Path to cache the preprocessed dataset and teacher logits
- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
**Returns**:

View File

@ -203,6 +203,8 @@ class FARMReader(BaseReader):
if not save_dir:
save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}"
if tinybert:
save_dir += "_tinybert_stage_1"
# 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
label_list = ["start_token", "end_token"]
@ -378,7 +380,7 @@ class FARMReader(BaseReader):
checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep,
caching=caching, cache_path=cache_path)
def distil_from(
def distil_prediction_layer_from(
self,
teacher_model: "FARMReader",
data_dir: str,
@ -389,7 +391,7 @@ class FARMReader(BaseReader):
student_batch_size: int = 10,
teacher_batch_size: Optional[int] = None,
n_epochs: int = 2,
learning_rate: float = 1e-5,
learning_rate: float = 3e-5,
max_seq_len: Optional[int] = None,
warmup_proportion: float = 0.2,
dev_split: float = 0,
@ -405,20 +407,21 @@ class FARMReader(BaseReader):
distillation_loss_weight: float = 0.5,
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div",
temperature: float = 1.0,
tinybert_loss: bool = False,
tinybert_epochs: int = 1,
):
"""
Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset
and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model
Fine-tune a model on a QA dataset using logit-based distillation. You need to provide a teacher model that is already finetuned on the dataset
and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model.
using a more complex teacher.
Originally proposed in: https://arxiv.org/pdf/1503.02531.pdf
This can also be considered as the second stage of distillation finetuning as described in the TinyBERT paper:
https://arxiv.org/pdf/1909.10351.pdf
**Example**
```python
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2")
student.distil_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
student.distil_prediction_layer_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
```
@ -465,20 +468,10 @@ class FARMReader(BaseReader):
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
:param tinybert_loss: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased.
:param tinybert_epochs: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
:param tinybert_learning_rate: Learning rate to use when training the student model with the TinyBERT loss function.
:param tinybert_train_filename: Filename of training data to use when training the student model with the TinyBERT loss function. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script. If not specified, the training data from the original training is used.
:return: None
"""
if tinybert_loss: # do hidden state and attention distillation as additional stage
self._training_procedure(data_dir=data_dir, train_filename=train_filename,
dev_filename=dev_filename, test_filename=test_filename,
use_gpu=use_gpu, batch_size=student_batch_size,
n_epochs=tinybert_epochs, learning_rate=learning_rate,
max_seq_len=max_seq_len, warmup_proportion=warmup_proportion,
dev_split=dev_split, evaluate_every=evaluate_every,
save_dir=save_dir, num_processes=num_processes,
use_amp=use_amp, checkpoint_root_dir=checkpoint_root_dir,
checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep,
teacher_model=teacher_model, teacher_batch_size=teacher_batch_size,
caching=caching, cache_path=cache_path, tinybert=True)
return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
dev_filename=dev_filename, test_filename=test_filename,
use_gpu=use_gpu, batch_size=student_batch_size,
@ -492,6 +485,102 @@ class FARMReader(BaseReader):
caching=caching, cache_path=cache_path, distillation_loss_weight=distillation_loss_weight,
distillation_loss=distillation_loss, temperature=temperature)
def distil_intermediate_layers_from(
self,
teacher_model: "FARMReader",
data_dir: str,
train_filename: str,
dev_filename: Optional[str] = None,
test_filename: Optional[str] = None,
use_gpu: Optional[bool] = None,
student_batch_size: int = 10,
teacher_batch_size: Optional[int] = None,
n_epochs: int = 5,
learning_rate: float = 5e-5,
max_seq_len: Optional[int] = None,
warmup_proportion: float = 0.2,
dev_split: float = 0,
evaluate_every: int = 300,
save_dir: Optional[str] = None,
num_processes: Optional[int] = None,
use_amp: str = None,
checkpoint_root_dir: Path = Path("model_checkpoints"),
checkpoint_every: Optional[int] = None,
checkpoints_to_keep: int = 3,
caching: bool = False,
cache_path: Path = Path("cache/data_silo"),
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse",
temperature: float = 1.0,
):
"""
The first stage of distillation finetuning as described in the TinyBERT paper:
https://arxiv.org/pdf/1909.10351.pdf
**Example**
```python
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
teacher = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D")
student.distil_intermediate_layers_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
```
Checkpoints can be stored via setting `checkpoint_every` to a custom number of steps.
If any checkpoints are stored, a subsequent run of train() will resume training from the latest available checkpoint.
:param teacher_model: Model whose logits will be used to improve accuracy
:param data_dir: Path to directory containing your training data in SQuAD style
:param train_filename: Filename of training data. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script
:param dev_filename: Filename of dev / eval data
:param test_filename: Filename of test data
:param dev_split: Instead of specifying a dev_filename, you can also specify a ratio (e.g. 0.1) here
that gets split off from training data for eval.
:param use_gpu: Whether to use GPU (if available)
:param student_batch_size: Number of samples the student model receives in one batch for training
:param student_batch_size: Number of samples the teacher model receives in one batch for distillation
:param n_epochs: Number of iterations on the whole training data set
:param learning_rate: Learning rate of the optimizer
:param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down.
:param warmup_proportion: Proportion of training steps until maximum learning rate is reached.
Until that point LR is increasing linearly. After that it's decreasing again linearly.
Options for different schedules are available in FARM.
:param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset
:param save_dir: Path to store the final model
:param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing.
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
Set to None to use all CPU cores minus one.
:param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
Available options:
None (Don't use AMP)
"O0" (Normal FP32 training)
"O1" (Mixed Precision => Recommended)
"O2" (Almost FP16)
"O3" (Pure FP16).
See details on: https://nvidia.github.io/apex/amp.html
:param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
:param checkpoint_every: save a train checkpoint after this many steps of training.
:param checkpoints_to_keep: maximum number of train checkpoints to save.
:param caching whether or not to use caching for preprocessed dataset and teacher logits
:param cache_path: Path to cache the preprocessed dataset and teacher logits
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
:return: None
"""
return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
dev_filename=dev_filename, test_filename=test_filename,
use_gpu=use_gpu, batch_size=student_batch_size,
n_epochs=n_epochs, learning_rate=learning_rate,
max_seq_len=max_seq_len, warmup_proportion=warmup_proportion,
dev_split=dev_split, evaluate_every=evaluate_every,
save_dir=save_dir, num_processes=num_processes,
use_amp=use_amp, checkpoint_root_dir=checkpoint_root_dir,
checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep,
teacher_model=teacher_model, teacher_batch_size=teacher_batch_size,
caching=caching, cache_path=cache_path,
distillation_loss=distillation_loss, temperature=temperature, tinybert=True)
def update_parameters(
self,
context_window_size: Optional[int] = None,

View File

@ -23,7 +23,7 @@ def test_distillation():
student_weights.pop(-2) # pooler is not updated due to different attention head
student.distil_from(teacher, data_dir="samples/squad", train_filename="tiny.json")
student.distil_prediction_layer_from(teacher, data_dir="samples/squad", train_filename="tiny.json")
# create new checkpoint
new_student_weights = create_checkpoint(student)
@ -47,7 +47,7 @@ def test_tinybert_distillation():
student_weights.pop(-1) # last layer is not affected by tinybert loss
student_weights.pop(-1) # pooler is not updated due to different attention head
student._training_procedure(teacher_model=teacher, tinybert=True, data_dir="samples/squad", train_filename="tiny.json")
student.distil_intermediate_layers_from(teacher_model=teacher, data_dir="samples/squad", train_filename="tiny.json")
# create new checkpoint
new_student_weights = create_checkpoint(student)