mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 18:29:32 +00:00
parent
9891bfeddd
commit
35b2c99f43
@ -667,7 +667,7 @@ class DistillationTrainer(Trainer):
|
||||
:param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments)
|
||||
:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable
|
||||
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
||||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits)
|
||||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
||||
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
||||
"""
|
||||
super().__init__(
|
||||
@ -819,7 +819,7 @@ class TinyBERTDistillationTrainer(Trainer):
|
||||
:param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments)
|
||||
:param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable
|
||||
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
||||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits)
|
||||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
||||
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
||||
"""
|
||||
super().__init__(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user