mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-25 06:48:43 +00:00 
			
		
		
		
	 84147edcca
			
		
	
	
		84147edcca
		
			
		
	
	
	
	
		
			
			* initial commit * Add latest docstring and tutorial changes * added comments and fixed bug * fixed bugs, added benchmark and added documentation * Add latest docstring and tutorial changes * fix type: ignore comment * fix logging in benchmark * fixed distillation config * Add latest docstring and tutorial changes * added type annotations * fixed distillation loss calculation * added type annotations * fixed distillation mse loss * improved model distillation benchmark config loading * added temperature for model distillation * removed uncessary imports, added comments, added named parameter calls * Add latest docstring and tutorial changes * added some more comments * added distillation test * fixed distillation test * removed unnecessary import * fix softmax dimension * add grid search * improved model distillation benchmark config * fixed model distillation hyperparameter search * added doc strings and type hints for model distillation * Add latest docstring and tutorial changes * fixed type hints * fixed type hints * fixed type hints * wrote out params instead of kwargs in DistillationDataSilo initializer * fixed type hints * fixed typo * fixed typo Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
		
			
				
	
	
		
			135 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			135 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from haystack.nodes import FARMReader
 | |
| import json
 | |
| import requests
 | |
| from pathlib import Path
 | |
| 
 | |
| from typing import Union, List
 | |
| import logging
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| logger.setLevel(logging.DEBUG)
 | |
| 
 | |
| download_links = {
 | |
|     "squad2": {
 | |
|         "train": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json",
 | |
|         "test": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"
 | |
|     },
 | |
|     "squad": {
 | |
|         "train": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json",
 | |
|         "test": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
 | |
|     }
 | |
| }
 | |
| 
 | |
| # loading json config file
 | |
| def load_config(path: str) -> dict:
 | |
|     with open(path, "r") as f:
 | |
|         return json.load(f)
 | |
| 
 | |
| # returns all possible combinations of hyperparameters for grid search
 | |
| def combine_config(configs: dict) -> List[dict]:
 | |
|     combinations_list = [[]]
 | |
|     for config_key, config in configs.items():
 | |
|         if not isinstance(config, list):
 | |
|             config = [config]
 | |
|         current_combinations = combinations_list
 | |
|         combinations_list = []
 | |
|         for item in config:
 | |
|             combinations_list += [c + [(config_key, item)] for c in current_combinations]
 | |
|     
 | |
|     combinations = []
 | |
|     for combination in combinations_list:
 | |
|         combinations.append(dict(combination))
 | |
|     
 | |
|     return combinations
 | |
|         
 | |
| 
 | |
| def download_file(url: str, path: Path):
 | |
|     request = requests.get(url, allow_redirects=True)
 | |
|     with path.open("wb") as f:
 | |
|         f.write(request.content)
 | |
| 
 | |
| def download_dataset(dataset: Union[dict, str], download_folder: Path):
 | |
|     train_file = "train.json"
 | |
|     test_file = "test.json"
 | |
|     # checking if dataset is already downloaded
 | |
|     if download_folder.exists():
 | |
|         assert download_folder.is_dir()
 | |
|         if (download_folder/train_file).is_file() and (download_folder/test_file).is_file():
 | |
|             return train_file, test_file
 | |
|     if type(dataset) is str: # check if dataset needs to be looked up
 | |
|         dataset = download_links[dataset]
 | |
|     train = dataset["train"]
 | |
|     test = dataset["test"]
 | |
|     download_folder.mkdir(parents=True, exist_ok=True)
 | |
|     download_file(train, download_folder/train_file)
 | |
|     download_file(test, download_folder/test_file)
 | |
|     return train_file, test_file
 | |
| 
 | |
| def eval(model: FARMReader, download_folder: Path, test_file: str):
 | |
|     return model.eval_on_file(data_dir=download_folder, test_filename=test_file)
 | |
| 
 | |
| def train_student(student: dict, download_folder: Path, train_file: str, test_file: str, **kwargs) -> dict:
 | |
|     # loading student model
 | |
|     model = FARMReader(model_name_or_path=student["model_name_or_path"])
 | |
|     # training student model
 | |
|     model.train(data_dir=download_folder, train_filename=train_file, batch_size=student["batch_size"], caching=True, **kwargs)
 | |
|     return eval(model, download_folder, test_file)
 | |
| 
 | |
| def train_student_with_distillation(student: dict, teacher: dict, download_folder: Path, train_file: str, test_file: str, **kwargs) -> dict:
 | |
|     # loading student and teacher models
 | |
|     student_model = FARMReader(model_name_or_path=student["model_name_or_path"])
 | |
|     teacher_model = FARMReader(model_name_or_path=teacher["model_name_or_path"])
 | |
|     # distilling
 | |
|     student_model.distil_from(teacher_model, data_dir=download_folder, train_filename=train_file, student_batch_size=student["batch_size"], teacher_batch_size=teacher["batch_size"], caching=True,
 | |
|      **kwargs)
 | |
|     return eval(student_model, download_folder, test_file)
 | |
| 
 | |
| def main():
 | |
|     # loading config
 | |
|     parent = Path(__file__).parent.resolve()
 | |
|     config = load_config(parent/"distillation_config.json")
 | |
|     download_folder = parent/config["download_folder"]
 | |
|     student = config["student_model"]
 | |
|     teacher = config["teacher_model"]
 | |
| 
 | |
|     distillation_settings = config["distillation_settings"]
 | |
|     training_settings = config["training_settings"]
 | |
| 
 | |
|     
 | |
|     # loading dataset
 | |
|     logger.info("Downloading dataset")
 | |
|     train_file, test_file = download_dataset(config["dataset"], download_folder)
 | |
| 
 | |
|     results = []
 | |
|     descriptions = []
 | |
| 
 | |
|     for current_config in combine_config(distillation_settings):
 | |
|         descriptions.append(f"Results of student with distillation (config: {current_config}")
 | |
|         # distillation training
 | |
|         logger.info(f"Training student with distillation (config: {current_config}")
 | |
|         results.append(train_student_with_distillation(student, teacher, download_folder, train_file, test_file, **current_config, **training_settings))
 | |
| 
 | |
|     # baseline
 | |
|     if config["evaluate_student_without_distillation"]:
 | |
|         logger.info("Training student without distillation as a baseline")
 | |
|         descriptions.append("Results of student without distillation")
 | |
|         results.append(train_student(student, download_folder, train_file, test_file, **training_settings))
 | |
| 
 | |
|     if config["evaluate_teacher"]:
 | |
|         # evaluating teacher as upper bound for performance
 | |
|         logger.info("Evaluating teacher")
 | |
|         descriptions.append("Results of teacher")
 | |
|         results.append(eval(FARMReader(model_name_or_path=teacher["model_name_or_path"]), download_folder, test_file))
 | |
| 
 | |
|     # printing evaluation results
 | |
|     logger.info("Evaluation results:")
 | |
|     for result, description in zip(results, descriptions):
 | |
|         logger.info(description)
 | |
|         logger.info(f"EM: {result['EM']}")
 | |
|         logger.info(f"F1: {result['f1']}")
 | |
|         logger.info(f"Top n accuracy: {result['top_n_accuracy']}")
 | |
| 
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main() |