This commit is contained in:
skzhang1 2022-08-13 18:51:33 +00:00
parent e3c9da50da
commit fc633ef15e

View File

@ -427,23 +427,7 @@ def get_val_loss(
train_time = time.time() - start train_time = time.time() - start
return val_loss, metric_for_logging, train_time, pred_time return val_loss, metric_for_logging, train_time, pred_time
def default_cv_score_agg_func(metrics_across_folds):
def evaluate_model_CV(
config,
estimator,
X_train_all,
y_train_all,
budget,
kf,
task,
eval_metric,
best_val_loss,
cv_score_agg_func,
log_training_metric=False,
fit_kwargs={},
):
if cv_score_agg_func is None:
def cv_score_agg_func(metrics_across_folds):
metric_to_minimize = sum([tem[0] for tem in metrics_across_folds])/len(metrics_across_folds) metric_to_minimize = sum([tem[0] for tem in metrics_across_folds])/len(metrics_across_folds)
metrics_to_log = None metrics_to_log = None
for single_fold in metrics_across_folds: for single_fold in metrics_across_folds:
@ -457,13 +441,29 @@ def evaluate_model_CV(
n = len(metrics_across_folds) n = len(metrics_across_folds)
metrics_to_log = {k: v / n for k, v in metrics_to_log.items()} metrics_to_log = {k: v / n for k, v in metrics_to_log.items()}
return metric_to_minimize, metrics_to_log return metric_to_minimize, metrics_to_log
def evaluate_model_CV(
config,
estimator,
X_train_all,
y_train_all,
budget,
kf,
task,
eval_metric,
best_val_loss,
cv_score_agg_func = None,
log_training_metric=False,
fit_kwargs={},
):
if cv_score_agg_func is None:
cv_score_agg_func = default_cv_score_agg_func
start_time = time.time() start_time = time.time()
val_loss_folds = [] val_loss_folds = []
log_metric_folds = [] log_metric_folds = []
total_metric = None
metric = None metric = None
train_time = pred_time = 0 train_time = pred_time = 0
valid_fold_num = total_fold_num = 0 total_fold_num = 0
n = kf.get_n_splits() n = kf.get_n_splits()
X_train_split, y_train_split = X_train_all, y_train_all X_train_split, y_train_split = X_train_all, y_train_all
if task in CLASSIFICATION: if task in CLASSIFICATION:
@ -485,7 +485,6 @@ def evaluate_model_CV(
else: else:
kf = kf.split(X_train_split) kf = kf.split(X_train_split)
rng = np.random.RandomState(2020) rng = np.random.RandomState(2020)
val_loss_list = []
budget_per_train = budget / n budget_per_train = budget / n
if "sample_weight" in fit_kwargs: if "sample_weight" in fit_kwargs:
weight = fit_kwargs["sample_weight"] weight = fit_kwargs["sample_weight"]
@ -530,32 +529,20 @@ def evaluate_model_CV(
) )
if weight is not None: if weight is not None:
fit_kwargs["sample_weight"] = weight fit_kwargs["sample_weight"] = weight
valid_fold_num += 1
total_fold_num += 1 total_fold_num += 1
val_loss_folds.append(val_loss_i) val_loss_folds.append(val_loss_i)
if log_training_metric or not isinstance(eval_metric, str): if log_training_metric or not isinstance(eval_metric, str):
if isinstance(metric_i, dict): if isinstance(metric_i, dict):
log_metric_folds.append(metric_i) log_metric_folds.append(metric_i)
elif total_metric is not None:
total_metric += metric_i
else:
total_metric = metric_i
train_time += train_time_i train_time += train_time_i
pred_time += pred_time_i pred_time += pred_time_i
if valid_fold_num == n: if time.time() - start_time >= budget:
val_loss_list.append(cv_score_agg_func(list(zip(val_loss_folds,[None]*len(val_loss_folds))))[0])
valid_fold_num = 0
val_loss_folds = []
elif time.time() - start_time >= budget:
val_loss_list.append(cv_score_agg_func(list(zip(val_loss_folds,[None]*len(val_loss_folds))))[0])
break break
val_loss = np.max(val_loss_list)
n = total_fold_num
if log_training_metric or not isinstance(eval_metric, str): if log_training_metric or not isinstance(eval_metric, str):
if len(log_metric_folds): val_loss, metric = cv_score_agg_func(list(zip([0]*len(log_metric_folds),log_metric_folds)))
metric = cv_score_agg_func(list(zip([0]*len(log_metric_folds),log_metric_folds)))[1]
else: else:
metric = total_metric / n val_loss, metric = cv_score_agg_func(list(zip(val_loss_folds,[None]*len(val_loss_folds))))
n = total_fold_num
pred_time /= n pred_time /= n
return val_loss, metric, train_time, pred_time return val_loss, metric, train_time, pred_time