mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-09 06:14:03 +00:00
678 lines
29 KiB
Python
678 lines
29 KiB
Python
|
|
import re
|
||
|
|
import pathlib
|
||
|
|
import os
|
||
|
|
from azure.storage.blob import BlobServiceClient, ContainerClient
|
||
|
|
from transformers import AutoConfig
|
||
|
|
|
||
|
|
from ..utils import get_wandb_azure_key
|
||
|
|
from datetime import datetime
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from ..hpo.grid_searchspace_auto import HF_MODEL_LIST
|
||
|
|
import json
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class JobID:
|
||
|
|
dat: list = field(default=None)
|
||
|
|
subdat: str = field(default=None)
|
||
|
|
mod: str = field(default=None)
|
||
|
|
spa: str = field(default=None)
|
||
|
|
arg: str = field(default=None)
|
||
|
|
alg: str = field(default=None)
|
||
|
|
pru: str = field(default=None)
|
||
|
|
pre_full: str = field(default=None)
|
||
|
|
pre: str = field(default=None)
|
||
|
|
presz: str = field(default=None)
|
||
|
|
spt: str = field(default=None)
|
||
|
|
rep: int = field(default=0)
|
||
|
|
sddt: int = field(default=None)
|
||
|
|
sdhf: int = field(default=None)
|
||
|
|
|
||
|
|
def __init__(self,
|
||
|
|
console_args=None):
|
||
|
|
if console_args:
|
||
|
|
self.set_jobid_from_console_args(console_args)
|
||
|
|
|
||
|
|
def set_unittest_config(self):
|
||
|
|
"""
|
||
|
|
set the JobID config for unit test
|
||
|
|
"""
|
||
|
|
self.dat = ["glue"]
|
||
|
|
self.subdat = "mrpc"
|
||
|
|
self.mod = "hpo"
|
||
|
|
self.spa = "uni_test"
|
||
|
|
self.arg = "dft"
|
||
|
|
self.alg = "bs"
|
||
|
|
self.pru = "None"
|
||
|
|
self.pre_full = "google/mobilebert-uncased"
|
||
|
|
self.pre = "mobilebert"
|
||
|
|
self.presz = "small"
|
||
|
|
self.spt = "rspt"
|
||
|
|
self.rep = 0
|
||
|
|
self.sddt = 43
|
||
|
|
self.sdhf = 42
|
||
|
|
|
||
|
|
def is_match(self, partial_jobid):
|
||
|
|
"""
|
||
|
|
return a boolean variable whether the current object matches the partial jobid defined
|
||
|
|
in partial_jobid. For example,
|
||
|
|
self = JobID(dat = ['glue'],
|
||
|
|
subdat = 'cola',
|
||
|
|
mod = 'bestnn',
|
||
|
|
spa = 'buni',
|
||
|
|
arg = 'cus',
|
||
|
|
alg = 'bs',
|
||
|
|
pru = 'None',
|
||
|
|
pre = 'funnel',
|
||
|
|
presz = 'xlarge',
|
||
|
|
spt = 'rspt',
|
||
|
|
rep = 0,
|
||
|
|
sddt = 43,
|
||
|
|
sdhf = 42)
|
||
|
|
partial_jobid1 = JobID(dat = ['glue'],
|
||
|
|
subdat = 'cola',
|
||
|
|
mod = 'hpo')
|
||
|
|
partial_jobid2 = JobID(dat = ['glue'],
|
||
|
|
subdat = 'cola',
|
||
|
|
mod = 'bestnn')
|
||
|
|
return False for partial_jobid1 and True for partial_jobid2
|
||
|
|
"""
|
||
|
|
is_not_match = False
|
||
|
|
for key, val in partial_jobid.__dict__.items():
|
||
|
|
if val is None:
|
||
|
|
continue
|
||
|
|
if getattr(self, key) != val:
|
||
|
|
is_not_match = True
|
||
|
|
return not is_not_match
|
||
|
|
|
||
|
|
def to_wandb_string(self):
|
||
|
|
"""
|
||
|
|
preparing for the job ID for wandb
|
||
|
|
"""
|
||
|
|
field_dict = self.__dict__
|
||
|
|
keytoval_str = "_".join([JobID.dataset_list_to_str(field_dict[key], key)
|
||
|
|
if type(field_dict[key]) == list
|
||
|
|
else str(field_dict[key])
|
||
|
|
for key in field_dict.keys() if not key.endswith("_full")])
|
||
|
|
return keytoval_str
|
||
|
|
|
||
|
|
def to_jobid_string(self):
|
||
|
|
"""
|
||
|
|
convert the current JobID into a blob name string which contains all the fields
|
||
|
|
"""
|
||
|
|
list_keys = list(JobID.__dataclass_fields__.keys())
|
||
|
|
field_dict = self.__dict__
|
||
|
|
keytoval_str = "_".join([key + "=" + JobID.dataset_list_to_str(field_dict[key], key)
|
||
|
|
if type(field_dict[key]) == list
|
||
|
|
else key + "=" + str(field_dict[key])
|
||
|
|
for key in list_keys if not key.endswith("_full")])
|
||
|
|
return keytoval_str
|
||
|
|
|
||
|
|
def to_partial_jobid_string(self):
|
||
|
|
"""
|
||
|
|
convert the current JobID into a blob name string which only contains the fields whose values are not "None"
|
||
|
|
"""
|
||
|
|
list_keys = list(JobID.__dataclass_fields__.keys())
|
||
|
|
field_dict = self.__dict__ # field_dict contains fields whose values are not None
|
||
|
|
keytoval_str = "_".join([key + "=" + JobID.dataset_list_to_str(field_dict[key], key)
|
||
|
|
if type(field_dict[key]) == list
|
||
|
|
else key + "=" + str(field_dict[key])
|
||
|
|
for key in list_keys if key in field_dict.keys()])
|
||
|
|
return keytoval_str
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def blobname_to_jobid_dict(keytoval_str):
|
||
|
|
"""
|
||
|
|
converting an azure blobname to a JobID config,
|
||
|
|
e.g., blobname = "dat=glue_subdat=cola_mod=bestnn_spa=buni_arg=cus_
|
||
|
|
alg=bs_pru=None_pre=funnel_presz=xlarge_spt=rspt_rep=0.json"
|
||
|
|
the converted jobid dict = {dat = ['glue'], subdat = 'cola', mod = 'bestnn',
|
||
|
|
spa = 'buni', arg = 'cus', alg = 'bs', pru = 'None',
|
||
|
|
pre = 'funnel', presz = 'xlarge', spt = 'rspt',
|
||
|
|
rep = 0, sddt = 43, sdhf = 42)
|
||
|
|
"""
|
||
|
|
field_keys = [key for key in
|
||
|
|
list(JobID.__dataclass_fields__.keys()) if not key.endswith("_full")]
|
||
|
|
regex_expression = ".*" + "_".join([key + "=(?P<" + key + ">.*)" for key in field_keys]) + ".(json|zip)"
|
||
|
|
result = re.search(regex_expression, keytoval_str)
|
||
|
|
if result:
|
||
|
|
result_dict = {}
|
||
|
|
for key in field_keys:
|
||
|
|
if key == "dat":
|
||
|
|
result_dict[key] = [result.group(key)]
|
||
|
|
elif key == "rep":
|
||
|
|
try:
|
||
|
|
result_dict[key] = int(result.group(key))
|
||
|
|
except IndexError:
|
||
|
|
result_dict[key] = -1
|
||
|
|
else:
|
||
|
|
result_dict[key] = result.group(key)
|
||
|
|
return result_dict
|
||
|
|
else:
|
||
|
|
return None
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def dataset_list_to_str(dataset_name, key):
|
||
|
|
if key == "dat":
|
||
|
|
assert isinstance(dataset_name, list)
|
||
|
|
return "-".join(dataset_name)
|
||
|
|
else:
|
||
|
|
return dataset_name
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def set_jobid_from_arg_list(self,
|
||
|
|
**jobid_list
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
set the jobid from a dict object
|
||
|
|
"""
|
||
|
|
|
||
|
|
for key in jobid_list.keys():
|
||
|
|
assert key in JobID.__dataclass_fields__.keys()
|
||
|
|
setattr(self, key, jobid_list[key])
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def convert_blobname_to_jobid(blobname):
|
||
|
|
"""
|
||
|
|
converting a blobname string to a JobID object
|
||
|
|
"""
|
||
|
|
jobconfig_dict = JobID.blobname_to_jobid_dict(blobname)
|
||
|
|
if jobconfig_dict:
|
||
|
|
jobconfig = JobID()
|
||
|
|
jobconfig.set_jobid_from_arg_list(**jobconfig_dict)
|
||
|
|
return jobconfig
|
||
|
|
else:
|
||
|
|
return None
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_full_data_name(dataset_name, subdataset_name=None):
|
||
|
|
"""
|
||
|
|
convert a dataset name and sub dataset name to a full dataset name
|
||
|
|
"""
|
||
|
|
full_dataset_name = dataset_name
|
||
|
|
if subdataset_name:
|
||
|
|
full_dataset_name = full_dataset_name + "_" + subdataset_name
|
||
|
|
return full_dataset_name
|
||
|
|
|
||
|
|
def get_jobid_full_data_name(self):
|
||
|
|
"""
|
||
|
|
get the full dataset name of the current JobID object
|
||
|
|
"""
|
||
|
|
return JobID.get_full_data_name(JobID.dataset_list_to_str(self.dat, "dat"), self.subdat)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _extract_model_type_with_keywords_match(pre_full):
|
||
|
|
matched_model_type = []
|
||
|
|
for each_model_type in HF_MODEL_LIST:
|
||
|
|
if each_model_type in pre_full:
|
||
|
|
matched_model_type.append(each_model_type)
|
||
|
|
assert len(matched_model_type) > 0
|
||
|
|
return max(enumerate(matched_model_type), key=lambda x: len(x[1]))[1]
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def extract_model_type(full_model_name):
|
||
|
|
model_config = AutoConfig.from_pretrained(full_model_name)
|
||
|
|
config_json_file = model_config.get_config_dict(full_model_name)[0]
|
||
|
|
try:
|
||
|
|
model_type = config_json_file["model_type"]
|
||
|
|
except KeyError:
|
||
|
|
model_type = JobID._extract_model_type_with_keywords_match()
|
||
|
|
return model_type
|
||
|
|
|
||
|
|
def set_jobid_from_console_args(self, console_args):
|
||
|
|
self.dat = console_args.dataset_subdataset_name.split(":")[0].split(",")
|
||
|
|
self.subdat = console_args.dataset_subdataset_name.split(":")[1]
|
||
|
|
self.mod = console_args.algo_mode
|
||
|
|
self.spa = console_args.space_mode
|
||
|
|
self.arg = console_args.search_alg_args_mode
|
||
|
|
self.alg = console_args.algo_name
|
||
|
|
self.pru = console_args.pruner
|
||
|
|
self.pre_full = console_args.pretrained_model_size.split(":")[0]
|
||
|
|
self.pre = JobID.extract_model_type(self.pre_full)
|
||
|
|
self.presz = console_args.pretrained_model_size.split(":")[1]
|
||
|
|
self.spt = console_args.resplit_mode
|
||
|
|
self.rep = console_args.rep_id
|
||
|
|
self.sddt = console_args.seed_data
|
||
|
|
self.sdhf = console_args.seed_transformers
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def legacy_old_blobname_to_new_blobname(self,
|
||
|
|
old_blobname):
|
||
|
|
spa_id2val = {
|
||
|
|
0: "gnr",
|
||
|
|
1: "uni"
|
||
|
|
}
|
||
|
|
alg_id2val = {
|
||
|
|
0: "bs",
|
||
|
|
1: "optuna",
|
||
|
|
2: "cfo"
|
||
|
|
}
|
||
|
|
pre_id2val = {
|
||
|
|
0: "xlnet-base-cased",
|
||
|
|
1: "albert-large-v1",
|
||
|
|
2: "distilbert-base-uncased",
|
||
|
|
3: "microsoft/deberta-base",
|
||
|
|
4: "funnel-transformer/small-base",
|
||
|
|
5: "microsoft/deberta-large",
|
||
|
|
6: "funnel-transformer/large-base",
|
||
|
|
7: "funnel-transformer/intermediate-base",
|
||
|
|
8: "funnel-transformer/xlarge-base"
|
||
|
|
}
|
||
|
|
presz_id2val = {
|
||
|
|
0: "base",
|
||
|
|
1: "small",
|
||
|
|
2: "base",
|
||
|
|
3: "base",
|
||
|
|
4: "base",
|
||
|
|
5: "large",
|
||
|
|
6: "large",
|
||
|
|
7: "intermediate",
|
||
|
|
8: "xlarge"
|
||
|
|
}
|
||
|
|
spt_id2val = {
|
||
|
|
0: "rspt",
|
||
|
|
1: "ori"
|
||
|
|
}
|
||
|
|
result_grid = re.search(r".*_mod(el)?(?P<model_id>\d+)_None_None(_spt(?P<split_id>\d+))?_rep(?P<rep_id>\d+).log",
|
||
|
|
old_blobname)
|
||
|
|
result = re.search(
|
||
|
|
r".*_mod(el)?(?P<model_id>\d+)_(alg)?(?P<algo_id>\d+)_(spa)?"
|
||
|
|
r"(?P<space_id>\d+)(_spt(?P<split_id>\d+))?_rep(?P<rep_id>\d+).log",
|
||
|
|
old_blobname)
|
||
|
|
if result_grid:
|
||
|
|
dat = [old_blobname.split("/")[1].split("_")[0]]
|
||
|
|
subdat = old_blobname.split("/")[1].split("_")[1]
|
||
|
|
mod = "hpo"
|
||
|
|
spa = None
|
||
|
|
arg = None
|
||
|
|
alg = None
|
||
|
|
pru = None
|
||
|
|
pre = pre_id2val[int(result_grid.group("model_id"))]
|
||
|
|
presz = presz_id2val[int(result_grid.group("model_id"))]
|
||
|
|
try:
|
||
|
|
spt = spt_id2val[int(result_grid.group("split_id"))]
|
||
|
|
except KeyError:
|
||
|
|
spt = spt_id2val[0]
|
||
|
|
rep = None
|
||
|
|
self.set_jobid_from_arg_list(dat, subdat, mod, spa, arg, alg, pru, pre, presz, spt, rep)
|
||
|
|
return self.to_jobid_string()
|
||
|
|
if result:
|
||
|
|
dat = [old_blobname.split("/")[1].split("_")[0]]
|
||
|
|
subdat = old_blobname.split("/")[1].split("_")[1]
|
||
|
|
mod = "hpo"
|
||
|
|
spa = spa_id2val[int(result.group("space_id"))]
|
||
|
|
arg = "dft"
|
||
|
|
alg = alg_id2val[int(result.group("algo_id"))]
|
||
|
|
pru = "None"
|
||
|
|
pre = pre_id2val[int(result_grid.group("model_id"))]
|
||
|
|
presz = presz_id2val[int(result_grid.group("model_id"))]
|
||
|
|
try:
|
||
|
|
spt = spt_id2val[int(result_grid.group("split_id"))]
|
||
|
|
except KeyError:
|
||
|
|
spt = spt_id2val[0]
|
||
|
|
rep = int(result.group("rep_id"))
|
||
|
|
self.set_jobid_from_arg_list(dat, subdat, mod, spa, arg, alg, pru, pre, presz, spt, rep)
|
||
|
|
return self.to_jobid_string()
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
class AzureUtils:
|
||
|
|
|
||
|
|
def __init__(self,
|
||
|
|
root_log_path=None,
|
||
|
|
console_args=None,
|
||
|
|
jobid=None,
|
||
|
|
autohf=None):
|
||
|
|
if root_log_path:
|
||
|
|
self.root_log_path = root_log_path
|
||
|
|
else:
|
||
|
|
self.root_log_path = "logs_azure"
|
||
|
|
self.jobid = jobid
|
||
|
|
self.console_args = console_args
|
||
|
|
self.autohf = autohf
|
||
|
|
if console_args:
|
||
|
|
wandb_key, azure_key, container_name = get_wandb_azure_key(console_args.key_path)
|
||
|
|
self._container_name = container_name
|
||
|
|
self._azure_key = azure_key
|
||
|
|
|
||
|
|
def _get_complete_connection_string(self):
|
||
|
|
return "DefaultEndpointsProtocol=https;AccountName=docws5141197765;AccountKey=" \
|
||
|
|
+ self._azure_key + ";EndpointSuffix=core.windows.net"
|
||
|
|
|
||
|
|
def _init_azure_clients(self):
|
||
|
|
connection_string = self._get_complete_connection_string()
|
||
|
|
container_client = ContainerClient.from_connection_string(conn_str=connection_string,
|
||
|
|
container_name=self._container_name)
|
||
|
|
return container_client
|
||
|
|
|
||
|
|
def _init_blob_client(self,
|
||
|
|
local_file_path):
|
||
|
|
connection_string = self._get_complete_connection_string()
|
||
|
|
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
|
||
|
|
blob_client = blob_service_client.get_blob_client(container=self._container_name, blob=local_file_path)
|
||
|
|
return blob_client
|
||
|
|
|
||
|
|
def upload_local_file_to_azure(self, local_file_path):
|
||
|
|
blob_client = self._init_blob_client(local_file_path)
|
||
|
|
with open(local_file_path, "rb") as fin:
|
||
|
|
blob_client.upload_blob(fin, overwrite=True)
|
||
|
|
|
||
|
|
def download_azure_blob(self, blobname):
|
||
|
|
blob_client = self._init_blob_client(blobname)
|
||
|
|
pathlib.Path(re.search("(?P<parent_path>^.*)/[^/]+$", blobname).group("parent_path")).mkdir(
|
||
|
|
parents=True, exist_ok=True)
|
||
|
|
with open(blobname, "wb") as fout:
|
||
|
|
fout.write(blob_client.download_blob().readall())
|
||
|
|
|
||
|
|
def write_exception(self):
|
||
|
|
result_json = {
|
||
|
|
"timestamp": datetime.now(),
|
||
|
|
}
|
||
|
|
local_file_path = self.generate_local_json_path()
|
||
|
|
self.create_local_json_and_upload(result_json, local_file_path)
|
||
|
|
|
||
|
|
def extract_log_from_analysis(self,
|
||
|
|
analysis):
|
||
|
|
"""
|
||
|
|
Extracting a json object for storing the key information returned from tune.run
|
||
|
|
"""
|
||
|
|
json_log = []
|
||
|
|
for each_trial in analysis.trials:
|
||
|
|
trial_id = each_trial.trial_id
|
||
|
|
start_time = each_trial.start_time
|
||
|
|
last_update_time = each_trial.last_update_time
|
||
|
|
config = each_trial.config
|
||
|
|
try:
|
||
|
|
metric_score = each_trial.metric_analysis["eval_" + analysis.default_metric]
|
||
|
|
time_stamp = each_trial.metric_analysis['timestamp']
|
||
|
|
json_log.append({"trial_id": trial_id,
|
||
|
|
"start_time": start_time,
|
||
|
|
"last_update_time": last_update_time,
|
||
|
|
"config": config,
|
||
|
|
"metric_score": metric_score,
|
||
|
|
"time_stamp": time_stamp})
|
||
|
|
except KeyError:
|
||
|
|
pass
|
||
|
|
return json_log
|
||
|
|
|
||
|
|
def write_autohf_output(self,
|
||
|
|
json_log=None,
|
||
|
|
valid_metric=None,
|
||
|
|
predictions=None,
|
||
|
|
duration=None):
|
||
|
|
"""
|
||
|
|
write the key info from a job and upload to azure blob storage
|
||
|
|
"""
|
||
|
|
local_file_path = self.generate_local_json_path()
|
||
|
|
output_json = {}
|
||
|
|
if json_log:
|
||
|
|
output_json["val_log"] = json_log
|
||
|
|
if valid_metric:
|
||
|
|
output_json["valid_metric"] = valid_metric
|
||
|
|
if duration:
|
||
|
|
output_json["duration"] = duration
|
||
|
|
if len(output_json) > 0:
|
||
|
|
self.create_local_json_and_upload(output_json, local_file_path)
|
||
|
|
if predictions is not None:
|
||
|
|
self.create_local_prediction_and_upload(local_file_path, predictions)
|
||
|
|
|
||
|
|
def generate_local_json_path(self):
|
||
|
|
"""
|
||
|
|
return a path string for storing the json file locally
|
||
|
|
"""
|
||
|
|
full_dataset_name = self.jobid.get_jobid_full_data_name()
|
||
|
|
jobid_str = self.jobid.to_jobid_string()
|
||
|
|
local_file_path = os.path.join(self.root_log_path, full_dataset_name, jobid_str + ".json")
|
||
|
|
pathlib.Path(os.path.join(self.root_log_path, full_dataset_name)).mkdir(parents=True, exist_ok=True)
|
||
|
|
return local_file_path
|
||
|
|
|
||
|
|
def create_local_json_and_upload(self, result_json, local_file_path):
|
||
|
|
with open(local_file_path, "w") as fout:
|
||
|
|
fout.write(json.dumps(result_json))
|
||
|
|
fout.flush()
|
||
|
|
self.upload_local_file_to_azure(local_file_path)
|
||
|
|
|
||
|
|
def legacy_to_json(self):
|
||
|
|
container_client = self._init_azure_clients()
|
||
|
|
for old_blob in container_client.list_blobs():
|
||
|
|
new_jobid_str = self.jobid.legacy_old_blobname_to_new_blobname(old_blob.name)
|
||
|
|
if new_jobid_str:
|
||
|
|
self.download_azure_blob(old_blob.name)
|
||
|
|
with open(old_blob.name, "r") as fin:
|
||
|
|
alllines = fin.readlines()
|
||
|
|
wandb_group_name = alllines[0].rstrip("\n:")
|
||
|
|
timestamp = re.search(
|
||
|
|
r"timestamp:(?P<timestamp>.*):",
|
||
|
|
alllines[1].strip("\n")).group("timestamp")
|
||
|
|
duration = re.search(
|
||
|
|
r"duration:(?P<duration>.*)$",
|
||
|
|
alllines[3].strip("\n")).group("duration")
|
||
|
|
sample_num = int(re.search(
|
||
|
|
r"sample_num: (?P<sample_num>\d+)$",
|
||
|
|
alllines[4].strip("\n")).group("sample_num"))
|
||
|
|
validation = {"accuracy": float(re.search(
|
||
|
|
"validation accuracy: (?P<validation>.*)$",
|
||
|
|
alllines[2].strip("\n")).group("validation"))}
|
||
|
|
test = None
|
||
|
|
if len(alllines) > 6:
|
||
|
|
result_test = re.search("test accuracy:(?P<test>.*)$", alllines[6].strip("\n"))
|
||
|
|
if result_test:
|
||
|
|
test = json.loads(result_test.group("test"))
|
||
|
|
yml_file = None
|
||
|
|
if len(alllines) > 8:
|
||
|
|
if alllines[8].startswith("aml"):
|
||
|
|
yml_file = alllines[8].strip("\n")
|
||
|
|
new_json = {"wandb_group_name": wandb_group_name,
|
||
|
|
"validation": validation,
|
||
|
|
"test": test,
|
||
|
|
"timestamp": timestamp,
|
||
|
|
"duration": duration,
|
||
|
|
"sample_num": sample_num,
|
||
|
|
"yml_file": yml_file}
|
||
|
|
full_dataset_name = self.jobid.get_jobid_full_data_name()
|
||
|
|
new_blobname = os.path.join("logs_azure/", full_dataset_name, new_jobid_str + ".json")
|
||
|
|
self.create_local_json_and_upload(new_json, new_blobname)
|
||
|
|
|
||
|
|
def create_local_prediction_and_upload(self,
|
||
|
|
local_json_file,
|
||
|
|
predictions):
|
||
|
|
"""
|
||
|
|
store predictions (a .zip file) locally and upload
|
||
|
|
"""
|
||
|
|
azure_save_file_name = local_json_file.split("/")[-1][:-5]
|
||
|
|
local_archive_path = self.autohf.output_prediction(predictions,
|
||
|
|
output_prediction_path=self.console_args.data_root_dir + "result/",
|
||
|
|
output_zip_file_name=azure_save_file_name)
|
||
|
|
self.upload_local_file_to_azure(local_archive_path)
|
||
|
|
|
||
|
|
def get_ranked_configs(self, metric_mode):
|
||
|
|
"""
|
||
|
|
extract the configs (ranked in descebding order by the score) for the azure file of the current object
|
||
|
|
(defined by self.jobid_config)
|
||
|
|
"""
|
||
|
|
azure_file_path = self.generate_local_json_path()
|
||
|
|
self.download_azure_blob(azure_file_path)
|
||
|
|
|
||
|
|
json_log = json.load(open(azure_file_path, "r"))
|
||
|
|
assert "val_log" in json_log
|
||
|
|
|
||
|
|
trialid_to_score = {}
|
||
|
|
trialid_to_config = {}
|
||
|
|
|
||
|
|
for each_entry in json_log["val_log"]:
|
||
|
|
trial_id = each_entry["trial_id"]
|
||
|
|
config = each_entry["config"]
|
||
|
|
this_score = each_entry["metric_score"][metric_mode]
|
||
|
|
trialid_to_config[trial_id] = config
|
||
|
|
trialid_to_score[trial_id] = this_score
|
||
|
|
|
||
|
|
sorted_trialid_to_score = sorted(trialid_to_score.items(), key=lambda x: x[1], reverse=True)
|
||
|
|
return [trialid_to_config[entry[0]] for entry in sorted_trialid_to_score]
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def is_after_earliest_time(this_blob, earliest_time):
|
||
|
|
import pytz
|
||
|
|
utc = pytz.UTC
|
||
|
|
if this_blob.last_modified >= utc.localize(datetime(earliest_time[0], earliest_time[1], earliest_time[2])):
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
def get_blob_list_matching_partial_jobid(self, root_log_path, partial_jobid, earliest_time=None):
|
||
|
|
"""
|
||
|
|
get all blobs whose jobid configs match the partial_jobid
|
||
|
|
"""
|
||
|
|
blob_list = []
|
||
|
|
container_client = self._init_azure_clients()
|
||
|
|
jobid_config = JobID()
|
||
|
|
for each_blob in container_client.list_blobs():
|
||
|
|
if each_blob.name.startswith(root_log_path):
|
||
|
|
each_jobconfig = jobid_config.convert_blobname_to_jobid(each_blob.name)
|
||
|
|
is_append = False
|
||
|
|
if each_jobconfig:
|
||
|
|
if each_jobconfig.is_match(partial_jobid):
|
||
|
|
is_append = True
|
||
|
|
if earliest_time and not AzureUtils.is_after_earliest_time(each_blob, earliest_time):
|
||
|
|
is_append = False
|
||
|
|
if is_append:
|
||
|
|
blob_list.append((each_jobconfig, each_blob))
|
||
|
|
return blob_list
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def extract_config_and_score(blobname):
|
||
|
|
data_json = json.load(open(blobname, "r"))
|
||
|
|
return [(x['config'], x['metric_score']["max"], x['start_time']) for x in data_json['val_log']]
|
||
|
|
|
||
|
|
def get_config_and_score_from_partial_jobid(self,
|
||
|
|
root_log_path,
|
||
|
|
partial_jobid,
|
||
|
|
group_attrs,
|
||
|
|
method,
|
||
|
|
earliest_time=None):
|
||
|
|
"""
|
||
|
|
get the best config and best score for each job matching the partial_jobid
|
||
|
|
"""
|
||
|
|
matched_blob_list = self.get_blob_list_matching_partial_jobid(
|
||
|
|
root_log_path,
|
||
|
|
partial_jobid,
|
||
|
|
earliest_time=earliest_time)
|
||
|
|
group_dict = {}
|
||
|
|
for (each_jobconfig, each_blob) in matched_blob_list:
|
||
|
|
self.download_azure_blob(each_blob.name)
|
||
|
|
config_and_score = AzureUtils.extract_config_and_score(each_blob.name)
|
||
|
|
if method == "unsorted":
|
||
|
|
sorted_config_and_score = config_and_score
|
||
|
|
elif method == "sort_time":
|
||
|
|
sorted_config_and_score = sorted(config_and_score, key=lambda x: x[2], reverse=False)
|
||
|
|
else:
|
||
|
|
sorted_config_and_score = sorted(config_and_score, key=lambda x: x[1], reverse=True)
|
||
|
|
group_attr_list = []
|
||
|
|
for each_attr in group_attrs:
|
||
|
|
group_val = getattr(each_jobconfig, each_attr)
|
||
|
|
if isinstance(group_val, list):
|
||
|
|
group_attr_list.append(JobID.dataset_list_to_str(group_val, each_attr))
|
||
|
|
else:
|
||
|
|
group_attr_list.append(group_val)
|
||
|
|
group_attr_tuple = tuple(group_attr_list)
|
||
|
|
group_dict.setdefault(group_attr_tuple, [])
|
||
|
|
group_dict[group_attr_tuple].append([(config, score, each_blob.name)
|
||
|
|
for (config, score, ts) in sorted_config_and_score])
|
||
|
|
return group_dict
|
||
|
|
|
||
|
|
def get_validation_perf(self, console_args=None, partial_jobid_config=None):
|
||
|
|
"""
|
||
|
|
get the validation score for all blobs matching the partial_jobid_config
|
||
|
|
"""
|
||
|
|
if partial_jobid_config.pre == "electra":
|
||
|
|
dataset_namelist = ["wnli", "rte", "mrpc", "cola", "stsb", "sst2", "qnli", "mnli"]
|
||
|
|
else:
|
||
|
|
dataset_namelist = ["wnli", "rte", "mrpc", "cola", "stsb", "sst2"]
|
||
|
|
dataset_vallist1 = [0] * len(dataset_namelist)
|
||
|
|
dataset_vallist2 = [0] * len(dataset_namelist)
|
||
|
|
|
||
|
|
matched_blob_list = self.get_blob_list_matching_partial_jobid(console_args.azure_root_log_path,
|
||
|
|
partial_jobid_config)
|
||
|
|
for (each_jobconfig, each_blob) in matched_blob_list:
|
||
|
|
subdat_name = each_jobconfig.subdat
|
||
|
|
self.download_azure_blob(each_blob.name)
|
||
|
|
data_json = json.load(open(each_blob.name, "r"))
|
||
|
|
print(len(data_json["val_log"]))
|
||
|
|
validation_metric = data_json['valid_metric']
|
||
|
|
try:
|
||
|
|
dataset_idx = dataset_namelist.index(subdat_name)
|
||
|
|
dataset_vallist1[dataset_idx], dataset_vallist2[dataset_idx] \
|
||
|
|
= self.get_validation_metricstr(validation_metric)
|
||
|
|
except ValueError:
|
||
|
|
pass
|
||
|
|
# print(" & ".join(dataset_vallist1))
|
||
|
|
# print(", ,".join(dataset_vallist2))
|
||
|
|
|
||
|
|
def get_validation_metricstr(self, validation_metric):
|
||
|
|
"""
|
||
|
|
get a string representing validations for pasting to Google spreadsheet
|
||
|
|
"""
|
||
|
|
validation_str1 = validation_str2 = ""
|
||
|
|
is_first = True
|
||
|
|
for key in ["f1", "accuracy", "pearson", "spearmanr", "matthews_correlation"]:
|
||
|
|
if "eval_" + key in validation_metric.keys():
|
||
|
|
if is_first:
|
||
|
|
validation_str1 += str("%.1f" % (validation_metric["eval_" + key] * 100))
|
||
|
|
validation_str2 += str(validation_metric["eval_" + key] * 100)
|
||
|
|
is_first = False
|
||
|
|
else:
|
||
|
|
validation_str1 += "/" + str("%.1f" % (validation_metric["eval_" + key] * 100))
|
||
|
|
validation_str2 += "," + str(validation_metric["eval_" + key] * 100)
|
||
|
|
return validation_str1, validation_str2
|
||
|
|
|
||
|
|
def get_test_perf(self, partial_jobid_config=None, result_root_dir=None):
|
||
|
|
"""
|
||
|
|
get the test scores for all blobs matching the partial_jobid_config
|
||
|
|
"""
|
||
|
|
import shutil
|
||
|
|
from flaml.nlp.dataset.submission_auto import file_name_mapping_glue, output_blank_tsv
|
||
|
|
matched_blob_list = self.get_blob_list_matching_partial_jobid("data/", partial_jobid_config)
|
||
|
|
partial_jobid_str = partial_jobid_config.to_partial_jobid_string()
|
||
|
|
output_dir = os.path.join(result_root_dir, partial_jobid_str)
|
||
|
|
if os.path.exists(output_dir):
|
||
|
|
assert os.path.isdir(output_dir)
|
||
|
|
else:
|
||
|
|
os.mkdir(output_dir)
|
||
|
|
output_blank_tsv(output_dir)
|
||
|
|
|
||
|
|
for (each_jobconfig, each_blob) in matched_blob_list:
|
||
|
|
subdat_name = each_jobconfig.subdat
|
||
|
|
self.download_azure_blob(each_blob.name)
|
||
|
|
import zipfile
|
||
|
|
if os.path.exists(each_blob.name[:-4]):
|
||
|
|
assert os.path.isdir(each_blob.name[:-4])
|
||
|
|
else:
|
||
|
|
os.mkdir(each_blob.name[:-4])
|
||
|
|
with zipfile.ZipFile(each_blob.name, 'r') as zip_ref:
|
||
|
|
zip_ref.extractall(each_blob.name[:-4])
|
||
|
|
src = os.path.join(each_blob.name[:-4], file_name_mapping_glue[subdat_name][0])
|
||
|
|
dst = os.path.join(output_dir, file_name_mapping_glue[subdat_name][0])
|
||
|
|
shutil.copy(src, dst)
|
||
|
|
shutil.make_archive(os.path.join(output_dir), 'zip', output_dir)
|
||
|
|
|
||
|
|
def get_best_perf_config(self, console_args, jobid_config):
|
||
|
|
"""
|
||
|
|
get the config of the best performed trial
|
||
|
|
"""
|
||
|
|
matched_blob_list = self.get_blob_list_matching_partial_jobid(console_args.azure_root_log_path, jobid_config)
|
||
|
|
try:
|
||
|
|
assert len(matched_blob_list) == 1
|
||
|
|
except AssertionError:
|
||
|
|
import pdb
|
||
|
|
pdb.set_trace()
|
||
|
|
|
||
|
|
each_jobconfig, each_blob = matched_blob_list[0]
|
||
|
|
self.download_azure_blob(each_blob.name)
|
||
|
|
data_json = json.load(open(each_blob.name, "r"))
|
||
|
|
|
||
|
|
sorted_entries = sorted(data_json['val_log'], key=lambda x: x['metric_score']['max'], reverse=True)
|
||
|
|
best_config = sorted_entries[0]['config']
|
||
|
|
if jobid_config.subdat != "mrpc":
|
||
|
|
best_score = sorted_entries[0]['metric_score']['max']
|
||
|
|
else:
|
||
|
|
best_score = (data_json["valid_metric"]["eval_f1"], data_json["valid_metric"]["eval_accuracy"])
|
||
|
|
return best_config, best_score
|