random search (#213)

* random search as a child class of CFO

* random search in sequential search of AutoML

* time to find best model as a property of AutoML
This commit is contained in:
Chi Wang 2021-09-19 11:19:23 -07:00 committed by GitHub
parent 0ba58e0ace
commit f3e50136e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 122 additions and 59 deletions

View File

@ -363,7 +363,7 @@ class AutoML:
Returns: Returns:
An object with `predict()` and `predict_proba()` method (for An object with `predict()` and `predict_proba()` method (for
classification), storing the best trained model for estimator_name. classification), storing the best trained model for estimator_name.
""" """
state = self._search_states.get(estimator_name) state = self._search_states.get(estimator_name)
return state and getattr(state, "trained_estimator", None) return state and getattr(state, "trained_estimator", None)
@ -414,6 +414,11 @@ class AutoML:
return attr.classes_.tolist() return attr.classes_.tolist()
return None return None
@property
def time_to_find_best_model(self) -> float:
"""time taken to find best model in seconds"""
return self.__dict__.get("_time_taken_best_iter")
def predict(self, X_test): def predict(self, X_test):
"""Predict label from features. """Predict label from features.
@ -1374,8 +1379,7 @@ class AutoML:
a simple customized search space. When set to 'bs', BlendSearch a simple customized search space. When set to 'bs', BlendSearch
is used. BlendSearch can be tried when the search space is is used. BlendSearch can be tried when the search space is
complex, for example, containing multiple disjoint, discontinuous complex, for example, containing multiple disjoint, discontinuous
subspaces. When set to 'random' and the argument subspaces. When set to 'random', random search is used.
`n_concurrent_trials` is larger than 1, random search is used.
starting_points: A dictionary to specify the starting hyperparameter starting_points: A dictionary to specify the starting hyperparameter
config for the estimators. config for the estimators.
Keys are the name of the estimators, and values are the starting Keys are the name of the estimators, and values are the starting
@ -1717,6 +1721,8 @@ class AutoML:
from .searcher.suggestion import OptunaSearch as SearchAlgo from .searcher.suggestion import OptunaSearch as SearchAlgo
elif "bs" == self._hpo_method: elif "bs" == self._hpo_method:
from flaml import BlendSearch as SearchAlgo from flaml import BlendSearch as SearchAlgo
elif "random" == self._hpo_method:
from flaml.searcher import RandomSearch as SearchAlgo
elif "cfocat" == self._hpo_method: elif "cfocat" == self._hpo_method:
from flaml.searcher.cfo_cat import CFOCat as SearchAlgo from flaml.searcher.cfo_cat import CFOCat as SearchAlgo
else: else:
@ -1784,7 +1790,7 @@ class AutoML:
else [search_state.init_config] else [search_state.init_config]
) )
low_cost_partial_config = search_state.low_cost_partial_config low_cost_partial_config = search_state.low_cost_partial_config
if self._hpo_method in ("bs", "cfo", "grid", "cfocat"): if self._hpo_method in ("bs", "cfo", "grid", "cfocat", "random"):
algo = SearchAlgo( algo = SearchAlgo(
metric="val_loss", metric="val_loss",
mode="min", mode="min",

View File

@ -1,3 +1,3 @@
from .blendsearch import CFO, BlendSearch, BlendSearchTuner from .blendsearch import CFO, BlendSearch, BlendSearchTuner, RandomSearch
from .flow2 import FLOW2 from .flow2 import FLOW2
from .online_searcher import ChampionFrontierSearcher from .online_searcher import ChampionFrontierSearcher

View File

@ -1024,3 +1024,19 @@ class CFO(BlendSearchTuner):
self._candidate_start_points[trial_id] = result self._candidate_start_points[trial_id] = result
if len(self._search_thread_pool) < 2 and not self._points_to_evaluate: if len(self._search_thread_pool) < 2 and not self._points_to_evaluate:
self._create_thread_from_best_candidate() self._create_thread_from_best_candidate()
class RandomSearch(CFO):
def suggest(self, trial_id: str) -> Optional[Dict]:
if self._points_to_evaluate:
return super().suggest(trial_id)
config, _ = self._ls.complete_config({})
return config
def on_trial_complete(
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
):
return
def on_trial_result(self, trial_id: str, result: Dict):
return

View File

@ -1,50 +1,66 @@
from openml.exceptions import OpenMLServerException from openml.exceptions import OpenMLServerException
def test_automl(budget=5, dataset_format='dataframe', hpo_method=None): def test_automl(budget=5, dataset_format="dataframe", hpo_method=None):
from flaml.data import load_openml_dataset from flaml.data import load_openml_dataset
try: try:
X_train, X_test, y_train, y_test = load_openml_dataset( X_train, X_test, y_train, y_test = load_openml_dataset(
dataset_id=1169, data_dir='test/', dataset_format=dataset_format) dataset_id=1169, data_dir="test/", dataset_format=dataset_format
)
except OpenMLServerException: except OpenMLServerException:
print("OpenMLServerException raised") print("OpenMLServerException raised")
return return
''' import AutoML class from flaml package ''' """ import AutoML class from flaml package """
from flaml import AutoML from flaml import AutoML
automl = AutoML() automl = AutoML()
settings = { settings = {
"time_budget": budget, # total running time in seconds "time_budget": budget, # total running time in seconds
"metric": 'accuracy', # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2'] "metric": "accuracy", # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2']
"task": 'classification', # task type "task": "classification", # task type
"log_file_name": 'airlines_experiment.log', # flaml log file "log_file_name": "airlines_experiment.log", # flaml log file
"seed": 7654321, # random seed "seed": 7654321, # random seed
'hpo_method': hpo_method "hpo_method": hpo_method,
} }
'''The main flaml automl API''' """The main flaml automl API"""
automl.fit(X_train=X_train, y_train=y_train, **settings) automl.fit(X_train=X_train, y_train=y_train, **settings)
''' retrieve best config and best learner''' """ retrieve best config and best learner """
print('Best ML leaner:', automl.best_estimator) print("Best ML leaner:", automl.best_estimator)
print('Best hyperparmeter config:', automl.best_config) print("Best hyperparmeter config:", automl.best_config)
print('Best accuracy on validation data: {0:.4g}'.format(1 - automl.best_loss)) print("Best accuracy on validation data: {0:.4g}".format(1 - automl.best_loss))
print('Training duration of best run: {0:.4g} s'.format(automl.best_config_train_time)) print(
"Training duration of best run: {0:.4g} s".format(automl.best_config_train_time)
)
print(automl.model.estimator) print(automl.model.estimator)
''' pickle and save the automl object ''' print("time taken to find best model:", automl.time_to_find_best_model)
""" pickle and save the automl object """
import pickle import pickle
with open('automl.pkl', 'wb') as f:
with open("automl.pkl", "wb") as f:
pickle.dump(automl, f, pickle.HIGHEST_PROTOCOL) pickle.dump(automl, f, pickle.HIGHEST_PROTOCOL)
''' compute predictions of testing dataset ''' """ compute predictions of testing dataset """
y_pred = automl.predict(X_test) y_pred = automl.predict(X_test)
print('Predicted labels', y_pred) print("Predicted labels", y_pred)
print('True labels', y_test) print("True labels", y_test)
y_pred_proba = automl.predict_proba(X_test)[:, 1] y_pred_proba = automl.predict_proba(X_test)[:, 1]
''' compute different metric values on testing dataset''' """ compute different metric values on testing dataset """
from flaml.ml import sklearn_metric_loss_score from flaml.ml import sklearn_metric_loss_score
print('accuracy', '=', 1 - sklearn_metric_loss_score('accuracy', y_pred, y_test))
print('roc_auc', '=', 1 - sklearn_metric_loss_score('roc_auc', y_pred_proba, y_test)) print("accuracy", "=", 1 - sklearn_metric_loss_score("accuracy", y_pred, y_test))
print('log_loss', '=', sklearn_metric_loss_score('log_loss', y_pred_proba, y_test)) print(
"roc_auc", "=", 1 - sklearn_metric_loss_score("roc_auc", y_pred_proba, y_test)
)
print("log_loss", "=", sklearn_metric_loss_score("log_loss", y_pred_proba, y_test))
from flaml.data import get_output_from_log from flaml.data import get_output_from_log
time_history, best_valid_loss_history, valid_loss_history, config_history, metric_history = \
get_output_from_log(filename=settings['log_file_name'], time_budget=60) (
time_history,
best_valid_loss_history,
valid_loss_history,
config_history,
metric_history,
) = get_output_from_log(filename=settings["log_file_name"], time_budget=60)
for config in config_history: for config in config_history:
print(config) print(config)
print(automl.prune_attr) print(automl.prune_attr)
@ -53,37 +69,40 @@ def test_automl(budget=5, dataset_format='dataframe', hpo_method=None):
def test_automl_array(): def test_automl_array():
test_automl(5, 'array', 'bs') test_automl(5, "array", "bs")
def test_mlflow(): def test_mlflow():
import subprocess import subprocess
import sys import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "mlflow"]) subprocess.check_call([sys.executable, "-m", "pip", "install", "mlflow"])
import mlflow import mlflow
from flaml.data import load_openml_task from flaml.data import load_openml_task
try: try:
X_train, X_test, y_train, y_test = load_openml_task( X_train, X_test, y_train, y_test = load_openml_task(
task_id=7592, data_dir='test/') task_id=7592, data_dir="test/"
)
except OpenMLServerException: except OpenMLServerException:
print("OpenMLServerException raised") print("OpenMLServerException raised")
return return
''' import AutoML class from flaml package ''' """ import AutoML class from flaml package """
from flaml import AutoML from flaml import AutoML
automl = AutoML() automl = AutoML()
settings = { settings = {
"time_budget": 5, # total running time in seconds "time_budget": 5, # total running time in seconds
"metric": 'accuracy', # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2'] "metric": "accuracy", # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2']
"estimator_list": ['lgbm', 'rf', 'xgboost'], # list of ML learners "estimator_list": ["lgbm", "rf", "xgboost"], # list of ML learners
"task": 'classification', # task type "task": "classification", # task type
"sample": False, # whether to subsample training data "sample": False, # whether to subsample training data
"log_file_name": 'adult.log', # flaml log file "log_file_name": "adult.log", # flaml log file
} }
mlflow.set_experiment("flaml") mlflow.set_experiment("flaml")
with mlflow.start_run(): with mlflow.start_run():
'''The main flaml automl API''' """The main flaml automl API"""
automl.fit( automl.fit(X_train=X_train, y_train=y_train, **settings)
X_train=X_train, y_train=y_train, **settings)
# subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "mlflow"]) # subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "mlflow"])
automl._mem_thres = 0 automl._mem_thres = 0
print(automl.trainable(automl.points_to_evaluate[0])) print(automl.trainable(automl.points_to_evaluate[0]))

View File

@ -12,58 +12,63 @@ dataset = "credit-g"
class XGBoost2D(XGBoostSklearnEstimator): class XGBoost2D(XGBoostSklearnEstimator):
@classmethod @classmethod
def search_space(cls, data_size, task): def search_space(cls, data_size, task):
upper = min(32768, int(data_size)) upper = min(32768, int(data_size))
return { return {
'n_estimators': { "n_estimators": {
'domain': tune.lograndint(lower=4, upper=upper), "domain": tune.lograndint(lower=4, upper=upper),
'low_cost_init_value': 4, "low_cost_init_value": 4,
}, },
'max_leaves': { "max_leaves": {
'domain': tune.lograndint(lower=4, upper=upper), "domain": tune.lograndint(lower=4, upper=upper),
'low_cost_init_value': 4, "low_cost_init_value": 4,
}, },
} }
def test_simple(method=None): def test_simple(method=None):
automl = AutoML() automl = AutoML()
automl.add_learner(learner_name='XGBoost2D', automl.add_learner(learner_name="XGBoost2D", learner_class=XGBoost2D)
learner_class=XGBoost2D)
automl_settings = { automl_settings = {
"estimator_list": ['XGBoost2D'], "estimator_list": ["XGBoost2D"],
"task": 'classification', "task": "classification",
"log_file_name": f"test/xgboost2d_{dataset}_{method}.log", "log_file_name": f"test/xgboost2d_{dataset}_{method}.log",
"n_jobs": 1, "n_jobs": 1,
"hpo_method": method, "hpo_method": method,
"log_type": "all", "log_type": "all",
"retrain_full": "budget", "retrain_full": "budget",
"keep_search_state": True, "keep_search_state": True,
"time_budget": 1 "time_budget": 1,
} }
from sklearn.externals._arff import ArffException from sklearn.externals._arff import ArffException
try: try:
X, y = fetch_openml(name=dataset, return_X_y=True) X, y = fetch_openml(name=dataset, return_X_y=True)
except (ArffException, ValueError): except (ArffException, ValueError):
from sklearn.datasets import load_wine from sklearn.datasets import load_wine
X, y = load_wine(return_X_y=True) X, y = load_wine(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split( X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42) X, y, test_size=0.33, random_state=42
)
automl.fit(X_train=X_train, y_train=y_train, **automl_settings) automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl.estimator_list) print(automl.estimator_list)
print(automl.search_space) print(automl.search_space)
print(automl.points_to_evaluate) print(automl.points_to_evaluate)
config = automl.best_config.copy() config = automl.best_config.copy()
config['learner'] = automl.best_estimator config["learner"] = automl.best_estimator
automl.trainable(config) automl.trainable(config)
from flaml import tune from flaml import tune
from flaml.automl import size from flaml.automl import size
from functools import partial from functools import partial
analysis = tune.run( analysis = tune.run(
automl.trainable, automl.search_space, metric='val_loss', mode="min", automl.trainable,
automl.search_space,
metric="val_loss",
mode="min",
low_cost_partial_config=automl.low_cost_partial_config, low_cost_partial_config=automl.low_cost_partial_config,
points_to_evaluate=automl.points_to_evaluate, points_to_evaluate=automl.points_to_evaluate,
cat_hp_cost=automl.cat_hp_cost, cat_hp_cost=automl.cat_hp_cost,
@ -71,8 +76,10 @@ def test_simple(method=None):
min_resource=automl.min_resource, min_resource=automl.min_resource,
max_resource=automl.max_resource, max_resource=automl.max_resource,
time_budget_s=automl._state.time_budget, time_budget_s=automl._state.time_budget,
config_constraints=[(partial(size, automl._state), '<=', automl._mem_thres)], config_constraints=[(partial(size, automl._state), "<=", automl._mem_thres)],
metric_constraints=automl.metric_constraints, num_samples=5) metric_constraints=automl.metric_constraints,
num_samples=5,
)
print(analysis.trials[-1]) print(analysis.trials[-1])
@ -80,6 +87,10 @@ def test_optuna():
test_simple(method="optuna") test_simple(method="optuna")
def test_random():
test_simple(method="random")
def test_grid(): def test_grid():
test_simple(method="grid") test_simple(method="grid")

View File

@ -1,4 +1,3 @@
from flaml.searcher.blendsearch import CFO
import numpy as np import numpy as np
try: try:
@ -8,8 +7,9 @@ try:
from ray.tune import sample from ray.tune import sample
except (ImportError, AssertionError): except (ImportError, AssertionError):
from flaml.tune import sample from flaml.tune import sample
from flaml.searcher.suggestion import OptunaSearch, Searcher, ConcurrencyLimiter from flaml.searcher.suggestion import OptunaSearch, Searcher, ConcurrencyLimiter
from flaml.searcher.blendsearch import BlendSearch from flaml.searcher.blendsearch import BlendSearch, CFO, RandomSearch
def define_search_space(trial): def define_search_space(trial):
trial.suggest_float("a", 6, 8) trial.suggest_float("a", 6, 8)
@ -135,3 +135,14 @@ except (ImportError, AssertionError):
}, },
} }
) )
np.random.seed(7654321)
searcher = RandomSearch(
space=config,
points_to_evaluate=[{"a": 7, "b": 1e-3}, {"a": 6, "b": 3e-4}],
)
print(searcher.suggest("t1"))
print(searcher.suggest("t2"))
print(searcher.suggest("t3"))
print(searcher.suggest("t4"))
searcher.on_trial_complete({"t1"}, {})
searcher.on_trial_result({"t2"}, {})