2021-06-09 11:37:03 -04:00
|
|
|
import re
|
|
|
|
|
import pathlib
|
|
|
|
|
import os
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
import json
|
2021-06-14 17:11:40 -04:00
|
|
|
from typing import Tuple, List, Union, Optional
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigScore:
|
|
|
|
|
trial_id: str = field(default=None)
|
|
|
|
|
start_time: float = field(default=None)
|
|
|
|
|
last_update_time: float = field(default=None)
|
|
|
|
|
config: dict = field(default=None)
|
|
|
|
|
metric_score: dict = field(default=None)
|
|
|
|
|
time_stamp: float = field(default=None)
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
trial_id: str = None,
|
|
|
|
|
start_time: float = None,
|
|
|
|
|
last_update_time: float = None,
|
|
|
|
|
config: dict = None,
|
|
|
|
|
metric_score: dict = None,
|
|
|
|
|
time_stamp: float = None
|
|
|
|
|
):
|
|
|
|
|
self.trial_id = trial_id
|
|
|
|
|
self.start_time = start_time
|
|
|
|
|
self.last_update_time = last_update_time
|
|
|
|
|
self.config = config
|
|
|
|
|
self.metric_score = metric_score
|
|
|
|
|
self.time_stamp = time_stamp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigScoreList:
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
config_score_list: List[ConfigScore],
|
|
|
|
|
jobid_config=None,
|
|
|
|
|
blob_file=None,
|
|
|
|
|
):
|
|
|
|
|
self._config_score_list = config_score_list
|
|
|
|
|
self._blob_file = blob_file
|
|
|
|
|
self._jobid_config = jobid_config
|
|
|
|
|
|
|
|
|
|
def sorted(self, sort_method="unsorted", metric_mode="max"):
|
|
|
|
|
if sort_method == "unsorted":
|
|
|
|
|
self._config_score_list = self._config_score_list
|
|
|
|
|
elif sort_method == "sort_time":
|
|
|
|
|
self._config_score_list = sorted(self._config_score_list, key=lambda x: x.start_time, reverse=False)
|
|
|
|
|
else:
|
|
|
|
|
self._config_score_list = sorted(self._config_score_list,
|
2021-06-15 12:55:23 -04:00
|
|
|
key=lambda x: getattr(x, "metric_score")
|
|
|
|
|
[metric_mode], reverse=True)
|
2021-06-14 17:11:40 -04:00
|
|
|
|
|
|
|
|
def get_best_config(self,
|
|
|
|
|
metric_mode="max"):
|
2021-06-15 12:55:23 -04:00
|
|
|
return max(self._config_score_list, key=lambda x: getattr(x, "metric_score")
|
|
|
|
|
[metric_mode])
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class JobID:
|
|
|
|
|
dat: list = field(default=None)
|
|
|
|
|
subdat: str = field(default=None)
|
|
|
|
|
mod: str = field(default=None)
|
|
|
|
|
spa: str = field(default=None)
|
|
|
|
|
arg: str = field(default=None)
|
|
|
|
|
alg: str = field(default=None)
|
|
|
|
|
pru: str = field(default=None)
|
|
|
|
|
pre_full: str = field(default=None)
|
|
|
|
|
pre: str = field(default=None)
|
|
|
|
|
presz: str = field(default=None)
|
|
|
|
|
spt: str = field(default=None)
|
|
|
|
|
rep: int = field(default=0)
|
|
|
|
|
sddt: int = field(default=None)
|
|
|
|
|
sdhf: int = field(default=None)
|
2021-06-14 17:11:40 -04:00
|
|
|
var1: Optional[float] = field(default=None)
|
|
|
|
|
var2: Optional[float] = field(default=None)
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
console_args=None):
|
|
|
|
|
if console_args:
|
|
|
|
|
self.set_jobid_from_console_args(console_args)
|
|
|
|
|
|
|
|
|
|
def set_unittest_config(self):
|
|
|
|
|
"""
|
|
|
|
|
set the JobID config for unit test
|
|
|
|
|
"""
|
|
|
|
|
self.dat = ["glue"]
|
|
|
|
|
self.subdat = "mrpc"
|
|
|
|
|
self.mod = "hpo"
|
|
|
|
|
self.spa = "uni_test"
|
2021-06-14 17:11:40 -04:00
|
|
|
self.arg = "cus"
|
2021-06-09 11:37:03 -04:00
|
|
|
self.alg = "bs"
|
|
|
|
|
self.pru = "None"
|
|
|
|
|
self.pre_full = "google/mobilebert-uncased"
|
|
|
|
|
self.pre = "mobilebert"
|
|
|
|
|
self.presz = "small"
|
|
|
|
|
self.spt = "rspt"
|
|
|
|
|
self.rep = 0
|
|
|
|
|
self.sddt = 43
|
|
|
|
|
self.sdhf = 42
|
2021-06-14 17:11:40 -04:00
|
|
|
self.var1 = None
|
|
|
|
|
self.var2 = None
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
def is_match(self, partial_jobid):
|
|
|
|
|
"""
|
|
|
|
|
return a boolean variable whether the current object matches the partial jobid defined
|
|
|
|
|
in partial_jobid. For example,
|
|
|
|
|
self = JobID(dat = ['glue'],
|
|
|
|
|
subdat = 'cola',
|
|
|
|
|
mod = 'bestnn',
|
|
|
|
|
spa = 'buni',
|
|
|
|
|
arg = 'cus',
|
|
|
|
|
alg = 'bs',
|
|
|
|
|
pru = 'None',
|
|
|
|
|
pre = 'funnel',
|
|
|
|
|
presz = 'xlarge',
|
|
|
|
|
spt = 'rspt',
|
|
|
|
|
rep = 0,
|
|
|
|
|
sddt = 43,
|
|
|
|
|
sdhf = 42)
|
|
|
|
|
partial_jobid1 = JobID(dat = ['glue'],
|
|
|
|
|
subdat = 'cola',
|
|
|
|
|
mod = 'hpo')
|
|
|
|
|
partial_jobid2 = JobID(dat = ['glue'],
|
|
|
|
|
subdat = 'cola',
|
|
|
|
|
mod = 'bestnn')
|
|
|
|
|
return False for partial_jobid1 and True for partial_jobid2
|
|
|
|
|
"""
|
|
|
|
|
is_not_match = False
|
|
|
|
|
for key, val in partial_jobid.__dict__.items():
|
|
|
|
|
if val is None:
|
|
|
|
|
continue
|
|
|
|
|
if getattr(self, key) != val:
|
|
|
|
|
is_not_match = True
|
|
|
|
|
return not is_not_match
|
|
|
|
|
|
|
|
|
|
def to_wandb_string(self):
|
|
|
|
|
"""
|
|
|
|
|
preparing for the job ID for wandb
|
|
|
|
|
"""
|
|
|
|
|
field_dict = self.__dict__
|
2021-06-14 17:11:40 -04:00
|
|
|
keytoval_str = "_".join([JobID.dataset_list_to_str(field_dict[key])
|
2021-06-09 11:37:03 -04:00
|
|
|
if type(field_dict[key]) == list
|
|
|
|
|
else str(field_dict[key])
|
|
|
|
|
for key in field_dict.keys() if not key.endswith("_full")])
|
|
|
|
|
return keytoval_str
|
|
|
|
|
|
|
|
|
|
def to_jobid_string(self):
|
|
|
|
|
"""
|
|
|
|
|
convert the current JobID into a blob name string which contains all the fields
|
|
|
|
|
"""
|
|
|
|
|
list_keys = list(JobID.__dataclass_fields__.keys())
|
|
|
|
|
field_dict = self.__dict__
|
2021-06-14 17:11:40 -04:00
|
|
|
keytoval_str = "_".join([key + "=" + JobID.dataset_list_to_str(field_dict[key])
|
2021-06-09 11:37:03 -04:00
|
|
|
if type(field_dict[key]) == list
|
|
|
|
|
else key + "=" + str(field_dict[key])
|
|
|
|
|
for key in list_keys if not key.endswith("_full")])
|
|
|
|
|
return keytoval_str
|
|
|
|
|
|
|
|
|
|
def to_partial_jobid_string(self):
|
|
|
|
|
"""
|
|
|
|
|
convert the current JobID into a blob name string which only contains the fields whose values are not "None"
|
|
|
|
|
"""
|
|
|
|
|
list_keys = list(JobID.__dataclass_fields__.keys())
|
|
|
|
|
field_dict = self.__dict__ # field_dict contains fields whose values are not None
|
2021-06-14 17:11:40 -04:00
|
|
|
keytoval_str = "_".join([key + "=" + JobID.dataset_list_to_str(field_dict[key])
|
2021-06-09 11:37:03 -04:00
|
|
|
if type(field_dict[key]) == list
|
|
|
|
|
else key + "=" + str(field_dict[key])
|
|
|
|
|
for key in list_keys if key in field_dict.keys()])
|
|
|
|
|
return keytoval_str
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def blobname_to_jobid_dict(keytoval_str):
|
|
|
|
|
"""
|
|
|
|
|
converting an azure blobname to a JobID config,
|
|
|
|
|
e.g., blobname = "dat=glue_subdat=cola_mod=bestnn_spa=buni_arg=cus_
|
|
|
|
|
alg=bs_pru=None_pre=funnel_presz=xlarge_spt=rspt_rep=0.json"
|
|
|
|
|
the converted jobid dict = {dat = ['glue'], subdat = 'cola', mod = 'bestnn',
|
|
|
|
|
spa = 'buni', arg = 'cus', alg = 'bs', pru = 'None',
|
|
|
|
|
pre = 'funnel', presz = 'xlarge', spt = 'rspt',
|
|
|
|
|
rep = 0, sddt = 43, sdhf = 42)
|
|
|
|
|
"""
|
2021-06-14 17:11:40 -04:00
|
|
|
field_keys = [key for key in list(JobID.__dataclass_fields__.keys()) if not key.endswith("_full")]
|
|
|
|
|
regex_expression = ".*"
|
|
|
|
|
is_first = True
|
|
|
|
|
for key in field_keys:
|
|
|
|
|
if is_first:
|
|
|
|
|
prefix = ""
|
|
|
|
|
is_first = False
|
|
|
|
|
else:
|
|
|
|
|
prefix = "_"
|
|
|
|
|
if key.startswith("sd") or key.startswith("var"):
|
|
|
|
|
regex_expression += "(" + prefix + key + "=(?P<" + key + ">[^_]*))?"
|
|
|
|
|
else:
|
|
|
|
|
regex_expression += prefix + key + "=(?P<" + key + ">[^_]*)"
|
|
|
|
|
regex_expression += ".(json|zip)"
|
2021-06-09 11:37:03 -04:00
|
|
|
result = re.search(regex_expression, keytoval_str)
|
|
|
|
|
if result:
|
|
|
|
|
result_dict = {}
|
|
|
|
|
for key in field_keys:
|
|
|
|
|
if key == "dat":
|
|
|
|
|
result_dict[key] = [result.group(key)]
|
|
|
|
|
elif key == "rep":
|
|
|
|
|
try:
|
2021-06-14 17:11:40 -04:00
|
|
|
try:
|
|
|
|
|
result_dict[key] = int(result.group(key))
|
|
|
|
|
except IndexError:
|
|
|
|
|
print("No group {} in the regex result".format(key))
|
|
|
|
|
result_dict[key] = -1
|
|
|
|
|
except ValueError:
|
|
|
|
|
print("Cannot parse integer {}".format(result.group(key)))
|
2021-06-09 11:37:03 -04:00
|
|
|
result_dict[key] = -1
|
|
|
|
|
else:
|
|
|
|
|
result_dict[key] = result.group(key)
|
|
|
|
|
return result_dict
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2021-06-14 17:11:40 -04:00
|
|
|
def dataset_list_to_str(dataset_name, key="dat"):
|
|
|
|
|
if isinstance(dataset_name, list):
|
2021-06-09 11:37:03 -04:00
|
|
|
return "-".join(dataset_name)
|
|
|
|
|
else:
|
|
|
|
|
return dataset_name
|
|
|
|
|
|
|
|
|
|
def set_jobid_from_arg_list(self,
|
|
|
|
|
**jobid_list
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
set the jobid from a dict object
|
|
|
|
|
"""
|
|
|
|
|
for key in jobid_list.keys():
|
|
|
|
|
assert key in JobID.__dataclass_fields__.keys()
|
|
|
|
|
setattr(self, key, jobid_list[key])
|
2021-06-14 17:11:40 -04:00
|
|
|
if self.mod == "grid":
|
|
|
|
|
self.alg = "grid"
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def convert_blobname_to_jobid(blobname):
|
|
|
|
|
"""
|
|
|
|
|
converting a blobname string to a JobID object
|
|
|
|
|
"""
|
|
|
|
|
jobconfig_dict = JobID.blobname_to_jobid_dict(blobname)
|
|
|
|
|
if jobconfig_dict:
|
|
|
|
|
jobconfig = JobID()
|
|
|
|
|
jobconfig.set_jobid_from_arg_list(**jobconfig_dict)
|
|
|
|
|
return jobconfig
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2021-06-14 17:11:40 -04:00
|
|
|
def get_full_data_name(dataset_name: Union[list, str], subdataset_name=None):
|
2021-06-09 11:37:03 -04:00
|
|
|
"""
|
|
|
|
|
convert a dataset name and sub dataset name to a full dataset name
|
|
|
|
|
"""
|
2021-06-14 17:11:40 -04:00
|
|
|
if isinstance(dataset_name, list):
|
|
|
|
|
full_dataset_name = JobID.dataset_list_to_str(dataset_name)
|
|
|
|
|
else:
|
|
|
|
|
full_dataset_name = dataset_name
|
2021-06-09 11:37:03 -04:00
|
|
|
if subdataset_name:
|
|
|
|
|
full_dataset_name = full_dataset_name + "_" + subdataset_name
|
|
|
|
|
return full_dataset_name
|
|
|
|
|
|
|
|
|
|
def get_jobid_full_data_name(self):
|
|
|
|
|
"""
|
|
|
|
|
get the full dataset name of the current JobID object
|
|
|
|
|
"""
|
2021-06-14 17:11:40 -04:00
|
|
|
return JobID.get_full_data_name(JobID.dataset_list_to_str(self.dat), self.subdat)
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _extract_model_type_with_keywords_match(pre_full):
|
2021-06-14 17:11:40 -04:00
|
|
|
from ..hpo.grid_searchspace_auto import HF_MODEL_LIST
|
2021-06-09 11:37:03 -04:00
|
|
|
matched_model_type = []
|
|
|
|
|
for each_model_type in HF_MODEL_LIST:
|
|
|
|
|
if each_model_type in pre_full:
|
|
|
|
|
matched_model_type.append(each_model_type)
|
|
|
|
|
assert len(matched_model_type) > 0
|
|
|
|
|
return max(enumerate(matched_model_type), key=lambda x: len(x[1]))[1]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def extract_model_type(full_model_name):
|
2021-06-14 17:11:40 -04:00
|
|
|
from transformers import AutoConfig
|
2021-06-09 11:37:03 -04:00
|
|
|
model_config = AutoConfig.from_pretrained(full_model_name)
|
|
|
|
|
config_json_file = model_config.get_config_dict(full_model_name)[0]
|
|
|
|
|
try:
|
|
|
|
|
model_type = config_json_file["model_type"]
|
|
|
|
|
except KeyError:
|
2021-06-14 17:11:40 -04:00
|
|
|
print("config_json_file does not contain model_type, re-extracting with keywords matching")
|
|
|
|
|
model_type = JobID._extract_model_type_with_keywords_match(full_model_name)
|
2021-06-09 11:37:03 -04:00
|
|
|
return model_type
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2021-06-14 17:11:40 -04:00
|
|
|
def get_attrval_from_arg_or_dict(console_args: Union[argparse.ArgumentParser, dict], each_key):
|
|
|
|
|
if type(console_args) == argparse.Namespace:
|
|
|
|
|
return getattr(console_args, each_key)
|
|
|
|
|
else:
|
|
|
|
|
return console_args[each_key]
|
|
|
|
|
|
|
|
|
|
def set_jobid_from_console_args(self, console_args: Union[argparse.ArgumentParser, dict]):
|
|
|
|
|
from ..utils import pretrained_model_size_format_check, dataset_subdataset_name_format_check
|
|
|
|
|
console_to_jobid_key_mapping = {
|
|
|
|
|
"pretrained_model_size": "pre",
|
|
|
|
|
"dataset_subdataset_name": "dat",
|
|
|
|
|
"algo_mode": "mod",
|
|
|
|
|
"space_mode": "spa",
|
|
|
|
|
"search_alg_args_mode": "arg",
|
|
|
|
|
"algo_name": "alg",
|
|
|
|
|
"pruner": "pru",
|
|
|
|
|
"resplit_mode": "spt",
|
|
|
|
|
"rep_id": "rep",
|
|
|
|
|
"seed_data": "sddt",
|
|
|
|
|
"seed_transformers": "sdhf",
|
|
|
|
|
"optarg1": "var1",
|
|
|
|
|
"optarg2": "var2"
|
2021-06-09 11:37:03 -04:00
|
|
|
}
|
2021-06-14 17:11:40 -04:00
|
|
|
for each_key in console_to_jobid_key_mapping.keys():
|
2021-06-09 11:37:03 -04:00
|
|
|
try:
|
2021-06-14 17:11:40 -04:00
|
|
|
try:
|
|
|
|
|
if each_key == "dataset_subdataset_name":
|
2021-06-15 12:55:23 -04:00
|
|
|
dataset_subdataset_name_format_check(JobID.get_attrval_from_arg_or_dict(console_args, each_key))
|
2021-06-14 17:11:40 -04:00
|
|
|
self.dat = JobID.get_attrval_from_arg_or_dict(console_args, each_key).split(":")[0].split(",")
|
|
|
|
|
self.subdat = JobID.get_attrval_from_arg_or_dict(console_args, each_key).split(":")[1]
|
|
|
|
|
elif each_key == "pretrained_model_size":
|
2021-06-15 12:55:23 -04:00
|
|
|
pretrained_model_size_format_check(JobID.get_attrval_from_arg_or_dict(console_args, each_key))
|
2021-06-14 17:11:40 -04:00
|
|
|
self.pre_full = JobID.get_attrval_from_arg_or_dict(console_args, each_key).split(":")[0]
|
|
|
|
|
self.pre = JobID.extract_model_type(self.pre_full)
|
|
|
|
|
self.presz = JobID.get_attrval_from_arg_or_dict(console_args, each_key).split(":")[1]
|
|
|
|
|
else:
|
|
|
|
|
jobid_key = console_to_jobid_key_mapping[each_key]
|
|
|
|
|
attrval = JobID.get_attrval_from_arg_or_dict(console_args, each_key)
|
|
|
|
|
setattr(self, jobid_key, attrval)
|
|
|
|
|
except AttributeError:
|
|
|
|
|
print("console_args has no attribute {}, continue".format(each_key))
|
|
|
|
|
continue
|
2021-06-09 11:37:03 -04:00
|
|
|
except KeyError:
|
2021-06-14 17:11:40 -04:00
|
|
|
print("console_args has no attribute {}, continue".format(each_key))
|
|
|
|
|
continue
|
|
|
|
|
if self.mod == "grid":
|
|
|
|
|
self.alg = "grid"
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AzureUtils:
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
root_log_path=None,
|
|
|
|
|
console_args=None,
|
|
|
|
|
autohf=None):
|
2021-06-14 17:11:40 -04:00
|
|
|
from ..utils import get_wandb_azure_key
|
2021-06-09 11:37:03 -04:00
|
|
|
if root_log_path:
|
|
|
|
|
self.root_log_path = root_log_path
|
|
|
|
|
else:
|
|
|
|
|
self.root_log_path = "logs_azure"
|
2021-06-14 17:11:40 -04:00
|
|
|
self.jobid = autohf.jobid_config
|
2021-06-09 11:37:03 -04:00
|
|
|
self.console_args = console_args
|
|
|
|
|
self.autohf = autohf
|
|
|
|
|
if console_args:
|
|
|
|
|
wandb_key, azure_key, container_name = get_wandb_azure_key(console_args.key_path)
|
|
|
|
|
self._container_name = container_name
|
|
|
|
|
self._azure_key = azure_key
|
2021-06-14 17:11:40 -04:00
|
|
|
else:
|
|
|
|
|
self._container_name = self._azure_key = ""
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
def _get_complete_connection_string(self):
|
2021-06-14 17:11:40 -04:00
|
|
|
try:
|
|
|
|
|
return "DefaultEndpointsProtocol=https;AccountName=docws5141197765;AccountKey=" \
|
|
|
|
|
+ self._azure_key + ";EndpointSuffix=core.windows.net"
|
|
|
|
|
except AttributeError:
|
|
|
|
|
return "DefaultEndpointsProtocol=https;AccountName=docws5141197765;AccountKey=" \
|
|
|
|
|
";EndpointSuffix=core.windows.net"
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
def _init_azure_clients(self):
|
2021-06-14 17:11:40 -04:00
|
|
|
try:
|
|
|
|
|
from azure.storage.blob import ContainerClient
|
|
|
|
|
connection_string = self._get_complete_connection_string()
|
|
|
|
|
try:
|
|
|
|
|
container_client = ContainerClient.from_connection_string(conn_str=connection_string,
|
|
|
|
|
container_name=self._container_name)
|
|
|
|
|
return container_client
|
|
|
|
|
except ValueError:
|
|
|
|
|
print("AzureUtils._container_name is specified as: {}, "
|
|
|
|
|
"please correctly specify AzureUtils._container_name".format(self._container_name))
|
|
|
|
|
return None
|
|
|
|
|
except ImportError:
|
|
|
|
|
print("To use the azure storage component in flaml.nlp, run pip install azure-storage-blob")
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
def _init_blob_client(self,
|
|
|
|
|
local_file_path):
|
2021-06-14 17:11:40 -04:00
|
|
|
try:
|
|
|
|
|
from azure.storage.blob import BlobServiceClient
|
|
|
|
|
|
|
|
|
|
connection_string = self._get_complete_connection_string()
|
|
|
|
|
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
|
|
|
|
|
try:
|
|
|
|
|
blob_client = blob_service_client.get_blob_client(container=self._container_name, blob=local_file_path)
|
|
|
|
|
return blob_client
|
|
|
|
|
except ValueError:
|
|
|
|
|
print("_container_name is unspecified or wrongly specified, please specify _container_name in AzureUtils")
|
|
|
|
|
return None
|
|
|
|
|
except ImportError:
|
|
|
|
|
print("To use the azure storage component in flaml.nlp, run pip install azure-storage-blob")
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
def upload_local_file_to_azure(self, local_file_path):
|
2021-06-14 17:11:40 -04:00
|
|
|
try:
|
|
|
|
|
from azure.core.exceptions import HttpResponseError
|
|
|
|
|
try:
|
|
|
|
|
blob_client = self._init_blob_client(local_file_path)
|
|
|
|
|
if blob_client:
|
|
|
|
|
with open(local_file_path, "rb") as fin:
|
|
|
|
|
blob_client.upload_blob(fin, overwrite=True)
|
|
|
|
|
except HttpResponseError as err:
|
|
|
|
|
print("Cannot upload blob due to {}: {}".format("azure.core.exceptions.HttpResponseError",
|
|
|
|
|
err))
|
|
|
|
|
except ImportError:
|
|
|
|
|
print("To use the azure storage component in flaml.nlp, run pip install azure-storage-blob")
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
def download_azure_blob(self, blobname):
|
|
|
|
|
blob_client = self._init_blob_client(blobname)
|
2021-06-14 17:11:40 -04:00
|
|
|
if blob_client:
|
|
|
|
|
pathlib.Path(re.search("(?P<parent_path>^.*)/[^/]+$", blobname).group("parent_path")).mkdir(
|
|
|
|
|
parents=True, exist_ok=True)
|
|
|
|
|
with open(blobname, "wb") as fout:
|
|
|
|
|
fout.write(blob_client.download_blob().readall())
|
2021-06-09 11:37:03 -04:00
|
|
|
|
2021-06-14 17:11:40 -04:00
|
|
|
def extract_configscore_list_from_analysis(self,
|
|
|
|
|
analysis):
|
2021-06-09 11:37:03 -04:00
|
|
|
"""
|
|
|
|
|
Extracting a json object for storing the key information returned from tune.run
|
|
|
|
|
"""
|
2021-06-14 17:11:40 -04:00
|
|
|
configscore_list = []
|
2021-06-09 11:37:03 -04:00
|
|
|
for each_trial in analysis.trials:
|
|
|
|
|
trial_id = each_trial.trial_id
|
|
|
|
|
start_time = each_trial.start_time
|
|
|
|
|
last_update_time = each_trial.last_update_time
|
|
|
|
|
config = each_trial.config
|
|
|
|
|
try:
|
|
|
|
|
metric_score = each_trial.metric_analysis["eval_" + analysis.default_metric]
|
|
|
|
|
time_stamp = each_trial.metric_analysis['timestamp']
|
|
|
|
|
except KeyError:
|
2021-06-14 17:11:40 -04:00
|
|
|
print("KeyError, {} does not contain the key {} or {}".format("each_trial.metric_analysis",
|
|
|
|
|
"eval_" + analysis.default_metric,
|
|
|
|
|
"timestamp"))
|
|
|
|
|
metric_score = 0
|
|
|
|
|
time_stamp = 0
|
|
|
|
|
configscore_list.append(ConfigScore(
|
|
|
|
|
trial_id=trial_id,
|
|
|
|
|
start_time=start_time,
|
|
|
|
|
last_update_time=last_update_time,
|
|
|
|
|
config=config,
|
|
|
|
|
metric_score=metric_score,
|
|
|
|
|
time_stamp=time_stamp))
|
|
|
|
|
return configscore_list
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
def write_autohf_output(self,
|
2021-06-14 17:11:40 -04:00
|
|
|
configscore_list=None,
|
2021-06-09 11:37:03 -04:00
|
|
|
valid_metric=None,
|
|
|
|
|
predictions=None,
|
|
|
|
|
duration=None):
|
|
|
|
|
"""
|
|
|
|
|
write the key info from a job and upload to azure blob storage
|
|
|
|
|
"""
|
|
|
|
|
local_file_path = self.generate_local_json_path()
|
|
|
|
|
output_json = {}
|
2021-06-14 17:11:40 -04:00
|
|
|
if configscore_list:
|
|
|
|
|
output_json["val_log"] = [configscore.__dict__ for configscore in configscore_list]
|
2021-06-09 11:37:03 -04:00
|
|
|
if valid_metric:
|
|
|
|
|
output_json["valid_metric"] = valid_metric
|
|
|
|
|
if duration:
|
|
|
|
|
output_json["duration"] = duration
|
|
|
|
|
if len(output_json) > 0:
|
|
|
|
|
self.create_local_json_and_upload(output_json, local_file_path)
|
|
|
|
|
if predictions is not None:
|
|
|
|
|
self.create_local_prediction_and_upload(local_file_path, predictions)
|
|
|
|
|
|
|
|
|
|
def generate_local_json_path(self):
|
|
|
|
|
"""
|
|
|
|
|
return a path string for storing the json file locally
|
|
|
|
|
"""
|
|
|
|
|
full_dataset_name = self.jobid.get_jobid_full_data_name()
|
|
|
|
|
jobid_str = self.jobid.to_jobid_string()
|
|
|
|
|
local_file_path = os.path.join(self.root_log_path, full_dataset_name, jobid_str + ".json")
|
|
|
|
|
pathlib.Path(os.path.join(self.root_log_path, full_dataset_name)).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
return local_file_path
|
|
|
|
|
|
|
|
|
|
def create_local_json_and_upload(self, result_json, local_file_path):
|
|
|
|
|
with open(local_file_path, "w") as fout:
|
|
|
|
|
fout.write(json.dumps(result_json))
|
|
|
|
|
fout.flush()
|
|
|
|
|
self.upload_local_file_to_azure(local_file_path)
|
|
|
|
|
|
|
|
|
|
def create_local_prediction_and_upload(self,
|
|
|
|
|
local_json_file,
|
|
|
|
|
predictions):
|
|
|
|
|
"""
|
|
|
|
|
store predictions (a .zip file) locally and upload
|
|
|
|
|
"""
|
|
|
|
|
azure_save_file_name = local_json_file.split("/")[-1][:-5]
|
2021-06-14 17:11:40 -04:00
|
|
|
try:
|
|
|
|
|
output_dir = self.console_args.data_root_dir
|
|
|
|
|
except AttributeError:
|
|
|
|
|
print("console_args does not contain data_root_dir, loading the default value")
|
|
|
|
|
from ..utils import load_dft_args
|
|
|
|
|
console_args = load_dft_args()
|
|
|
|
|
output_dir = getattr(console_args, "data_root_dir")
|
2021-06-09 11:37:03 -04:00
|
|
|
local_archive_path = self.autohf.output_prediction(predictions,
|
2021-06-14 17:11:40 -04:00
|
|
|
output_prediction_path=output_dir + "result/",
|
2021-06-09 11:37:03 -04:00
|
|
|
output_zip_file_name=azure_save_file_name)
|
|
|
|
|
self.upload_local_file_to_azure(local_archive_path)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2021-06-14 17:11:40 -04:00
|
|
|
def is_after_earliest_time(this_blob, earliest_time: Tuple[int, int, int]):
|
2021-06-09 11:37:03 -04:00
|
|
|
import pytz
|
|
|
|
|
utc = pytz.UTC
|
|
|
|
|
if this_blob.last_modified >= utc.localize(datetime(earliest_time[0], earliest_time[1], earliest_time[2])):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
2021-06-14 17:11:40 -04:00
|
|
|
def get_configblob_from_partial_jobid(self,
|
|
|
|
|
root_log_path,
|
|
|
|
|
partial_jobid,
|
|
|
|
|
earliest_time: Tuple[int, int, int] = None):
|
2021-06-09 11:37:03 -04:00
|
|
|
"""
|
|
|
|
|
get all blobs whose jobid configs match the partial_jobid
|
|
|
|
|
"""
|
|
|
|
|
blob_list = []
|
|
|
|
|
container_client = self._init_azure_clients()
|
2021-06-14 17:11:40 -04:00
|
|
|
if container_client:
|
|
|
|
|
for each_blob in container_client.list_blobs():
|
|
|
|
|
if each_blob.name.startswith(root_log_path):
|
|
|
|
|
each_jobconfig = JobID.convert_blobname_to_jobid(each_blob.name)
|
|
|
|
|
is_append = False
|
|
|
|
|
if each_jobconfig:
|
|
|
|
|
if each_jobconfig.is_match(partial_jobid):
|
|
|
|
|
is_append = True
|
|
|
|
|
if earliest_time and not AzureUtils.is_after_earliest_time(each_blob, earliest_time):
|
|
|
|
|
is_append = False
|
|
|
|
|
if is_append:
|
|
|
|
|
blob_list.append((each_jobconfig, each_blob))
|
2021-06-09 11:37:03 -04:00
|
|
|
return blob_list
|
|
|
|
|
|
|
|
|
|
def get_config_and_score_from_partial_jobid(self,
|
2021-06-14 17:11:40 -04:00
|
|
|
root_log_path: str,
|
|
|
|
|
partial_jobid: JobID,
|
|
|
|
|
earliest_time: Tuple[int, int, int] = None):
|
2021-06-09 11:37:03 -04:00
|
|
|
"""
|
2021-06-14 17:11:40 -04:00
|
|
|
Extract the config and score list from a partial config id
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
root_log_path:
|
|
|
|
|
The root log path in azure blob storage, e.g., "logs_seed/"
|
|
|
|
|
|
|
|
|
|
partial_jobid:
|
|
|
|
|
The partial jobid for matching the blob list
|
|
|
|
|
|
|
|
|
|
earliest_time (optional):
|
|
|
|
|
The earliest starting time for any matched blob, for filtering out out-dated jobs,
|
|
|
|
|
format: (YYYY, MM, DD)
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
a ConfigScore list object which stores the config and scores list for each matched blob lists
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(root_log_path, str), "root_log_path must be of type str"
|
|
|
|
|
assert isinstance(partial_jobid, JobID), "partial_jobid must be of type JobID"
|
|
|
|
|
if earliest_time:
|
|
|
|
|
assert isinstance(earliest_time, tuple), "earliest_time must be a tuple of (YYYY, MM, DD)"
|
|
|
|
|
|
|
|
|
|
matched_blob_list = self.get_configblob_from_partial_jobid(
|
2021-06-09 11:37:03 -04:00
|
|
|
root_log_path,
|
|
|
|
|
partial_jobid,
|
|
|
|
|
earliest_time=earliest_time)
|
2021-06-14 17:11:40 -04:00
|
|
|
return self.get_config_and_score_from_matched_blob_list(matched_blob_list,
|
|
|
|
|
earliest_time)
|
2021-06-09 11:37:03 -04:00
|
|
|
|
2021-06-14 17:11:40 -04:00
|
|
|
def get_config_and_score_from_matched_blob_list(self,
|
|
|
|
|
matched_blob_list,
|
|
|
|
|
earliest_time: Tuple[int, int, int] = None):
|
2021-06-09 11:37:03 -04:00
|
|
|
"""
|
2021-06-14 17:11:40 -04:00
|
|
|
Extract the config and score list of one or multiple blobs
|
2021-06-09 11:37:03 -04:00
|
|
|
|
2021-06-14 17:11:40 -04:00
|
|
|
Args:
|
|
|
|
|
matched_blob_list:
|
|
|
|
|
matched blob list
|
2021-06-09 11:37:03 -04:00
|
|
|
|
2021-06-14 17:11:40 -04:00
|
|
|
Return:
|
|
|
|
|
a ConfigScore list object which stores the config and scores list for each matched blob lists
|
2021-06-09 11:37:03 -04:00
|
|
|
|
|
|
|
|
"""
|
2021-06-14 17:11:40 -04:00
|
|
|
matched_config_score_lists = []
|
2021-06-09 11:37:03 -04:00
|
|
|
for (each_jobconfig, each_blob) in matched_blob_list:
|
|
|
|
|
self.download_azure_blob(each_blob.name)
|
2021-06-14 17:11:40 -04:00
|
|
|
data_json = json.load(open(each_blob.name, "r"))
|
|
|
|
|
each_config_and_score_list = ConfigScoreList(
|
|
|
|
|
jobid_config=each_jobconfig,
|
|
|
|
|
blob_file=each_blob,
|
|
|
|
|
config_score_list=[ConfigScore(**each_dict) for each_dict in data_json['val_log']])
|
|
|
|
|
matched_config_score_lists.append(each_config_and_score_list)
|
|
|
|
|
return matched_config_score_lists
|