| 
									
										
										
										
											2023-04-28 17:08:41 +02:00
										 |  |  | import torch | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | from haystack.nodes import FARMReader | 
					
						
							| 
									
										
										
										
											2022-01-25 14:54:34 +01:00
										 |  |  | from haystack.modeling.data_handler.processor import UnlabeledTextProcessor | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-26 18:12:55 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-23 14:54:02 +01:00
										 |  |  | def create_checkpoint(model): | 
					
						
							|  |  |  |     weights = [] | 
					
						
							|  |  |  |     for name, weight in model.inferencer.model.named_parameters(): | 
					
						
							|  |  |  |         if "weight" in name and weight.requires_grad: | 
					
						
							|  |  |  |             weights.append(torch.clone(weight)) | 
					
						
							|  |  |  |     return weights | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-23 14:54:02 +01:00
										 |  |  | def assert_weight_change(weights, new_weights): | 
					
						
							|  |  |  |     print([torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights)]) | 
					
						
							|  |  |  |     assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  | def test_prediction_layer_distillation(samples_path): | 
					
						
							| 
									
										
										
										
											2023-01-25 19:02:11 +01:00
										 |  |  |     student = FARMReader(model_name_or_path="prajjwal1/bert-mini", num_processes=0) | 
					
						
							| 
									
										
										
										
											2021-12-22 17:20:23 +01:00
										 |  |  |     teacher = FARMReader(model_name_or_path="prajjwal1/bert-small", num_processes=0) | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # create a checkpoint of weights before distillation | 
					
						
							| 
									
										
										
										
											2021-12-23 14:54:02 +01:00
										 |  |  |     student_weights = create_checkpoint(student) | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-25 19:02:11 +01:00
										 |  |  |     assert len(student_weights) == 38 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     student_weights.pop(-2)  # pooler is not updated due to different attention head | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  |     student.distil_prediction_layer_from(teacher, data_dir=samples_path / "squad", train_filename="tiny.json") | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # create new checkpoint | 
					
						
							| 
									
										
										
										
											2021-12-23 14:54:02 +01:00
										 |  |  |     new_student_weights = create_checkpoint(student) | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-25 19:02:11 +01:00
										 |  |  |     assert len(new_student_weights) == 38 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     new_student_weights.pop(-2)  # pooler is not updated due to different attention head | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # check if weights have changed | 
					
						
							| 
									
										
										
										
											2021-12-23 14:54:02 +01:00
										 |  |  |     assert_weight_change(student_weights, new_student_weights) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  | def test_intermediate_layer_distillation(samples_path): | 
					
						
							| 
									
										
										
										
											2021-12-23 14:54:02 +01:00
										 |  |  |     student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D") | 
					
						
							|  |  |  |     teacher = FARMReader(model_name_or_path="bert-base-uncased") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # create a checkpoint of weights before distillation | 
					
						
							|  |  |  |     student_weights = create_checkpoint(student) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert len(student_weights) == 38 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     student_weights.pop(-1)  # last layer is not affected by tinybert loss | 
					
						
							|  |  |  |     student_weights.pop(-1)  # pooler is not updated due to different attention head | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     student.distil_intermediate_layers_from( | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  |         teacher_model=teacher, data_dir=samples_path / "squad", train_filename="tiny.json" | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-12-23 14:54:02 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # create new checkpoint | 
					
						
							|  |  |  |     new_student_weights = create_checkpoint(student) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert len(new_student_weights) == 38 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     new_student_weights.pop(-1)  # last layer is not affected by tinybert loss | 
					
						
							|  |  |  |     new_student_weights.pop(-1)  # pooler is not updated due to different attention head | 
					
						
							| 
									
										
										
										
											2021-12-23 14:54:02 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-25 14:54:34 +01:00
										 |  |  |     # check if weights have changed | 
					
						
							|  |  |  |     assert_weight_change(student_weights, new_student_weights) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  | def test_intermediate_layer_distillation_from_scratch(samples_path): | 
					
						
							| 
									
										
										
										
											2022-01-25 14:54:34 +01:00
										 |  |  |     student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D") | 
					
						
							|  |  |  |     teacher = FARMReader(model_name_or_path="bert-base-uncased") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # create a checkpoint of weights before distillation | 
					
						
							|  |  |  |     student_weights = create_checkpoint(student) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert len(student_weights) == 38 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     student_weights.pop(-1)  # last layer is not affected by tinybert loss | 
					
						
							|  |  |  |     student_weights.pop(-1)  # pooler is not updated due to different attention head | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     processor = UnlabeledTextProcessor( | 
					
						
							|  |  |  |         tokenizer=teacher.inferencer.processor.tokenizer, | 
					
						
							|  |  |  |         max_seq_len=128, | 
					
						
							|  |  |  |         train_filename="doc_2.txt", | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  |         data_dir=samples_path / "docs", | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  |     student.distil_intermediate_layers_from( | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  |         teacher_model=teacher, data_dir=samples_path / "squad", train_filename="tiny.json", processor=processor | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-01-25 14:54:34 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # create new checkpoint | 
					
						
							|  |  |  |     new_student_weights = create_checkpoint(student) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert len(new_student_weights) == 38 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     new_student_weights.pop(-1)  # last layer is not affected by tinybert loss | 
					
						
							|  |  |  |     new_student_weights.pop(-1)  # pooler is not updated due to different attention head | 
					
						
							| 
									
										
										
										
											2022-01-25 14:54:34 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-23 14:54:02 +01:00
										 |  |  |     # check if weights have changed | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     assert_weight_change(student_weights, new_student_weights) |