update config if n_estimators is modified (#225)

* update config if n_estimators is modified

* prediction as int

* handle the case n_estimators <= 0

* if trained and no budget to train more, return the trained model

* split_type=group for classification & regression
This commit is contained in:
Chi Wang 2021-09-27 21:30:49 -07:00 committed by GitHub
parent 7d9e28f02d
commit a99e939404
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 163 additions and 108 deletions

View File

@ -117,9 +117,17 @@ class SearchState:
time2eval = result["time_total_s"]
trained_estimator = result["trained_estimator"]
del result["trained_estimator"] # free up RAM
n_iter = trained_estimator and trained_estimator.params.get("n_estimators")
if (
n_iter is not None
and "n_estimators" in config
and n_iter >= self._search_space_domain["n_estimators"].lower
):
config["n_estimators"] = n_iter
n_iter = None
else:
obj, time2eval, trained_estimator = np.inf, 0.0, None
metric_for_logging = config = None
metric_for_logging = config = n_iter = None
self.trial_time = time2eval
self.total_time_used += time_used
self.total_iter += 1
@ -147,8 +155,10 @@ class SearchState:
self.trained_estimator.cleanup()
if trained_estimator:
self.trained_estimator = trained_estimator
self.best_n_iter = n_iter
self.metric_for_logging = metric_for_logging
self.val_loss, self.config = obj, config
self.n_iter = n_iter
def get_hist_config_sig(self, sample_size, config):
config_values = tuple([config[k] for k in self._hp_names])
@ -251,7 +261,9 @@ class AutoMLState:
# tune.report(**result)
return result
def _train_with_config(self, estimator, config_w_resource, sample_size=None):
def _train_with_config(
self, estimator, config_w_resource, sample_size=None, n_iter=None
):
if not sample_size:
sample_size = config_w_resource.get(
"FLAML_sample_size", len(self.y_train_all)
@ -288,6 +300,7 @@ class AutoMLState:
self.n_jobs,
self.learner_classes.get(estimator),
budget,
n_iter,
self.fit_kwargs,
)
if sampled_weight is not None:
@ -444,7 +457,9 @@ class AutoML:
if y_pred.ndim > 1 and isinstance(y_pred, np.ndarray):
y_pred = y_pred.flatten()
if self._label_transformer:
return self._label_transformer.inverse_transform(pd.Series(y_pred))
return self._label_transformer.inverse_transform(
pd.Series(y_pred.astype(int))
)
else:
return y_pred
@ -606,7 +621,7 @@ class AutoML:
if (
self._state.task in ("binary", "multi")
and self._state.fit_kwargs.get("sample_weight") is None
and self._split_type != "time"
and self._split_type not in ["time", "group"]
):
# logger.info(f"label {pd.unique(y_train_all)}")
label_set, counts = np.unique(y_train_all, return_counts=True)
@ -695,12 +710,12 @@ class AutoML:
test_size=split_ratio,
shuffle=False,
)
elif self._state.task == "rank":
elif self._split_type == "group":
gss = GroupShuffleSplit(
n_splits=1, test_size=split_ratio, random_state=RANDOM_SEED
)
for train_idx, val_idx in gss.split(
X_train_all, y_train_all, self._state.groups
X_train_all, y_train_all, self._state.groups_all
):
if self._df:
X_train = X_train_all.iloc[train_idx]
@ -708,8 +723,8 @@ class AutoML:
else:
X_train, X_val = X_train_all[train_idx], X_train_all[val_idx]
y_train, y_val = y_train_all[train_idx], y_train_all[val_idx]
self._state.groups = self._state.groups[train_idx]
self._state.groups_val = self._state.groups[val_idx]
self._state.groups = self._state.groups_all[train_idx]
self._state.groups_val = self._state.groups_all[val_idx]
elif self._state.task in ("binary", "multi"):
# for classification, make sure the labels are complete in both
# training and validation data
@ -920,7 +935,7 @@ class AutoML:
n_splits: An integer of the number of folds for cross-validation.
split_type: str or None, default=None | the data split type.
For classification tasks, valid choices are [
None, 'stratified', 'uniform', 'time']. None -> stratified.
None, 'stratified', 'uniform', 'time', 'group']. None -> stratified.
For regression tasks, valid choices are [None, 'uniform', 'time'].
None -> uniform.
For time series forecasting, must be None or 'time'.
@ -1007,7 +1022,7 @@ class AutoML:
self._state.time_budget = None
self._state.n_jobs = n_jobs
self._trained_estimator = self._state._train_with_config(
best_estimator, best_config, sample_size
best_estimator, best_config, sample_size, best.n_iter
)[0]
logger.info("retrain from log succeeded")
return training_duration
@ -1018,10 +1033,12 @@ class AutoML:
len(np.unique(self._y_train_all))
)
if self._state.task in ("binary", "multi"):
assert split_type in [None, "stratified", "uniform", "time"]
self._split_type = split_type or "stratified"
assert split_type in [None, "stratified", "uniform", "time", "group"]
self._split_type = (
split_type or self._state.groups is None and "stratified" or "group"
)
elif self._state.task == "regression":
assert split_type in [None, "uniform", "time"]
assert split_type in [None, "uniform", "time", "group"]
self._split_type = split_type or "uniform"
elif self._state.task == "forecast":
assert split_type in [None, "time"]
@ -1420,15 +1437,16 @@ class AutoML:
self.verbose = verbose
if verbose == 0:
logger.setLevel(logging.WARNING)
self._decide_split_type(split_type)
if eval_method == "auto" or self._state.X_val is not None:
eval_method = self._decide_eval_method(time_budget)
self._state.eval_method = eval_method
if (not mlflow or not mlflow.active_run()) and not logger.handlers:
# Add the console handler.
_ch = logging.StreamHandler()
_ch.setFormatter(logger_formatter)
logger.addHandler(_ch)
self._decide_split_type(split_type)
logger.info(f"Data split method: {self._split_type}")
if eval_method == "auto" or self._state.X_val is not None:
eval_method = self._decide_eval_method(time_budget)
self._state.eval_method = eval_method
logger.info("Evaluation method: {}".format(eval_method))
self._retrain_in_budget = retrain_full == "budget" and (
@ -1697,10 +1715,9 @@ class AutoML:
self._state.time_from_start,
search_state.val_loss,
config,
self._state.best_loss,
search_state.best_config,
estimator,
search_state.sample_size,
search_state.n_iter,
)
def _search_sequential(self):
@ -1909,10 +1926,9 @@ class AutoML:
self._state.time_from_start,
search_state.val_loss,
search_state.config,
search_state.best_loss,
search_state.best_config,
estimator,
search_state.sample_size,
search_state.n_iter,
)
if mlflow is not None and mlflow.active_run():
with mlflow.start_run(nested=True):
@ -1985,10 +2001,12 @@ class AutoML:
<= est_retrain_time + next_trial_time
)
):
state = self._search_states[self._best_estimator]
self._trained_estimator, retrain_time = self._state._train_with_config(
self._best_estimator,
self._search_states[self._best_estimator].best_config,
state.best_config,
self.data_size_full,
state.best_n_iter,
)
logger.info(
"retrain {} for {:.1f}s".format(self._best_estimator, retrain_time)
@ -2093,13 +2111,15 @@ class AutoML:
> self._selected.est_retrain_time(self.data_size_full)
and self._selected.best_config_sample_size == self._state.data_size
):
state = self._search_states[self._best_estimator]
(
self._trained_estimator,
retrain_time,
) = self._state._train_with_config(
self._best_estimator,
self._search_states[self._best_estimator].best_config,
state.best_config,
self.data_size_full,
state.best_n_iter,
)
logger.info(
"retrain {} for {:.1f}s".format(

View File

@ -465,11 +465,14 @@ def train_estimator(
n_jobs=1,
estimator_class=None,
budget=None,
n_iter=None,
fit_kwargs={},
):
start_time = time.time()
estimator_class = estimator_class or get_estimator_class(task, estimator_name)
estimator = estimator_class(**config_dic, task=task, n_jobs=n_jobs)
if n_iter is not None:
estimator.params["n_estimators"] = n_iter
if X_train is not None:
train_time = estimator.fit(X_train, y_train, budget, **fit_kwargs)
else:

View File

@ -316,15 +316,18 @@ class LGBMEstimator(BaseEstimator):
def fit(self, X_train, y_train, budget=None, **kwargs):
start_time = time.time()
n_iter = self.params["n_estimators"]
trained = False
if (
not self._time_per_iter or abs(self._train_size - X_train.shape[0]) > 4
) and budget is not None:
(not self._time_per_iter or abs(self._train_size - X_train.shape[0]) > 4)
and budget is not None
and n_iter > 1
):
self.params["n_estimators"] = 1
self._t1 = self._fit(X_train, y_train, **kwargs)
if self._t1 >= budget:
self.params["n_estimators"] = n_iter
# self.params["n_estimators"] = n_iter
return self._t1
self.params["n_estimators"] = 4
self.params["n_estimators"] = min(n_iter, 4)
self._t2 = self._fit(X_train, y_train, **kwargs)
self._time_per_iter = (
(self._t2 - self._t1) / (self.params["n_estimators"] - 1)
@ -335,19 +338,24 @@ class LGBMEstimator(BaseEstimator):
)
self._train_size = X_train.shape[0]
if self._t1 + self._t2 >= budget or n_iter == self.params["n_estimators"]:
self.params["n_estimators"] = n_iter
# self.params["n_estimators"] = n_iter
return time.time() - start_time
if budget is not None:
self.params["n_estimators"] = min(
trained = True
if budget is not None and n_iter > 1:
max_iter = min(
n_iter,
int(
(budget - time.time() + start_time - self._t1) / self._time_per_iter
+ 1
),
)
if trained and max_iter <= self.params["n_estimators"]:
return time.time() - start_time
self.params["n_estimators"] = max_iter
if self.params["n_estimators"] > 0:
self._fit(X_train, y_train, **kwargs)
self.params["n_estimators"] = n_iter
else:
self.params["n_estimators"] = n_iter
train_time = time.time() - start_time
return train_time
@ -787,10 +795,15 @@ class CatBoostEstimator(BaseEstimator):
cat_features = []
# from catboost import CatBoostError
# try:
trained = False
if (
not CatBoostEstimator._time_per_iter
or abs(CatBoostEstimator._train_size - len(y_train)) > 4
) and budget:
(
not CatBoostEstimator._time_per_iter
or abs(CatBoostEstimator._train_size - len(y_train)) > 4
)
and budget
and n_iter > 4
):
# measure the time per iteration
self.params["n_estimators"] = 1
CatBoostEstimator._smallmodel = self.estimator_class(
@ -801,11 +814,11 @@ class CatBoostEstimator(BaseEstimator):
)
CatBoostEstimator._t1 = time.time() - start_time
if CatBoostEstimator._t1 >= budget:
self.params["n_estimators"] = n_iter
# self.params["n_estimators"] = n_iter
self._model = CatBoostEstimator._smallmodel
shutil.rmtree(train_dir, ignore_errors=True)
return CatBoostEstimator._t1
self.params["n_estimators"] = 4
self.params["n_estimators"] = min(n_iter, 4)
CatBoostEstimator._smallmodel = self.estimator_class(
train_dir=train_dir, **self.params
)
@ -822,13 +835,14 @@ class CatBoostEstimator(BaseEstimator):
time.time() - start_time >= budget
or n_iter == self.params["n_estimators"]
):
self.params["n_estimators"] = n_iter
# self.params["n_estimators"] = n_iter
self._model = CatBoostEstimator._smallmodel
shutil.rmtree(train_dir, ignore_errors=True)
return time.time() - start_time
if budget:
trained = True
if budget and n_iter > 4:
train_times = 1
self.params["n_estimators"] = min(
max_iter = min(
n_iter,
int(
(budget - time.time() + start_time - CatBoostEstimator._t1)
@ -838,6 +852,9 @@ class CatBoostEstimator(BaseEstimator):
),
)
self._model = CatBoostEstimator._smallmodel
if trained and max_iter <= self.params["n_estimators"]:
return time.time() - start_time
self.params["n_estimators"] = max_iter
if self.params["n_estimators"] > 0:
n = max(int(len(y_train) * 0.9), len(y_train) - 1000)
X_tr, y_tr = X_train[:n], y_train[:n]
@ -863,9 +880,10 @@ class CatBoostEstimator(BaseEstimator):
if weight is not None:
kwargs["sample_weight"] = weight
self._model = model
else:
self.params["n_estimators"] = n_iter
# except CatBoostError:
# self._model = None
self.params["n_estimators"] = n_iter
train_time = time.time() - start_time
return train_time

View File

@ -1,7 +1,7 @@
'''!
* Copyright (c) 2020-2021 Microsoft Corporation. All rights reserved.
"""!
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
'''
"""
import json
from typing import IO
@ -10,19 +10,19 @@ import warnings
class TrainingLogRecord(object):
def __init__(self,
record_id: int,
iter_per_learner: int,
logged_metric: float,
trial_time: float,
wall_clock_time: float,
validation_loss,
config,
best_validation_loss,
best_config,
learner,
sample_size):
def __init__(
self,
record_id: int,
iter_per_learner: int,
logged_metric: float,
trial_time: float,
wall_clock_time: float,
validation_loss: float,
config: dict,
learner: str,
sample_size: int,
n_iter: int,
):
self.record_id = record_id
self.iter_per_learner = iter_per_learner
self.logged_metric = logged_metric
@ -30,10 +30,9 @@ class TrainingLogRecord(object):
self.wall_clock_time = wall_clock_time
self.validation_loss = validation_loss
self.config = config
self.best_validation_loss = best_validation_loss
self.best_config = best_config
self.learner = learner
self.sample_size = sample_size
self.n_iter = n_iter # n_estimators for catboost
def dump(self, fp: IO[str]):
d = vars(self)
@ -49,75 +48,78 @@ class TrainingLogRecord(object):
class TrainingLogCheckPoint(TrainingLogRecord):
def __init__(self, curr_best_record_id: int):
self.curr_best_record_id = curr_best_record_id
class TrainingLogWriter(object):
def __init__(self, output_filename: str):
self.output_filename = output_filename
self.file = None
self.current_best_loss_record_id = None
self.current_best_loss = float('+inf')
self.current_best_loss = float("+inf")
self.current_sample_size = None
self.current_record_id = 0
def open(self):
self.file = open(self.output_filename, 'w')
self.file = open(self.output_filename, "w")
def append_open(self):
self.file = open(self.output_filename, 'a')
self.file = open(self.output_filename, "a")
def append(self,
it_counter: int,
train_loss: float,
trial_time: float,
wall_clock_time: float,
validation_loss,
config,
best_validation_loss,
best_config,
learner,
sample_size):
def append(
self,
it_counter: int,
train_loss: float,
trial_time: float,
wall_clock_time: float,
validation_loss,
config,
learner,
sample_size,
n_iter,
):
if self.file is None:
raise IOError("Call open() to open the outpute file first.")
if validation_loss is None:
raise ValueError('TEST LOSS NONE ERROR!!!')
record = TrainingLogRecord(self.current_record_id,
it_counter,
train_loss,
trial_time,
wall_clock_time,
validation_loss,
config,
best_validation_loss,
best_config,
learner,
sample_size)
if validation_loss < self.current_best_loss or \
validation_loss == self.current_best_loss and \
self.current_sample_size is not None and \
sample_size > self.current_sample_size:
raise ValueError("TEST LOSS NONE ERROR!!!")
record = TrainingLogRecord(
self.current_record_id,
it_counter,
train_loss,
trial_time,
wall_clock_time,
validation_loss,
config,
learner,
sample_size,
n_iter,
)
if (
validation_loss < self.current_best_loss
or validation_loss == self.current_best_loss
and self.current_sample_size is not None
and sample_size > self.current_sample_size
):
self.current_best_loss = validation_loss
self.current_sample_size = sample_size
self.current_best_loss_record_id = self.current_record_id
self.current_record_id += 1
record.dump(self.file)
self.file.write('\n')
self.file.write("\n")
self.file.flush()
def checkpoint(self):
if self.file is None:
raise IOError("Call open() to open the outpute file first.")
if self.current_best_loss_record_id is None:
warnings.warn("checkpoint() called before any record is written, "
"skipped.")
warnings.warn(
"checkpoint() called before any record is written, " "skipped."
)
return
record = TrainingLogCheckPoint(self.current_best_loss_record_id)
record.dump(self.file)
self.file.write('\n')
self.file.write("\n")
self.file.flush()
def close(self):
@ -127,7 +129,6 @@ class TrainingLogWriter(object):
class TrainingLogReader(object):
def __init__(self, filename: str):
self.filename = filename
self.file = None

View File

@ -1 +1 @@
__version__ = "0.6.5"
__version__ = "0.6.6"

View File

@ -17,7 +17,7 @@ def _test(split_type):
automl_settings = {
"time_budget": 2,
# "metric": 'accuracy',
"task": 'classification',
"task": "classification",
"log_file_name": "test/{}.log".format(dataset),
"model_history": True,
"log_training_metric": True,
@ -28,13 +28,16 @@ def _test(split_type):
X, y = fetch_openml(name=dataset, return_X_y=True)
except (ArffException, ValueError):
from sklearn.datasets import load_wine
X, y = load_wine(return_X_y=True)
if split_type != 'time':
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33,
random_state=42)
if split_type != "time":
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)
else:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33,
shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, shuffle=False
)
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
pred = automl.predict(X_test)
@ -53,36 +56,45 @@ def test_time():
def test_groups():
from sklearn.externals._arff import ArffException
try:
X, y = fetch_openml(name=dataset, return_X_y=True)
except (ArffException, ValueError):
from sklearn.datasets import load_wine
X, y = load_wine(return_X_y=True)
import numpy as np
automl = AutoML()
automl_settings = {
"time_budget": 2,
"task": 'classification',
"task": "classification",
"log_file_name": "test/{}.log".format(dataset),
"model_history": True,
"eval_method": "cv",
"groups": np.random.randint(low=0, high=10, size=len(y)),
"estimator_list": ['lgbm', 'rf', 'xgboost', 'kneighbor'], # list of ML learners
"estimator_list": ["lgbm", "rf", "xgboost", "kneighbor"],
"learner_selector": "roundrobin",
}
automl.fit(X, y, **automl_settings)
automl_settings["eval_method"] = "holdout"
automl.fit(X, y, **automl_settings)
def test_rank():
from sklearn.externals._arff import ArffException
try:
X, y = fetch_openml(name=dataset, return_X_y=True)
except (ArffException, ValueError):
from sklearn.datasets import load_wine
X, y = load_wine(return_X_y=True)
y = y.cat.codes
import numpy as np
automl = AutoML()
automl_settings = {
"time_budget": 2,
@ -90,8 +102,9 @@ def test_rank():
"log_file_name": "test/{}.log".format(dataset),
"model_history": True,
"eval_method": "cv",
"groups": np.array( # group labels
[0] * 200 + [1] * 200 + [2] * 200 + [3] * 200 + [4] * 100 + [5] * 100),
"groups": np.array( # group labels
[0] * 200 + [1] * 200 + [2] * 200 + [3] * 200 + [4] * 100 + [5] * 100
),
"learner_selector": "roundrobin",
}
automl.fit(X, y, **automl_settings)
@ -100,10 +113,10 @@ def test_rank():
automl_settings = {
"time_budget": 2,
"task": "rank",
"metric": "ndcg@5", # 5 can be replaced by any number
"metric": "ndcg@5", # 5 can be replaced by any number
"log_file_name": "test/{}.log".format(dataset),
"model_history": True,
"groups": [200] * 4 + [100] * 2, # alternative way: group counts
"groups": [200] * 4 + [100] * 2, # alternative way: group counts
# "estimator_list": ['lgbm', 'xgboost'], # list of ML learners
"learner_selector": "roundrobin",
}