2022-01-26 18:12:55 +01:00
|
|
|
from pathlib import Path
|
2021-11-26 18:49:30 +01:00
|
|
|
from haystack.nodes import FARMReader
|
2022-01-25 14:54:34 +01:00
|
|
|
from haystack.modeling.data_handler.processor import UnlabeledTextProcessor
|
2021-11-26 18:49:30 +01:00
|
|
|
import torch
|
|
|
|
|
2022-01-26 18:12:55 +01:00
|
|
|
from conftest import SAMPLES_PATH
|
|
|
|
|
|
|
|
|
2021-12-23 14:54:02 +01:00
|
|
|
def create_checkpoint(model):
|
|
|
|
weights = []
|
|
|
|
for name, weight in model.inferencer.model.named_parameters():
|
|
|
|
if "weight" in name and weight.requires_grad:
|
|
|
|
weights.append(torch.clone(weight))
|
|
|
|
return weights
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2021-12-23 14:54:02 +01:00
|
|
|
def assert_weight_change(weights, new_weights):
|
|
|
|
print([torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights)])
|
|
|
|
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights))
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-01-25 14:54:34 +01:00
|
|
|
def test_prediction_layer_distillation():
|
2021-12-22 17:20:23 +01:00
|
|
|
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny", num_processes=0)
|
|
|
|
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small", num_processes=0)
|
2021-11-26 18:49:30 +01:00
|
|
|
|
|
|
|
# create a checkpoint of weights before distillation
|
2021-12-23 14:54:02 +01:00
|
|
|
student_weights = create_checkpoint(student)
|
2021-11-26 18:49:30 +01:00
|
|
|
|
|
|
|
assert len(student_weights) == 22
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
student_weights.pop(-2) # pooler is not updated due to different attention head
|
|
|
|
|
|
|
|
student.distil_prediction_layer_from(teacher, data_dir=SAMPLES_PATH / "squad", train_filename="tiny.json")
|
2021-11-26 18:49:30 +01:00
|
|
|
|
|
|
|
# create new checkpoint
|
2021-12-23 14:54:02 +01:00
|
|
|
new_student_weights = create_checkpoint(student)
|
2021-11-26 18:49:30 +01:00
|
|
|
|
|
|
|
assert len(new_student_weights) == 22
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
new_student_weights.pop(-2) # pooler is not updated due to different attention head
|
2021-11-26 18:49:30 +01:00
|
|
|
|
|
|
|
# check if weights have changed
|
2021-12-23 14:54:02 +01:00
|
|
|
assert_weight_change(student_weights, new_student_weights)
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-01-25 14:54:34 +01:00
|
|
|
def test_intermediate_layer_distillation():
|
2021-12-23 14:54:02 +01:00
|
|
|
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D")
|
|
|
|
teacher = FARMReader(model_name_or_path="bert-base-uncased")
|
|
|
|
|
|
|
|
# create a checkpoint of weights before distillation
|
|
|
|
student_weights = create_checkpoint(student)
|
|
|
|
|
|
|
|
assert len(student_weights) == 38
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
student_weights.pop(-1) # last layer is not affected by tinybert loss
|
|
|
|
student_weights.pop(-1) # pooler is not updated due to different attention head
|
|
|
|
|
|
|
|
student.distil_intermediate_layers_from(
|
|
|
|
teacher_model=teacher, data_dir=SAMPLES_PATH / "squad", train_filename="tiny.json"
|
|
|
|
)
|
2021-12-23 14:54:02 +01:00
|
|
|
|
|
|
|
# create new checkpoint
|
|
|
|
new_student_weights = create_checkpoint(student)
|
|
|
|
|
|
|
|
assert len(new_student_weights) == 38
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
new_student_weights.pop(-1) # last layer is not affected by tinybert loss
|
|
|
|
new_student_weights.pop(-1) # pooler is not updated due to different attention head
|
2021-12-23 14:54:02 +01:00
|
|
|
|
2022-01-25 14:54:34 +01:00
|
|
|
# check if weights have changed
|
|
|
|
assert_weight_change(student_weights, new_student_weights)
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
|
2022-01-25 14:54:34 +01:00
|
|
|
def test_intermediate_layer_distillation_from_scratch():
|
|
|
|
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D")
|
|
|
|
teacher = FARMReader(model_name_or_path="bert-base-uncased")
|
|
|
|
|
|
|
|
# create a checkpoint of weights before distillation
|
|
|
|
student_weights = create_checkpoint(student)
|
|
|
|
|
|
|
|
assert len(student_weights) == 38
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
student_weights.pop(-1) # last layer is not affected by tinybert loss
|
|
|
|
student_weights.pop(-1) # pooler is not updated due to different attention head
|
|
|
|
|
|
|
|
processor = UnlabeledTextProcessor(
|
|
|
|
tokenizer=teacher.inferencer.processor.tokenizer,
|
|
|
|
max_seq_len=128,
|
|
|
|
train_filename="doc_2.txt",
|
|
|
|
data_dir=SAMPLES_PATH / "docs",
|
|
|
|
)
|
|
|
|
student.distil_intermediate_layers_from(
|
|
|
|
teacher_model=teacher, data_dir=SAMPLES_PATH / "squad", train_filename="tiny.json", processor=processor
|
|
|
|
)
|
2022-01-25 14:54:34 +01:00
|
|
|
|
|
|
|
# create new checkpoint
|
|
|
|
new_student_weights = create_checkpoint(student)
|
|
|
|
|
|
|
|
assert len(new_student_weights) == 38
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
new_student_weights.pop(-1) # last layer is not affected by tinybert loss
|
|
|
|
new_student_weights.pop(-1) # pooler is not updated due to different attention head
|
2022-01-25 14:54:34 +01:00
|
|
|
|
2021-12-23 14:54:02 +01:00
|
|
|
# check if weights have changed
|
2022-02-03 13:43:18 +01:00
|
|
|
assert_weight_change(student_weights, new_student_weights)
|