haystack/test/test_distillation.py
MichelBartels 84147edcca
Model Distillation (#1758)
* initial commit

* Add latest docstring and tutorial changes

* added comments and fixed bug

* fixed bugs, added benchmark and added documentation

* Add latest docstring and tutorial changes

* fix type: ignore comment

* fix logging in benchmark

* fixed distillation config

* Add latest docstring and tutorial changes

* added type annotations

* fixed distillation loss calculation

* added type annotations

* fixed distillation mse loss

* improved model distillation benchmark config loading

* added temperature for model distillation

* removed uncessary imports, added comments, added named parameter calls

* Add latest docstring and tutorial changes

* added some more comments

* added distillation test

* fixed distillation test

* removed unnecessary import

* fix softmax dimension

* add grid search

* improved model distillation benchmark config

* fixed model distillation hyperparameter search

* added doc strings and type hints for model distillation

* Add latest docstring and tutorial changes

* fixed type hints

* fixed type hints

* fixed type hints

* wrote out params instead of kwargs in DistillationDataSilo initializer

* fixed type hints

* fixed typo

* fixed typo

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2021-11-26 18:49:30 +01:00

33 lines
1.3 KiB
Python

from haystack.nodes import FARMReader
import torch
def test_distillation():
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny")
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small")
# create a checkpoint of weights before distillation
student_weights = []
for name, weight in student.inferencer.model.named_parameters():
if "weight" in name and weight.requires_grad:
student_weights.append(torch.clone(weight))
assert len(student_weights) == 22
student_weights.pop(-2) # pooler is not updated due to different attention head
student.distil_from(teacher, data_dir="samples/squad", train_filename="tiny.json")
# create new checkpoint
new_student_weights = [torch.clone(param) for param in student.inferencer.model.parameters()]
new_student_weights = []
for name, weight in student.inferencer.model.named_parameters():
if "weight" in name and weight.requires_grad:
new_student_weights.append(weight)
assert len(new_student_weights) == 22
new_student_weights.pop(-2) # pooler is not updated due to different attention head
# check if weights have changed
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(student_weights, new_student_weights))