mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 18:36:36 +00:00

* initial commit * Add latest docstring and tutorial changes * added comments and fixed bug * fixed bugs, added benchmark and added documentation * Add latest docstring and tutorial changes * fix type: ignore comment * fix logging in benchmark * fixed distillation config * Add latest docstring and tutorial changes * added type annotations * fixed distillation loss calculation * added type annotations * fixed distillation mse loss * improved model distillation benchmark config loading * added temperature for model distillation * removed uncessary imports, added comments, added named parameter calls * Add latest docstring and tutorial changes * added some more comments * added distillation test * fixed distillation test * removed unnecessary import * fix softmax dimension * add grid search * improved model distillation benchmark config * fixed model distillation hyperparameter search * added doc strings and type hints for model distillation * Add latest docstring and tutorial changes * fixed type hints * fixed type hints * fixed type hints * wrote out params instead of kwargs in DistillationDataSilo initializer * fixed type hints * fixed typo * fixed typo Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
33 lines
1.3 KiB
Python
33 lines
1.3 KiB
Python
from haystack.nodes import FARMReader
|
|
import torch
|
|
|
|
def test_distillation():
|
|
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny")
|
|
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small")
|
|
|
|
# create a checkpoint of weights before distillation
|
|
student_weights = []
|
|
for name, weight in student.inferencer.model.named_parameters():
|
|
if "weight" in name and weight.requires_grad:
|
|
student_weights.append(torch.clone(weight))
|
|
|
|
assert len(student_weights) == 22
|
|
|
|
student_weights.pop(-2) # pooler is not updated due to different attention head
|
|
|
|
student.distil_from(teacher, data_dir="samples/squad", train_filename="tiny.json")
|
|
|
|
# create new checkpoint
|
|
new_student_weights = [torch.clone(param) for param in student.inferencer.model.parameters()]
|
|
|
|
new_student_weights = []
|
|
for name, weight in student.inferencer.model.named_parameters():
|
|
if "weight" in name and weight.requires_grad:
|
|
new_student_weights.append(weight)
|
|
|
|
assert len(new_student_weights) == 22
|
|
|
|
new_student_weights.pop(-2) # pooler is not updated due to different attention head
|
|
|
|
# check if weights have changed
|
|
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(student_weights, new_student_weights)) |