mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-25 09:50:14 +00:00

* Testing black on ui/ * Applying black on docstores * Add latest docstring and tutorial changes * Create a single GH action for Black and docs to reduce commit noise to the minimum, slightly refactor the OpenAPI action too * Remove comments * Relax constraints on pydoc-markdown * Split temporary black from the docs. Pydoc-markdown was obsolete and needs a separate PR to upgrade * Fix a couple of bugs * Add a type: ignore that was missing somehow * Give path to black * Apply Black * Apply Black * Relocate a couple of type: ignore * Update documentation * Make Linux CI run after applying Black * Triggering Black * Apply Black * Remove dependency, does not work well * Remove manually double trailing commas * Update documentation Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
155 lines
5.5 KiB
Python
155 lines
5.5 KiB
Python
from haystack.nodes import FARMReader
|
|
import json
|
|
import requests
|
|
from pathlib import Path
|
|
|
|
from typing import Union, List
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
download_links = {
|
|
"squad2": {
|
|
"train": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json",
|
|
"test": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json",
|
|
},
|
|
"squad": {
|
|
"train": "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json",
|
|
"test": "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
|
|
},
|
|
}
|
|
|
|
# loading json config file
|
|
def load_config(path: str) -> dict:
|
|
with open(path, "r") as f:
|
|
return json.load(f)
|
|
|
|
|
|
# returns all possible combinations of hyperparameters for grid search
|
|
def combine_config(configs: dict) -> List[dict]:
|
|
combinations_list = [[]]
|
|
for config_key, config in configs.items():
|
|
if not isinstance(config, list):
|
|
config = [config]
|
|
current_combinations = combinations_list
|
|
combinations_list = []
|
|
for item in config:
|
|
combinations_list += [c + [(config_key, item)] for c in current_combinations]
|
|
|
|
combinations = []
|
|
for combination in combinations_list:
|
|
combinations.append(dict(combination))
|
|
|
|
return combinations
|
|
|
|
|
|
def download_file(url: str, path: Path):
|
|
request = requests.get(url, allow_redirects=True)
|
|
with path.open("wb") as f:
|
|
f.write(request.content)
|
|
|
|
|
|
def download_dataset(dataset: Union[dict, str], download_folder: Path):
|
|
train_file = "train.json"
|
|
test_file = "test.json"
|
|
# checking if dataset is already downloaded
|
|
if download_folder.exists():
|
|
assert download_folder.is_dir()
|
|
if (download_folder / train_file).is_file() and (download_folder / test_file).is_file():
|
|
return train_file, test_file
|
|
if type(dataset) is str: # check if dataset needs to be looked up
|
|
dataset = download_links[dataset]
|
|
train = dataset["train"]
|
|
test = dataset["test"]
|
|
download_folder.mkdir(parents=True, exist_ok=True)
|
|
download_file(train, download_folder / train_file)
|
|
download_file(test, download_folder / test_file)
|
|
return train_file, test_file
|
|
|
|
|
|
def eval(model: FARMReader, download_folder: Path, test_file: str):
|
|
return model.eval_on_file(data_dir=download_folder, test_filename=test_file)
|
|
|
|
|
|
def train_student(student: dict, download_folder: Path, train_file: str, test_file: str, **kwargs) -> dict:
|
|
# loading student model
|
|
model = FARMReader(model_name_or_path=student["model_name_or_path"])
|
|
# training student model
|
|
model.train(
|
|
data_dir=download_folder, train_filename=train_file, batch_size=student["batch_size"], caching=True, **kwargs
|
|
)
|
|
return eval(model, download_folder, test_file)
|
|
|
|
|
|
def train_student_with_distillation(
|
|
student: dict, teacher: dict, download_folder: Path, train_file: str, test_file: str, **kwargs
|
|
) -> dict:
|
|
# loading student and teacher models
|
|
student_model = FARMReader(model_name_or_path=student["model_name_or_path"])
|
|
teacher_model = FARMReader(model_name_or_path=teacher["model_name_or_path"])
|
|
# distilling
|
|
student_model.distil_from(
|
|
teacher_model,
|
|
data_dir=download_folder,
|
|
train_filename=train_file,
|
|
student_batch_size=student["batch_size"],
|
|
teacher_batch_size=teacher["batch_size"],
|
|
caching=True,
|
|
**kwargs,
|
|
)
|
|
return eval(student_model, download_folder, test_file)
|
|
|
|
|
|
def main():
|
|
# loading config
|
|
parent = Path(__file__).parent.resolve()
|
|
config = load_config(parent / "distillation_config.json")
|
|
download_folder = parent / config["download_folder"]
|
|
student = config["student_model"]
|
|
teacher = config["teacher_model"]
|
|
|
|
distillation_settings = config["distillation_settings"]
|
|
training_settings = config["training_settings"]
|
|
|
|
# loading dataset
|
|
logger.info("Downloading dataset")
|
|
train_file, test_file = download_dataset(config["dataset"], download_folder)
|
|
|
|
results = []
|
|
descriptions = []
|
|
|
|
for current_config in combine_config(distillation_settings):
|
|
descriptions.append(f"Results of student with distillation (config: {current_config}")
|
|
# distillation training
|
|
logger.info(f"Training student with distillation (config: {current_config}")
|
|
results.append(
|
|
train_student_with_distillation(
|
|
student, teacher, download_folder, train_file, test_file, **current_config, **training_settings
|
|
)
|
|
)
|
|
|
|
# baseline
|
|
if config["evaluate_student_without_distillation"]:
|
|
logger.info("Training student without distillation as a baseline")
|
|
descriptions.append("Results of student without distillation")
|
|
results.append(train_student(student, download_folder, train_file, test_file, **training_settings))
|
|
|
|
if config["evaluate_teacher"]:
|
|
# evaluating teacher as upper bound for performance
|
|
logger.info("Evaluating teacher")
|
|
descriptions.append("Results of teacher")
|
|
results.append(eval(FARMReader(model_name_or_path=teacher["model_name_or_path"]), download_folder, test_file))
|
|
|
|
# printing evaluation results
|
|
logger.info("Evaluation results:")
|
|
for result, description in zip(results, descriptions):
|
|
logger.info(description)
|
|
logger.info(f"EM: {result['EM']}")
|
|
logger.info(f"F1: {result['f1']}")
|
|
logger.info(f"Top n accuracy: {result['top_n_accuracy']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|