2021-02-05 21:41:14 -08:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
from sklearn.datasets import fetch_openml
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
from flaml.automl import AutoML
|
|
|
|
from flaml.model import XGBoostSklearnEstimator
|
|
|
|
from flaml import tune
|
|
|
|
|
|
|
|
|
|
|
|
dataset = "credit-g"
|
|
|
|
|
|
|
|
|
|
|
|
class XGBoost2D(XGBoostSklearnEstimator):
|
|
|
|
@classmethod
|
|
|
|
def search_space(cls, data_size, task):
|
2021-12-03 09:15:21 -08:00
|
|
|
upper = min(32768, int(data_size[0]))
|
2021-02-05 21:41:14 -08:00
|
|
|
return {
|
2021-09-19 11:19:23 -07:00
|
|
|
"n_estimators": {
|
|
|
|
"domain": tune.lograndint(lower=4, upper=upper),
|
|
|
|
"low_cost_init_value": 4,
|
2021-02-05 21:41:14 -08:00
|
|
|
},
|
2021-09-19 11:19:23 -07:00
|
|
|
"max_leaves": {
|
|
|
|
"domain": tune.lograndint(lower=4, upper=upper),
|
|
|
|
"low_cost_init_value": 4,
|
2021-02-05 21:41:14 -08:00
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def test_simple(method=None):
|
|
|
|
automl = AutoML()
|
2021-09-19 11:19:23 -07:00
|
|
|
automl.add_learner(learner_name="XGBoost2D", learner_class=XGBoost2D)
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
automl_settings = {
|
2021-09-19 11:19:23 -07:00
|
|
|
"estimator_list": ["XGBoost2D"],
|
|
|
|
"task": "classification",
|
2021-02-05 21:41:14 -08:00
|
|
|
"log_file_name": f"test/xgboost2d_{dataset}_{method}.log",
|
|
|
|
"n_jobs": 1,
|
|
|
|
"hpo_method": method,
|
|
|
|
"log_type": "all",
|
2021-08-23 19:36:51 -04:00
|
|
|
"retrain_full": "budget",
|
2021-08-26 13:45:13 -07:00
|
|
|
"keep_search_state": True,
|
2021-09-19 11:19:23 -07:00
|
|
|
"time_budget": 1,
|
2021-02-05 21:41:14 -08:00
|
|
|
}
|
2021-04-21 04:36:06 -04:00
|
|
|
from sklearn.externals._arff import ArffException
|
2021-09-19 11:19:23 -07:00
|
|
|
|
2021-02-22 22:10:41 -08:00
|
|
|
try:
|
|
|
|
X, y = fetch_openml(name=dataset, return_X_y=True)
|
2021-04-21 04:36:06 -04:00
|
|
|
except (ArffException, ValueError):
|
2021-02-22 22:10:41 -08:00
|
|
|
from sklearn.datasets import load_wine
|
2021-09-19 11:19:23 -07:00
|
|
|
|
2021-02-22 22:10:41 -08:00
|
|
|
X, y = load_wine(return_X_y=True)
|
2021-04-08 09:29:55 -07:00
|
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
2021-09-19 11:19:23 -07:00
|
|
|
X, y, test_size=0.33, random_state=42
|
|
|
|
)
|
2021-02-05 21:41:14 -08:00
|
|
|
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
2021-08-02 19:10:26 -04:00
|
|
|
print(automl.estimator_list)
|
|
|
|
print(automl.search_space)
|
2021-08-23 19:36:51 -04:00
|
|
|
print(automl.points_to_evaluate)
|
2022-03-30 22:19:47 -07:00
|
|
|
if not automl.best_config:
|
|
|
|
return
|
2021-08-02 19:10:26 -04:00
|
|
|
config = automl.best_config.copy()
|
2021-09-19 11:19:23 -07:00
|
|
|
config["learner"] = automl.best_estimator
|
2021-08-02 19:10:26 -04:00
|
|
|
automl.trainable(config)
|
|
|
|
from flaml import tune
|
2021-08-23 19:36:51 -04:00
|
|
|
from flaml.automl import size
|
|
|
|
from functools import partial
|
2021-09-19 11:19:23 -07:00
|
|
|
|
2021-08-02 19:10:26 -04:00
|
|
|
analysis = tune.run(
|
2021-09-19 11:19:23 -07:00
|
|
|
automl.trainable,
|
|
|
|
automl.search_space,
|
|
|
|
metric="val_loss",
|
|
|
|
mode="min",
|
2021-08-02 19:10:26 -04:00
|
|
|
low_cost_partial_config=automl.low_cost_partial_config,
|
2021-08-23 19:36:51 -04:00
|
|
|
points_to_evaluate=automl.points_to_evaluate,
|
2021-08-02 19:10:26 -04:00
|
|
|
cat_hp_cost=automl.cat_hp_cost,
|
2021-12-04 21:52:20 -05:00
|
|
|
resource_attr=automl.resource_attr,
|
2021-08-02 19:10:26 -04:00
|
|
|
min_resource=automl.min_resource,
|
|
|
|
max_resource=automl.max_resource,
|
|
|
|
time_budget_s=automl._state.time_budget,
|
2021-09-19 11:19:23 -07:00
|
|
|
config_constraints=[(partial(size, automl._state), "<=", automl._mem_thres)],
|
|
|
|
metric_constraints=automl.metric_constraints,
|
|
|
|
num_samples=5,
|
|
|
|
)
|
2021-08-02 19:10:26 -04:00
|
|
|
print(analysis.trials[-1])
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
|
2021-09-04 01:42:21 -07:00
|
|
|
def test_optuna():
|
2021-02-05 21:41:14 -08:00
|
|
|
test_simple(method="optuna")
|
|
|
|
|
|
|
|
|
2021-09-19 11:19:23 -07:00
|
|
|
def test_random():
|
|
|
|
test_simple(method="random")
|
|
|
|
|
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
def test_grid():
|
|
|
|
test_simple(method="grid")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|