random search (#213)

* random search as a child class of CFO

* random search in sequential search of AutoML

* time to find best model as a property of AutoML
This commit is contained in:
Chi Wang 2021-09-19 11:19:23 -07:00 committed by GitHub
parent 0ba58e0ace
commit f3e50136e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 122 additions and 59 deletions

View File

@ -363,7 +363,7 @@ class AutoML:
Returns:
An object with `predict()` and `predict_proba()` method (for
classification), storing the best trained model for estimator_name.
classification), storing the best trained model for estimator_name.
"""
state = self._search_states.get(estimator_name)
return state and getattr(state, "trained_estimator", None)
@ -414,6 +414,11 @@ class AutoML:
return attr.classes_.tolist()
return None
@property
def time_to_find_best_model(self) -> float:
"""time taken to find best model in seconds"""
return self.__dict__.get("_time_taken_best_iter")
def predict(self, X_test):
"""Predict label from features.
@ -1374,8 +1379,7 @@ class AutoML:
a simple customized search space. When set to 'bs', BlendSearch
is used. BlendSearch can be tried when the search space is
complex, for example, containing multiple disjoint, discontinuous
subspaces. When set to 'random' and the argument
`n_concurrent_trials` is larger than 1, random search is used.
subspaces. When set to 'random', random search is used.
starting_points: A dictionary to specify the starting hyperparameter
config for the estimators.
Keys are the name of the estimators, and values are the starting
@ -1717,6 +1721,8 @@ class AutoML:
from .searcher.suggestion import OptunaSearch as SearchAlgo
elif "bs" == self._hpo_method:
from flaml import BlendSearch as SearchAlgo
elif "random" == self._hpo_method:
from flaml.searcher import RandomSearch as SearchAlgo
elif "cfocat" == self._hpo_method:
from flaml.searcher.cfo_cat import CFOCat as SearchAlgo
else:
@ -1784,7 +1790,7 @@ class AutoML:
else [search_state.init_config]
)
low_cost_partial_config = search_state.low_cost_partial_config
if self._hpo_method in ("bs", "cfo", "grid", "cfocat"):
if self._hpo_method in ("bs", "cfo", "grid", "cfocat", "random"):
algo = SearchAlgo(
metric="val_loss",
mode="min",

View File

@ -1,3 +1,3 @@
from .blendsearch import CFO, BlendSearch, BlendSearchTuner
from .blendsearch import CFO, BlendSearch, BlendSearchTuner, RandomSearch
from .flow2 import FLOW2
from .online_searcher import ChampionFrontierSearcher

View File

@ -1024,3 +1024,19 @@ class CFO(BlendSearchTuner):
self._candidate_start_points[trial_id] = result
if len(self._search_thread_pool) < 2 and not self._points_to_evaluate:
self._create_thread_from_best_candidate()
class RandomSearch(CFO):
def suggest(self, trial_id: str) -> Optional[Dict]:
if self._points_to_evaluate:
return super().suggest(trial_id)
config, _ = self._ls.complete_config({})
return config
def on_trial_complete(
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
):
return
def on_trial_result(self, trial_id: str, result: Dict):
return

View File

@ -1,50 +1,66 @@
from openml.exceptions import OpenMLServerException
def test_automl(budget=5, dataset_format='dataframe', hpo_method=None):
def test_automl(budget=5, dataset_format="dataframe", hpo_method=None):
from flaml.data import load_openml_dataset
try:
X_train, X_test, y_train, y_test = load_openml_dataset(
dataset_id=1169, data_dir='test/', dataset_format=dataset_format)
dataset_id=1169, data_dir="test/", dataset_format=dataset_format
)
except OpenMLServerException:
print("OpenMLServerException raised")
return
''' import AutoML class from flaml package '''
""" import AutoML class from flaml package """
from flaml import AutoML
automl = AutoML()
settings = {
"time_budget": budget, # total running time in seconds
"metric": 'accuracy', # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2']
"task": 'classification', # task type
"log_file_name": 'airlines_experiment.log', # flaml log file
"seed": 7654321, # random seed
'hpo_method': hpo_method
"metric": "accuracy", # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2']
"task": "classification", # task type
"log_file_name": "airlines_experiment.log", # flaml log file
"seed": 7654321, # random seed
"hpo_method": hpo_method,
}
'''The main flaml automl API'''
"""The main flaml automl API"""
automl.fit(X_train=X_train, y_train=y_train, **settings)
''' retrieve best config and best learner'''
print('Best ML leaner:', automl.best_estimator)
print('Best hyperparmeter config:', automl.best_config)
print('Best accuracy on validation data: {0:.4g}'.format(1 - automl.best_loss))
print('Training duration of best run: {0:.4g} s'.format(automl.best_config_train_time))
""" retrieve best config and best learner """
print("Best ML leaner:", automl.best_estimator)
print("Best hyperparmeter config:", automl.best_config)
print("Best accuracy on validation data: {0:.4g}".format(1 - automl.best_loss))
print(
"Training duration of best run: {0:.4g} s".format(automl.best_config_train_time)
)
print(automl.model.estimator)
''' pickle and save the automl object '''
print("time taken to find best model:", automl.time_to_find_best_model)
""" pickle and save the automl object """
import pickle
with open('automl.pkl', 'wb') as f:
with open("automl.pkl", "wb") as f:
pickle.dump(automl, f, pickle.HIGHEST_PROTOCOL)
''' compute predictions of testing dataset '''
""" compute predictions of testing dataset """
y_pred = automl.predict(X_test)
print('Predicted labels', y_pred)
print('True labels', y_test)
print("Predicted labels", y_pred)
print("True labels", y_test)
y_pred_proba = automl.predict_proba(X_test)[:, 1]
''' compute different metric values on testing dataset'''
""" compute different metric values on testing dataset """
from flaml.ml import sklearn_metric_loss_score
print('accuracy', '=', 1 - sklearn_metric_loss_score('accuracy', y_pred, y_test))
print('roc_auc', '=', 1 - sklearn_metric_loss_score('roc_auc', y_pred_proba, y_test))
print('log_loss', '=', sklearn_metric_loss_score('log_loss', y_pred_proba, y_test))
print("accuracy", "=", 1 - sklearn_metric_loss_score("accuracy", y_pred, y_test))
print(
"roc_auc", "=", 1 - sklearn_metric_loss_score("roc_auc", y_pred_proba, y_test)
)
print("log_loss", "=", sklearn_metric_loss_score("log_loss", y_pred_proba, y_test))
from flaml.data import get_output_from_log
time_history, best_valid_loss_history, valid_loss_history, config_history, metric_history = \
get_output_from_log(filename=settings['log_file_name'], time_budget=60)
(
time_history,
best_valid_loss_history,
valid_loss_history,
config_history,
metric_history,
) = get_output_from_log(filename=settings["log_file_name"], time_budget=60)
for config in config_history:
print(config)
print(automl.prune_attr)
@ -53,37 +69,40 @@ def test_automl(budget=5, dataset_format='dataframe', hpo_method=None):
def test_automl_array():
test_automl(5, 'array', 'bs')
test_automl(5, "array", "bs")
def test_mlflow():
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "mlflow"])
import mlflow
from flaml.data import load_openml_task
try:
X_train, X_test, y_train, y_test = load_openml_task(
task_id=7592, data_dir='test/')
task_id=7592, data_dir="test/"
)
except OpenMLServerException:
print("OpenMLServerException raised")
return
''' import AutoML class from flaml package '''
""" import AutoML class from flaml package """
from flaml import AutoML
automl = AutoML()
settings = {
"time_budget": 5, # total running time in seconds
"metric": 'accuracy', # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2']
"estimator_list": ['lgbm', 'rf', 'xgboost'], # list of ML learners
"task": 'classification', # task type
"metric": "accuracy", # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2']
"estimator_list": ["lgbm", "rf", "xgboost"], # list of ML learners
"task": "classification", # task type
"sample": False, # whether to subsample training data
"log_file_name": 'adult.log', # flaml log file
"log_file_name": "adult.log", # flaml log file
}
mlflow.set_experiment("flaml")
with mlflow.start_run():
'''The main flaml automl API'''
automl.fit(
X_train=X_train, y_train=y_train, **settings)
"""The main flaml automl API"""
automl.fit(X_train=X_train, y_train=y_train, **settings)
# subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "mlflow"])
automl._mem_thres = 0
print(automl.trainable(automl.points_to_evaluate[0]))

View File

@ -12,58 +12,63 @@ dataset = "credit-g"
class XGBoost2D(XGBoostSklearnEstimator):
@classmethod
def search_space(cls, data_size, task):
upper = min(32768, int(data_size))
return {
'n_estimators': {
'domain': tune.lograndint(lower=4, upper=upper),
'low_cost_init_value': 4,
"n_estimators": {
"domain": tune.lograndint(lower=4, upper=upper),
"low_cost_init_value": 4,
},
'max_leaves': {
'domain': tune.lograndint(lower=4, upper=upper),
'low_cost_init_value': 4,
"max_leaves": {
"domain": tune.lograndint(lower=4, upper=upper),
"low_cost_init_value": 4,
},
}
def test_simple(method=None):
automl = AutoML()
automl.add_learner(learner_name='XGBoost2D',
learner_class=XGBoost2D)
automl.add_learner(learner_name="XGBoost2D", learner_class=XGBoost2D)
automl_settings = {
"estimator_list": ['XGBoost2D'],
"task": 'classification',
"estimator_list": ["XGBoost2D"],
"task": "classification",
"log_file_name": f"test/xgboost2d_{dataset}_{method}.log",
"n_jobs": 1,
"hpo_method": method,
"log_type": "all",
"retrain_full": "budget",
"keep_search_state": True,
"time_budget": 1
"time_budget": 1,
}
from sklearn.externals._arff import ArffException
try:
X, y = fetch_openml(name=dataset, return_X_y=True)
except (ArffException, ValueError):
from sklearn.datasets import load_wine
X, y = load_wine(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
X, y, test_size=0.33, random_state=42
)
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl.estimator_list)
print(automl.search_space)
print(automl.points_to_evaluate)
config = automl.best_config.copy()
config['learner'] = automl.best_estimator
config["learner"] = automl.best_estimator
automl.trainable(config)
from flaml import tune
from flaml.automl import size
from functools import partial
analysis = tune.run(
automl.trainable, automl.search_space, metric='val_loss', mode="min",
automl.trainable,
automl.search_space,
metric="val_loss",
mode="min",
low_cost_partial_config=automl.low_cost_partial_config,
points_to_evaluate=automl.points_to_evaluate,
cat_hp_cost=automl.cat_hp_cost,
@ -71,8 +76,10 @@ def test_simple(method=None):
min_resource=automl.min_resource,
max_resource=automl.max_resource,
time_budget_s=automl._state.time_budget,
config_constraints=[(partial(size, automl._state), '<=', automl._mem_thres)],
metric_constraints=automl.metric_constraints, num_samples=5)
config_constraints=[(partial(size, automl._state), "<=", automl._mem_thres)],
metric_constraints=automl.metric_constraints,
num_samples=5,
)
print(analysis.trials[-1])
@ -80,6 +87,10 @@ def test_optuna():
test_simple(method="optuna")
def test_random():
test_simple(method="random")
def test_grid():
test_simple(method="grid")

View File

@ -1,4 +1,3 @@
from flaml.searcher.blendsearch import CFO
import numpy as np
try:
@ -8,8 +7,9 @@ try:
from ray.tune import sample
except (ImportError, AssertionError):
from flaml.tune import sample
from flaml.searcher.suggestion import OptunaSearch, Searcher, ConcurrencyLimiter
from flaml.searcher.blendsearch import BlendSearch
from flaml.searcher.blendsearch import BlendSearch, CFO, RandomSearch
def define_search_space(trial):
trial.suggest_float("a", 6, 8)
@ -135,3 +135,14 @@ except (ImportError, AssertionError):
},
}
)
np.random.seed(7654321)
searcher = RandomSearch(
space=config,
points_to_evaluate=[{"a": 7, "b": 1e-3}, {"a": 6, "b": 3e-4}],
)
print(searcher.suggest("t1"))
print(searcher.suggest("t2"))
print(searcher.suggest("t3"))
print(searcher.suggest("t4"))
searcher.on_trial_complete({"t1"}, {})
searcher.on_trial_result({"t2"}, {})