| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | from haystack.nodes import FARMReader | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | import requests | 
					
						
							|  |  |  | from pathlib import Path | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from typing import Union, List | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | logger = logging.getLogger(__name__) | 
					
						
							|  |  |  | logger.setLevel(logging.DEBUG) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | download_links = { | 
					
						
							|  |  |  |     "squad2": { | 
					
						
							|  |  |  |         "train": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         "test": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |     }, | 
					
						
							|  |  |  |     "squad": { | 
					
						
							|  |  |  |         "train": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json", | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         "test": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", | 
					
						
							|  |  |  |     }, | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-08 15:34:43 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | # loading json config file | 
					
						
							|  |  |  | def load_config(path: str) -> dict: | 
					
						
							|  |  |  |     with open(path, "r") as f: | 
					
						
							|  |  |  |         return json.load(f) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | # returns all possible combinations of hyperparameters for grid search | 
					
						
							|  |  |  | def combine_config(configs: dict) -> List[dict]: | 
					
						
							|  |  |  |     combinations_list = [[]] | 
					
						
							|  |  |  |     for config_key, config in configs.items(): | 
					
						
							|  |  |  |         if not isinstance(config, list): | 
					
						
							|  |  |  |             config = [config] | 
					
						
							|  |  |  |         current_combinations = combinations_list | 
					
						
							|  |  |  |         combinations_list = [] | 
					
						
							|  |  |  |         for item in config: | 
					
						
							|  |  |  |             combinations_list += [c + [(config_key, item)] for c in current_combinations] | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |     combinations = [] | 
					
						
							|  |  |  |     for combination in combinations_list: | 
					
						
							|  |  |  |         combinations.append(dict(combination)) | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |     return combinations | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | def download_file(url: str, path: Path): | 
					
						
							|  |  |  |     request = requests.get(url, allow_redirects=True) | 
					
						
							|  |  |  |     with path.open("wb") as f: | 
					
						
							|  |  |  |         f.write(request.content) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | def download_dataset(dataset: Union[dict, str], download_folder: Path): | 
					
						
							|  |  |  |     train_file = "train.json" | 
					
						
							|  |  |  |     test_file = "test.json" | 
					
						
							|  |  |  |     # checking if dataset is already downloaded | 
					
						
							|  |  |  |     if download_folder.exists(): | 
					
						
							|  |  |  |         assert download_folder.is_dir() | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         if (download_folder / train_file).is_file() and (download_folder / test_file).is_file(): | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |             return train_file, test_file | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     if type(dataset) is str:  # check if dataset needs to be looked up | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |         dataset = download_links[dataset] | 
					
						
							|  |  |  |     train = dataset["train"] | 
					
						
							|  |  |  |     test = dataset["test"] | 
					
						
							|  |  |  |     download_folder.mkdir(parents=True, exist_ok=True) | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     download_file(train, download_folder / train_file) | 
					
						
							|  |  |  |     download_file(test, download_folder / test_file) | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |     return train_file, test_file | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | def eval(model: FARMReader, download_folder: Path, test_file: str): | 
					
						
							|  |  |  |     return model.eval_on_file(data_dir=download_folder, test_filename=test_file) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | def train_student(student: dict, download_folder: Path, train_file: str, test_file: str, **kwargs) -> dict: | 
					
						
							|  |  |  |     # loading student model | 
					
						
							|  |  |  |     model = FARMReader(model_name_or_path=student["model_name_or_path"]) | 
					
						
							|  |  |  |     # training student model | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     model.train( | 
					
						
							|  |  |  |         data_dir=download_folder, train_filename=train_file, batch_size=student["batch_size"], caching=True, **kwargs | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |     return eval(model, download_folder, test_file) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | def train_student_with_distillation( | 
					
						
							|  |  |  |     student: dict, teacher: dict, download_folder: Path, train_file: str, test_file: str, **kwargs | 
					
						
							|  |  |  | ) -> dict: | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |     # loading student and teacher models | 
					
						
							|  |  |  |     student_model = FARMReader(model_name_or_path=student["model_name_or_path"]) | 
					
						
							|  |  |  |     teacher_model = FARMReader(model_name_or_path=teacher["model_name_or_path"]) | 
					
						
							|  |  |  |     # distilling | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     student_model.distil_from( | 
					
						
							|  |  |  |         teacher_model, | 
					
						
							|  |  |  |         data_dir=download_folder, | 
					
						
							|  |  |  |         train_filename=train_file, | 
					
						
							|  |  |  |         student_batch_size=student["batch_size"], | 
					
						
							|  |  |  |         teacher_batch_size=teacher["batch_size"], | 
					
						
							|  |  |  |         caching=True, | 
					
						
							|  |  |  |         **kwargs, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |     return eval(student_model, download_folder, test_file) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | def main(): | 
					
						
							|  |  |  |     # loading config | 
					
						
							|  |  |  |     parent = Path(__file__).parent.resolve() | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     config = load_config(parent / "distillation_config.json") | 
					
						
							|  |  |  |     download_folder = parent / config["download_folder"] | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  |     student = config["student_model"] | 
					
						
							|  |  |  |     teacher = config["teacher_model"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     distillation_settings = config["distillation_settings"] | 
					
						
							|  |  |  |     training_settings = config["training_settings"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # loading dataset | 
					
						
							|  |  |  |     logger.info("Downloading dataset") | 
					
						
							|  |  |  |     train_file, test_file = download_dataset(config["dataset"], download_folder) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     results = [] | 
					
						
							|  |  |  |     descriptions = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for current_config in combine_config(distillation_settings): | 
					
						
							|  |  |  |         descriptions.append(f"Results of student with distillation (config: {current_config}") | 
					
						
							|  |  |  |         # distillation training | 
					
						
							| 
									
										
										
										
											2022-09-19 18:18:32 +02:00
										 |  |  |         logger.info("Training student with distillation (config: %s)", current_config) | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         results.append( | 
					
						
							|  |  |  |             train_student_with_distillation( | 
					
						
							|  |  |  |                 student, teacher, download_folder, train_file, test_file, **current_config, **training_settings | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # baseline | 
					
						
							|  |  |  |     if config["evaluate_student_without_distillation"]: | 
					
						
							|  |  |  |         logger.info("Training student without distillation as a baseline") | 
					
						
							|  |  |  |         descriptions.append("Results of student without distillation") | 
					
						
							|  |  |  |         results.append(train_student(student, download_folder, train_file, test_file, **training_settings)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if config["evaluate_teacher"]: | 
					
						
							|  |  |  |         # evaluating teacher as upper bound for performance | 
					
						
							|  |  |  |         logger.info("Evaluating teacher") | 
					
						
							|  |  |  |         descriptions.append("Results of teacher") | 
					
						
							|  |  |  |         results.append(eval(FARMReader(model_name_or_path=teacher["model_name_or_path"]), download_folder, test_file)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # printing evaluation results | 
					
						
							|  |  |  |     logger.info("Evaluation results:") | 
					
						
							|  |  |  |     for result, description in zip(results, descriptions): | 
					
						
							|  |  |  |         logger.info(description) | 
					
						
							| 
									
										
										
										
											2022-09-19 18:18:32 +02:00
										 |  |  |         logger.info("EM: %s", result["EM"]) | 
					
						
							|  |  |  |         logger.info("F1: %s", result["f1"]) | 
					
						
							|  |  |  |         logger.info("Top n accuracy: %s", result["top_n_accuracy"]) | 
					
						
							| 
									
										
										
										
											2021-11-26 18:49:30 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     main() |