fix: inconsistent batch_size parameter names in distillation (#3811)

This commit is contained in:
Julian Risch 2023-01-10 11:38:21 +01:00 committed by GitHub
parent dea10a51d3
commit 0e42a9015e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -470,7 +470,7 @@ class FARMReader(BaseReader):
test_filename: Optional[str] = None,
use_gpu: Optional[bool] = None,
devices: List[torch.device] = [],
student_batch_size: int = 10,
batch_size: int = 10,
teacher_batch_size: Optional[int] = None,
n_epochs: int = 2,
learning_rate: float = 3e-5,
@ -489,6 +489,7 @@ class FARMReader(BaseReader):
distillation_loss_weight: float = 0.5,
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div",
temperature: float = 1.0,
processor: Optional[Processor] = None,
grad_acc_steps: int = 1,
early_stopping: Optional[EarlyStopping] = None,
):
@ -522,7 +523,7 @@ class FARMReader(BaseReader):
A list containing torch device objects and/or strings is supported (For example
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
:param student_batch_size: Number of samples the student model receives in one batch for training
:param batch_size: Number of samples the student model receives in one batch for training
:param teacher_batch_size: Number of samples the teacher model receives in one batch for distillation
:param n_epochs: Number of iterations on the whole training data set
:param learning_rate: Learning rate of the optimizer
@ -548,10 +549,6 @@ class FARMReader(BaseReader):
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
:param tinybert_loss: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased.
:param tinybert_epochs: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
:param tinybert_learning_rate: Learning rate to use when training the student model with the TinyBERT loss function.
:param tinybert_train_filename: Filename of training data to use when training the student model with the TinyBERT loss function. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script. If not specified, the training data from the original training is used.
:param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used.
:param grad_acc_steps: The number of steps to accumulate gradients for before performing a backward pass.
:param early_stopping: An initialized EarlyStopping object to control early stopping and saving of the best models.
@ -564,7 +561,7 @@ class FARMReader(BaseReader):
test_filename=test_filename,
use_gpu=use_gpu,
devices=devices,
batch_size=student_batch_size,
batch_size=batch_size,
n_epochs=n_epochs,
learning_rate=learning_rate,
max_seq_len=max_seq_len,
@ -584,6 +581,7 @@ class FARMReader(BaseReader):
distillation_loss_weight=distillation_loss_weight,
distillation_loss=distillation_loss,
temperature=temperature,
processor=processor,
grad_acc_steps=grad_acc_steps,
early_stopping=early_stopping,
distributed=False,
@ -599,6 +597,7 @@ class FARMReader(BaseReader):
use_gpu: Optional[bool] = None,
devices: List[torch.device] = [],
batch_size: int = 10,
teacher_batch_size: Optional[int] = None,
n_epochs: int = 5,
learning_rate: float = 5e-5,
max_seq_len: Optional[int] = None,
@ -614,6 +613,7 @@ class FARMReader(BaseReader):
caching: bool = False,
cache_path: Path = Path("cache/data_silo"),
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse",
distillation_loss_weight: float = 0.5,
temperature: float = 1.0,
processor: Optional[Processor] = None,
grad_acc_steps: int = 1,
@ -645,7 +645,8 @@ class FARMReader(BaseReader):
A list containing torch device objects and/or strings is supported (For example
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
:param batch_size: Number of samples the student model and teacher model receives in one batch for training
:param batch_size: Number of samples the student model receives in one batch for training
:param teacher_batch_size: Number of samples the teacher model receives in one batch for distillation.
:param n_epochs: Number of iterations on the whole training data set
:param learning_rate: Learning rate of the optimizer
:param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down.
@ -668,6 +669,7 @@ class FARMReader(BaseReader):
:param caching: whether or not to use caching for preprocessed dataset and teacher logits
:param cache_path: Path to cache the preprocessed dataset and teacher logits
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
:param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used.
:param grad_acc_steps: The number of steps to accumulate gradients for before performing a backward pass.
@ -695,10 +697,11 @@ class FARMReader(BaseReader):
checkpoint_every=checkpoint_every,
checkpoints_to_keep=checkpoints_to_keep,
teacher_model=teacher_model,
teacher_batch_size=batch_size,
teacher_batch_size=teacher_batch_size,
caching=caching,
cache_path=cache_path,
distillation_loss=distillation_loss,
distillation_loss_weight=distillation_loss_weight,
temperature=temperature,
tinybert=True,
processor=processor,