haystack/test/benchmarks/model_distillation.py
Sara Zan a59bca3661
Apply black formatting (#2115)
* Testing black on ui/

* Applying black on docstores

* Add latest docstring and tutorial changes

* Create a single GH action for Black and docs to reduce commit noise to the minimum, slightly refactor the OpenAPI action too

* Remove comments

* Relax constraints on pydoc-markdown

* Split temporary black from the docs. Pydoc-markdown was obsolete and needs a separate PR to upgrade

* Fix a couple of bugs

* Add a type: ignore that was missing somehow

* Give path to black

* Apply Black

* Apply Black

* Relocate a couple of type: ignore

* Update documentation

* Make Linux CI run after applying Black

* Triggering Black

* Apply Black

* Remove dependency, does not work well

* Remove manually double trailing commas

* Update documentation

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2022-02-03 13:43:18 +01:00

155 lines
5.5 KiB
Python

from haystack.nodes import FARMReader
import json
import requests
from pathlib import Path
from typing import Union, List
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
download_links = {
"squad2": {
"train": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json",
"test": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json",
},
"squad": {
"train": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json",
"test": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
},
}
# loading json config file
def load_config(path: str) -> dict:
with open(path, "r") as f:
return json.load(f)
# returns all possible combinations of hyperparameters for grid search
def combine_config(configs: dict) -> List[dict]:
combinations_list = [[]]
for config_key, config in configs.items():
if not isinstance(config, list):
config = [config]
current_combinations = combinations_list
combinations_list = []
for item in config:
combinations_list += [c + [(config_key, item)] for c in current_combinations]
combinations = []
for combination in combinations_list:
combinations.append(dict(combination))
return combinations
def download_file(url: str, path: Path):
request = requests.get(url, allow_redirects=True)
with path.open("wb") as f:
f.write(request.content)
def download_dataset(dataset: Union[dict, str], download_folder: Path):
train_file = "train.json"
test_file = "test.json"
# checking if dataset is already downloaded
if download_folder.exists():
assert download_folder.is_dir()
if (download_folder / train_file).is_file() and (download_folder / test_file).is_file():
return train_file, test_file
if type(dataset) is str: # check if dataset needs to be looked up
dataset = download_links[dataset]
train = dataset["train"]
test = dataset["test"]
download_folder.mkdir(parents=True, exist_ok=True)
download_file(train, download_folder / train_file)
download_file(test, download_folder / test_file)
return train_file, test_file
def eval(model: FARMReader, download_folder: Path, test_file: str):
return model.eval_on_file(data_dir=download_folder, test_filename=test_file)
def train_student(student: dict, download_folder: Path, train_file: str, test_file: str, **kwargs) -> dict:
# loading student model
model = FARMReader(model_name_or_path=student["model_name_or_path"])
# training student model
model.train(
data_dir=download_folder, train_filename=train_file, batch_size=student["batch_size"], caching=True, **kwargs
)
return eval(model, download_folder, test_file)
def train_student_with_distillation(
student: dict, teacher: dict, download_folder: Path, train_file: str, test_file: str, **kwargs
) -> dict:
# loading student and teacher models
student_model = FARMReader(model_name_or_path=student["model_name_or_path"])
teacher_model = FARMReader(model_name_or_path=teacher["model_name_or_path"])
# distilling
student_model.distil_from(
teacher_model,
data_dir=download_folder,
train_filename=train_file,
student_batch_size=student["batch_size"],
teacher_batch_size=teacher["batch_size"],
caching=True,
**kwargs,
)
return eval(student_model, download_folder, test_file)
def main():
# loading config
parent = Path(__file__).parent.resolve()
config = load_config(parent / "distillation_config.json")
download_folder = parent / config["download_folder"]
student = config["student_model"]
teacher = config["teacher_model"]
distillation_settings = config["distillation_settings"]
training_settings = config["training_settings"]
# loading dataset
logger.info("Downloading dataset")
train_file, test_file = download_dataset(config["dataset"], download_folder)
results = []
descriptions = []
for current_config in combine_config(distillation_settings):
descriptions.append(f"Results of student with distillation (config: {current_config}")
# distillation training
logger.info(f"Training student with distillation (config: {current_config}")
results.append(
train_student_with_distillation(
student, teacher, download_folder, train_file, test_file, **current_config, **training_settings
)
)
# baseline
if config["evaluate_student_without_distillation"]:
logger.info("Training student without distillation as a baseline")
descriptions.append("Results of student without distillation")
results.append(train_student(student, download_folder, train_file, test_file, **training_settings))
if config["evaluate_teacher"]:
# evaluating teacher as upper bound for performance
logger.info("Evaluating teacher")
descriptions.append("Results of teacher")
results.append(eval(FARMReader(model_name_or_path=teacher["model_name_or_path"]), download_folder, test_file))
# printing evaluation results
logger.info("Evaluation results:")
for result, description in zip(results, descriptions):
logger.info(description)
logger.info(f"EM: {result['EM']}")
logger.info(f"F1: {result['f1']}")
logger.info(f"Top n accuracy: {result['top_n_accuracy']}")
if __name__ == "__main__":
main()