import argparse import json import os import pathlib import re from dataclasses import dataclass, field def dataset_subdataset_name_format_check(val_str): regex = re.compile(r"^[^:]*:[^:]*$") if not regex.match(val_str): raise argparse.ArgumentTypeError("dataset_subdataset_name must be in the format {data_name}:{subdata_name}") return val_str def pretrained_model_size_format_check(val_str): regex = re.compile(r"^[^:]*:(small|base|large|xlarge)") if not regex.match(val_str): raise argparse.ArgumentTypeError("pretrained_model_size must be in the format {model_name}:{model_size}," "where {model_name} is the name from huggingface.co/models, {model_size}" "is chosen from small, base, large, xlarge") return val_str def load_console_args(**custom_data_args): arg_parser = argparse.ArgumentParser() arg_parser.add_argument('--server_name', type=str, help='server name', required=False, choices=["tmdev", "dgx", "azureml"], default="tmdev") arg_parser.add_argument('--algo_mode', type=str, help='hpo or grid search', required=False, choices=["grid", "gridbert", "hpo", "hfhpo", "list_s", "list", "bestnn"], default="hpo") arg_parser.add_argument('--data_root_dir', type=str, help='data dir', required=False, default="data/") arg_parser.add_argument('--dataset_subdataset_name', type=dataset_subdataset_name_format_check, help='dataset and subdataset name', required=False, default=None) arg_parser.add_argument('--space_mode', type=str, help='space mode', required=False, choices=["gnr", "uni", "uni_test", "cus", "buni"], default="uni") arg_parser.add_argument('--search_alg_args_mode', type=str, help='search algorithm args mode', required=False, choices=["dft", "exp", "cus"], default="dft") arg_parser.add_argument('--algo_name', type=str, help='algorithm', required=False, choices=["bs", "optuna", "cfo", "rs"], default="bs") arg_parser.add_argument('--pruner', type=str, help='pruner', required=False, choices=["asha", "None"], default="None") arg_parser.add_argument('--pretrained_model_size', type=pretrained_model_size_format_check, help='pretrained model', required=False, default=None) arg_parser.add_argument('--sample_num', type=int, help='sample num', required=False, default=None) arg_parser.add_argument('--time_budget', type=int, help='time budget', required=False, default=None) arg_parser.add_argument('--time_as_grid', type=int, help='time as grid search', required=False, default=None) arg_parser.add_argument('--rep_id', type=int, help='rep id', required=False, default=0) arg_parser.add_argument('--azure_key', type=str, help='azure key', required=False, default=None) arg_parser.add_argument('--resplit_mode', type=str, help='resplit mode', required=False, choices=["rspt", "ori"], default="ori") arg_parser.add_argument('--ds_config', type=str, help='deep speed config file path', required=False, default=None) arg_parser.add_argument('--yml_file', type=str, help='yml file path', required=False, default="test.yml") arg_parser.add_argument('--key_path', type=str, help='path for key.json', required=False, default=None) arg_parser.add_argument('--root_log_path', type=str, help='root path for log', required=False, default="logs_azure") arg_parser.add_argument('--round_idx', type=int, help='round idx for acl experiments', required=False, default=0) arg_parser.add_argument('--seed_data', type=int, help='seed of data shuffling', required=False, default=43) arg_parser.add_argument('--seed_transformers', type=int, help='seed of transformers', required=False, default=42) args, unknown = arg_parser.parse_known_args() for each_key in custom_data_args.keys(): if args.__contains__(each_key): try: check_key_format_func = globals()[each_key + "_format_check"] check_key_format_func(custom_data_args[each_key]) except KeyError: pass setattr(args, each_key, custom_data_args[each_key]) return args def get_wandb_azure_key(key_path): key_json = json.load(open(os.path.join(key_path, "key.json"), "r")) wandb_key = key_json["wandb_key"] azure_key = key_json["azure_key"] azure_container_name = key_json["container_name"] return wandb_key, azure_key, azure_container_name def merge_dicts(dict1, dict2): for key2 in dict2.keys(): if key2 in dict1: dict1_vals = set(dict1[key2]) dict2_vals = set(dict2[key2]) dict1[key2] = list(dict1_vals.union(dict2_vals)) else: dict1[key2] = dict2[key2] return dict1 def _check_dict_keys_overlaps(dict1: dict, dict2: dict): dict1_keys = set(dict1.keys()) dict2_keys = set(dict2.keys()) return len(dict1_keys.intersection(dict2_keys)) > 0 def _variable_override_default_alternative(logger, obj_ref, var_name, default_value, all_values, overriding_value=None): """ Setting the value of var. If overriding_value is specified, var is set to overriding_value; If overriding_value is not specified, var is set to default_value meanwhile showing all_values """ assert isinstance(all_values, list) if overriding_value: setattr(obj_ref, var_name, overriding_value) logger.warning("The value for {} is specified as {}".format(var_name, overriding_value)) else: setattr(obj_ref, var_name, default_value) logger.warning("The value for {} is not specified, setting it to the default value {}. " "Alternatively, you can set it to {}".format(var_name, default_value, ",".join(all_values))) @dataclass class PathUtils: hpo_ckpt_path: str = field(metadata={"help": "the directory for hpo output"}) hpo_result_path: str = field(metadata={"help": "the directory for hpo result"}) hpo_log_path: str = field(metadata={"help": "the directory for log"}) hpo_config_path: str = field(metadata={"help": "the directory for log"}) log_dir_per_run: str = field(metadata={"help": "log directory for each run."}) result_dir_per_run: str = field(metadata={"help": "result directory for each run."}) ckpt_dir_per_run: str = field(metadata={"help": "checkpoint directory for each run."}) ckpt_dir_per_trial: str = field(metadata={"help": "checkpoint directory for each trial."}) def __init__(self, jobid_config, hpo_data_root_path, ): self.jobid_config = jobid_config self.hpo_data_root_path = hpo_data_root_path self.hpo_ckpt_path = os.path.join(hpo_data_root_path, "checkpoint") self.hpo_result_path = os.path.join(hpo_data_root_path, "result") self.hpo_log_path = self.hpo_result_path @staticmethod def init_and_make_one_dir(dir_path): assert dir_path if not os.path.exists(dir_path): pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True) def make_dir_per_run(self): jobid_str = self.jobid_config.to_jobid_string() self.ckpt_dir_per_run = os.path.join(self.hpo_ckpt_path, jobid_str) PathUtils.init_and_make_one_dir(self.ckpt_dir_per_run) self.result_dir_per_run = os.path.join(self.hpo_result_path, jobid_str) PathUtils.init_and_make_one_dir(self.result_dir_per_run) self.log_dir_per_run = os.path.join(self.hpo_log_path, jobid_str) PathUtils.init_and_make_one_dir(self.log_dir_per_run) def make_dir_per_trial(self, trial_id): jobid_str = self.jobid_config.to_jobid_string() ckpt_dir_per_run = os.path.join(self.hpo_ckpt_path, jobid_str) self.ckpt_dir_per_trial = os.path.join(ckpt_dir_per_run, jobid_str, trial_id) PathUtils.init_and_make_one_dir(self.ckpt_dir_per_trial)