mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-13 08:34:29 +00:00
space -> main (#148)
* subspace in flow2 * search space and trainable from AutoML * experimental features: multivariate TPE, grouping, add_evaluated_points * test experimental features * readme * define by run * set time_budget_s for bs Co-authored-by: liususan091219 <Xqq630517> * version * acl * test define_by_run_func * size * constraints Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
46752083a2
commit
eeaf5b5963
194
flaml/automl.py
194
flaml/automl.py
@ -4,6 +4,7 @@
|
|||||||
* project root for license information.
|
* project root for license information.
|
||||||
'''
|
'''
|
||||||
import time
|
import time
|
||||||
|
from typing import Callable, Optional
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -212,10 +213,11 @@ class AutoMLState:
|
|||||||
'val_loss': val_loss,
|
'val_loss': val_loss,
|
||||||
'trained_estimator': trained_estimator
|
'trained_estimator': trained_estimator
|
||||||
}
|
}
|
||||||
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
|
|
||||||
tune.report(**result)
|
|
||||||
if sampled_weight is not None:
|
if sampled_weight is not None:
|
||||||
self.fit_kwargs['sample_weight'] = weight
|
self.fit_kwargs['sample_weight'] = weight
|
||||||
|
# with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
|
||||||
|
# tune.report(**result)
|
||||||
|
return result
|
||||||
|
|
||||||
def _train_with_config(
|
def _train_with_config(
|
||||||
self, estimator, config_w_resource, sample_size=None
|
self, estimator, config_w_resource, sample_size=None
|
||||||
@ -790,6 +792,177 @@ class AutoML:
|
|||||||
else:
|
else:
|
||||||
return 'holdout'
|
return 'holdout'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def search_space(self) -> dict:
|
||||||
|
'''Search space
|
||||||
|
Must be called after fit(...) (use max_iter=0 to prevent actual fitting)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict of the search space
|
||||||
|
'''
|
||||||
|
estimator_list = self.estimator_list
|
||||||
|
if len(estimator_list) == 1:
|
||||||
|
estimator = estimator_list[0]
|
||||||
|
space = self._search_states[estimator].search_space.copy()
|
||||||
|
space['learner'] = estimator
|
||||||
|
return space
|
||||||
|
choices = []
|
||||||
|
for estimator in estimator_list:
|
||||||
|
space = self._search_states[estimator].search_space.copy()
|
||||||
|
space['learner'] = estimator
|
||||||
|
choices.append(space)
|
||||||
|
return {'ml': tune.choice(choices)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def low_cost_partial_config(self) -> dict:
|
||||||
|
'''Low cost partial config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict.
|
||||||
|
(a) if there is only one estimator in estimator_list, each key is a
|
||||||
|
hyperparameter name
|
||||||
|
(b) otherwise, it is a nested dict with 'ml' as the key, and
|
||||||
|
a list of the low_cost_partial_configs as the value, corresponding
|
||||||
|
to each learner's low_cost_partial_config
|
||||||
|
|
||||||
|
'''
|
||||||
|
if len(self.estimator_list) == 1:
|
||||||
|
estimator = self.estimator_list[0]
|
||||||
|
c = self._search_states[estimator].low_cost_partial_config
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
configs = []
|
||||||
|
for estimator in self.estimator_list:
|
||||||
|
c = self._search_states[estimator].low_cost_partial_config
|
||||||
|
configs.append(c)
|
||||||
|
config = {'ml': configs}
|
||||||
|
return config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cat_hp_cost(self) -> dict:
|
||||||
|
'''Categorical hyperparameter cost
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict.
|
||||||
|
(a) if there is only one estimator in estimator_list, each key is a
|
||||||
|
hyperparameter name
|
||||||
|
(b) otherwise, it is a nested dict with 'ml' as the key, and
|
||||||
|
a list of the cat_hp_cost's as the value, corresponding
|
||||||
|
to each learner's cat_hp_cost
|
||||||
|
|
||||||
|
'''
|
||||||
|
if len(self.estimator_list) == 1:
|
||||||
|
estimator = self.estimator_list[0]
|
||||||
|
c = self._search_states[estimator].cat_hp_cost
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
configs = []
|
||||||
|
for estimator in self.estimator_list:
|
||||||
|
c = self._search_states[estimator].cat_hp_cost
|
||||||
|
configs.append(c)
|
||||||
|
config = {'ml': configs}
|
||||||
|
return config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points_to_evalaute(self) -> dict:
|
||||||
|
'''Initial points to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dicts. Each dict is the initial point for each learner
|
||||||
|
'''
|
||||||
|
points = []
|
||||||
|
for estimator in self.estimator_list:
|
||||||
|
config = self._search_states[estimator].init_config
|
||||||
|
config['learner'] = estimator
|
||||||
|
if len(self.estimator_list) > 1:
|
||||||
|
points.append({'ml': config})
|
||||||
|
else:
|
||||||
|
points.append(config)
|
||||||
|
return points
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prune_attr(self) -> Optional[str]:
|
||||||
|
'''Attribute for pruning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string for the sample size attribute or None
|
||||||
|
'''
|
||||||
|
return 'FLAML_sample_size' if self._sample else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def min_resource(self) -> Optional[float]:
|
||||||
|
'''Attribute for pruning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A float for the minimal sample size or None
|
||||||
|
'''
|
||||||
|
return MIN_SAMPLE_TRAIN if self._sample else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_resource(self) -> Optional[float]:
|
||||||
|
'''Attribute for pruning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A float for the maximal sample size or None
|
||||||
|
'''
|
||||||
|
return self._state.data_size if self._sample else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trainable(self) -> Callable[[dict], Optional[float]]:
|
||||||
|
'''Training function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that evaluates each config and returns the loss
|
||||||
|
'''
|
||||||
|
self._state.time_from_start = 0
|
||||||
|
for estimator in self.estimator_list:
|
||||||
|
search_state = self._search_states[estimator]
|
||||||
|
if not hasattr(search_state, 'training_function'):
|
||||||
|
search_state.training_function = partial(
|
||||||
|
AutoMLState._compute_with_config_base,
|
||||||
|
self._state, estimator)
|
||||||
|
states = self._search_states
|
||||||
|
|
||||||
|
def train(config: dict):
|
||||||
|
sample_size = config.get('FLAML_sample_size')
|
||||||
|
config = config.get('ml', config).copy()
|
||||||
|
if sample_size:
|
||||||
|
config['FLAML_sample_size'] = sample_size
|
||||||
|
estimator = config['learner']
|
||||||
|
del config['learner']
|
||||||
|
states[estimator].training_function(config)
|
||||||
|
|
||||||
|
return train
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> Callable[[dict], float]:
|
||||||
|
'''Size function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that returns the mem size in bytes for a config
|
||||||
|
'''
|
||||||
|
|
||||||
|
def size_func(config: dict) -> float:
|
||||||
|
config = config.get('ml', config).copy
|
||||||
|
estimator = config['learner']
|
||||||
|
learner_class = self._state.learner_classes.get(estimator)
|
||||||
|
return learner_class.size(config)
|
||||||
|
|
||||||
|
return size_func
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metric_constraints(self) -> list:
|
||||||
|
'''Metric constraints
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of the metric constraints
|
||||||
|
'''
|
||||||
|
constraints = []
|
||||||
|
if np.isfinite(self._pred_time_limit):
|
||||||
|
constraints.append(
|
||||||
|
('pred_time', '<=', self._pred_time_limit))
|
||||||
|
return constraints
|
||||||
|
|
||||||
def fit(self,
|
def fit(self,
|
||||||
X_train=None,
|
X_train=None,
|
||||||
y_train=None,
|
y_train=None,
|
||||||
@ -969,11 +1142,12 @@ class AutoML:
|
|||||||
)
|
)
|
||||||
logger.info("List of ML learners in AutoML Run: {}".format(
|
logger.info("List of ML learners in AutoML Run: {}".format(
|
||||||
estimator_list))
|
estimator_list))
|
||||||
|
self.estimator_list = estimator_list
|
||||||
self._hpo_method = hpo_method or 'cfo'
|
self._hpo_method = hpo_method or 'cfo'
|
||||||
with training_log_writer(log_file_name) as save_helper:
|
with training_log_writer(log_file_name) as save_helper:
|
||||||
self._training_log = save_helper
|
self._training_log = save_helper
|
||||||
self._state.time_budget = time_budget
|
self._state.time_budget = time_budget
|
||||||
self.estimator_list = estimator_list
|
self._active_estimators = estimator_list.copy()
|
||||||
self._ensemble = ensemble
|
self._ensemble = ensemble
|
||||||
self._max_iter = max_iter
|
self._max_iter = max_iter
|
||||||
self._mem_thres = mem_thres
|
self._mem_thres = mem_thres
|
||||||
@ -1028,9 +1202,9 @@ class AutoML:
|
|||||||
|
|
||||||
for self._track_iter in range(self._max_iter):
|
for self._track_iter in range(self._max_iter):
|
||||||
if self._estimator_index is None:
|
if self._estimator_index is None:
|
||||||
estimator = self.estimator_list[0]
|
estimator = self._active_estimators[0]
|
||||||
else:
|
else:
|
||||||
estimator = self._select_estimator(self.estimator_list)
|
estimator = self._select_estimator(self._active_estimators)
|
||||||
if not estimator:
|
if not estimator:
|
||||||
break
|
break
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -1071,10 +1245,6 @@ class AutoML:
|
|||||||
points_to_evaluate = [search_state.init_config]
|
points_to_evaluate = [search_state.init_config]
|
||||||
low_cost_partial_config = search_state.low_cost_partial_config
|
low_cost_partial_config = search_state.low_cost_partial_config
|
||||||
if self._hpo_method in ('bs', 'cfo', 'grid'):
|
if self._hpo_method in ('bs', 'cfo', 'grid'):
|
||||||
metric_constraints = []
|
|
||||||
if np.isfinite(self._pred_time_limit):
|
|
||||||
metric_constraints.append(
|
|
||||||
('pred_time', '<=', self._pred_time_limit))
|
|
||||||
algo = SearchAlgo(
|
algo = SearchAlgo(
|
||||||
metric='val_loss', mode='min', space=search_space,
|
metric='val_loss', mode='min', space=search_space,
|
||||||
points_to_evaluate=points_to_evaluate,
|
points_to_evaluate=points_to_evaluate,
|
||||||
@ -1086,7 +1256,7 @@ class AutoML:
|
|||||||
config_constraints=[
|
config_constraints=[
|
||||||
(learner_class.size, '<=', self._mem_thres)
|
(learner_class.size, '<=', self._mem_thres)
|
||||||
],
|
],
|
||||||
metric_constraints=metric_constraints,
|
metric_constraints=self.metric_constraints,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
algo = SearchAlgo(
|
algo = SearchAlgo(
|
||||||
@ -1198,7 +1368,7 @@ class AutoML:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"no enough budget for learner {estimator}")
|
logger.info(f"no enough budget for learner {estimator}")
|
||||||
if self._estimator_index is not None:
|
if self._estimator_index is not None:
|
||||||
self.estimator_list.remove(estimator)
|
self._active_estimators.remove(estimator)
|
||||||
self._estimator_index -= 1
|
self._estimator_index -= 1
|
||||||
if self._retrain_full and best_config_sig and not better and (
|
if self._retrain_full and best_config_sig and not better and (
|
||||||
self._search_states[
|
self._search_states[
|
||||||
@ -1217,7 +1387,7 @@ class AutoML:
|
|||||||
est_retrain_time = 0
|
est_retrain_time = 0
|
||||||
self._state.time_from_start = time.time() - self._start_time_flag
|
self._state.time_from_start = time.time() - self._start_time_flag
|
||||||
if (self._state.time_from_start >= self._state.time_budget
|
if (self._state.time_from_start >= self._state.time_budget
|
||||||
or not self.estimator_list):
|
or not self._active_estimators):
|
||||||
break
|
break
|
||||||
if self._ensemble and self._best_estimator:
|
if self._ensemble and self._best_estimator:
|
||||||
time_left = self._state.time_budget - self._state.time_from_start
|
time_left = self._state.time_budget - self._state.time_from_start
|
||||||
|
|||||||
@ -21,24 +21,34 @@ autohf_settings = {"resources_per_trial": {"gpu": 1, "cpu": 1},
|
|||||||
"ckpt_per_epoch": 1,
|
"ckpt_per_epoch": 1,
|
||||||
"fp16": False,
|
"fp16": False,
|
||||||
}
|
}
|
||||||
validation_metric, analysis = \
|
validation_metric, analysis = autohf.fit(**autohf_settings)
|
||||||
autohf.fit(**autohf_settings,)
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The current use cases that are supported:
|
The current use cases that are supported:
|
||||||
|
|
||||||
1. A simplified version of fine-tuning the GLUE dataset using HuggingFace;
|
1. A simplified version of fine-tuning the GLUE dataset using HuggingFace;
|
||||||
2. For selecting better search space for fine-tuning the GLUE dataset;
|
2. For selecting better search space for fine-tuning the GLUE dataset;
|
||||||
3. Use the search algorithms in flaml for more efficient fine-tuning of HuggingFace;
|
3. Use the search algorithms in flaml for more efficient fine-tuning of HuggingFace.
|
||||||
|
|
||||||
The use cases that can be supported in future:
|
The use cases that can be supported in future:
|
||||||
1. HPO fine-tuning for text generation;
|
|
||||||
2. HPO fine-tuning for question answering;
|
|
||||||
|
|
||||||
### Troubleshooting fine-tuning HPO for pre-trained language models
|
1. HPO fine-tuning for text generation;
|
||||||
|
2. HPO fine-tuning for question answering.
|
||||||
|
|
||||||
|
## Troubleshooting fine-tuning HPO for pre-trained language models
|
||||||
|
|
||||||
To reproduce the results for our ACL2021 paper:
|
To reproduce the results for our ACL2021 paper:
|
||||||
|
|
||||||
*[An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://arxiv.org/abs/2106.09204). Xueqing Liu, Chi Wang. To appear in ACL-IJCNLP 2021*
|
* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://arxiv.org/abs/2106.09204). Xueqing Liu, Chi Wang. ACL-IJCNLP 2021.
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{liu2021hpo,
|
||||||
|
title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},
|
||||||
|
author={Xueqing Liu and Chi Wang},
|
||||||
|
year={2021},
|
||||||
|
booktitle={ACL-IJCNLP},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
Please refer to the following jupyter notebook: [Troubleshooting HPO for fine-tuning pre-trained language models](https://github.com/microsoft/FLAML/blob/main/notebook/research/acl2021.ipynb)
|
Please refer to the following jupyter notebook: [Troubleshooting HPO for fine-tuning pre-trained language models](https://github.com/microsoft/FLAML/blob/main/notebook/research/acl2021.ipynb)
|
||||||
@ -455,6 +455,9 @@ class AutoTransformers:
|
|||||||
def _get_search_algo(self,
|
def _get_search_algo(self,
|
||||||
search_algo_name,
|
search_algo_name,
|
||||||
search_algo_args_mode,
|
search_algo_args_mode,
|
||||||
|
time_budget,
|
||||||
|
metric_name,
|
||||||
|
metric_mode_name,
|
||||||
**custom_hpo_args):
|
**custom_hpo_args):
|
||||||
from .hpo.searchalgo_auto import AutoSearchAlgorithm
|
from .hpo.searchalgo_auto import AutoSearchAlgorithm
|
||||||
|
|
||||||
@ -464,6 +467,9 @@ class AutoTransformers:
|
|||||||
search_algo_name,
|
search_algo_name,
|
||||||
search_algo_args_mode,
|
search_algo_args_mode,
|
||||||
self._search_space_hpo,
|
self._search_space_hpo,
|
||||||
|
time_budget,
|
||||||
|
metric_name,
|
||||||
|
metric_mode_name,
|
||||||
**custom_hpo_args)
|
**custom_hpo_args)
|
||||||
return search_algo
|
return search_algo
|
||||||
|
|
||||||
@ -745,7 +751,12 @@ class AutoTransformers:
|
|||||||
ray.init(local_mode=ray_local_mode)
|
ray.init(local_mode=ray_local_mode)
|
||||||
self._set_search_space(**custom_hpo_args)
|
self._set_search_space(**custom_hpo_args)
|
||||||
|
|
||||||
search_algo = self._get_search_algo(self.jobid_config.alg, self.jobid_config.arg, **custom_hpo_args)
|
search_algo = self._get_search_algo(self.jobid_config.alg,
|
||||||
|
self.jobid_config.arg,
|
||||||
|
time_budget,
|
||||||
|
self.metric_name,
|
||||||
|
self.metric_mode_name,
|
||||||
|
**custom_hpo_args)
|
||||||
scheduler = AutoScheduler.from_scheduler_name(self.jobid_config.pru)
|
scheduler = AutoScheduler.from_scheduler_name(self.jobid_config.pru)
|
||||||
self.ckpt_per_epoch = ckpt_per_epoch
|
self.ckpt_per_epoch = ckpt_per_epoch
|
||||||
self.path_utils.make_dir_per_run()
|
self.path_utils.make_dir_per_run()
|
||||||
|
|||||||
@ -35,7 +35,14 @@ class AutoSearchAlgorithm:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_method_name(cls, search_algo_name, search_algo_args_mode, hpo_search_space, **custom_hpo_args):
|
def from_method_name(cls,
|
||||||
|
search_algo_name,
|
||||||
|
search_algo_args_mode,
|
||||||
|
hpo_search_space,
|
||||||
|
time_budget,
|
||||||
|
metric_name,
|
||||||
|
metric_mode_name,
|
||||||
|
**custom_hpo_args):
|
||||||
"""
|
"""
|
||||||
Instantiating one of the search algorithm classes based on the search algorithm name, search algorithm
|
Instantiating one of the search algorithm classes based on the search algorithm name, search algorithm
|
||||||
argument mode, hpo search space and other keyword args
|
argument mode, hpo search space and other keyword args
|
||||||
@ -85,15 +92,26 @@ class AutoSearchAlgorithm:
|
|||||||
"""
|
"""
|
||||||
if search_algo_args_mode == "dft":
|
if search_algo_args_mode == "dft":
|
||||||
this_search_algo_kwargs = DEFAULT_SEARCH_ALGO_ARGS_MAPPING[search_algo_name](
|
this_search_algo_kwargs = DEFAULT_SEARCH_ALGO_ARGS_MAPPING[search_algo_name](
|
||||||
"dft", hpo_search_space=hpo_search_space, **allowed_custom_args)
|
"dft",
|
||||||
|
metric_name,
|
||||||
|
metric_mode_name,
|
||||||
|
hpo_search_space=hpo_search_space,
|
||||||
|
**allowed_custom_args)
|
||||||
elif search_algo_args_mode == "cus":
|
elif search_algo_args_mode == "cus":
|
||||||
this_search_algo_kwargs = DEFAULT_SEARCH_ALGO_ARGS_MAPPING[search_algo_name](
|
this_search_algo_kwargs = DEFAULT_SEARCH_ALGO_ARGS_MAPPING[search_algo_name](
|
||||||
"cus", hpo_search_space=hpo_search_space, **allowed_custom_args)
|
"cus",
|
||||||
|
metric_name,
|
||||||
|
metric_mode_name,
|
||||||
|
hpo_search_space=hpo_search_space,
|
||||||
|
**allowed_custom_args)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
returning the hpo algorithm with the arguments
|
returning the hpo algorithm with the arguments
|
||||||
"""
|
"""
|
||||||
return SEARCH_ALGO_MAPPING[search_algo_name](**this_search_algo_kwargs)
|
search_algo = SEARCH_ALGO_MAPPING[search_algo_name](**this_search_algo_kwargs)
|
||||||
|
if search_algo_name == "bs":
|
||||||
|
search_algo.set_search_properties(config={"time_budget_s": time_budget})
|
||||||
|
return search_algo
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized method {} for this kind of AutoSearchAlgorithm: {}.\n"
|
"Unrecognized method {} for this kind of AutoSearchAlgorithm: {}.\n"
|
||||||
"Method name should be one of {}.".format(
|
"Method name should be one of {}.".format(
|
||||||
@ -109,11 +127,19 @@ class AutoSearchAlgorithm:
|
|||||||
return config_list
|
return config_list
|
||||||
|
|
||||||
|
|
||||||
def get_search_algo_args_optuna(search_args_mode, hpo_search_space=None, **custom_hpo_args):
|
def get_search_algo_args_optuna(search_args_mode,
|
||||||
|
metric_name,
|
||||||
|
metric_mode_name,
|
||||||
|
hpo_search_space=None,
|
||||||
|
**custom_hpo_args):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def default_search_algo_args_bs(search_args_mode, hpo_search_space=None, **custom_hpo_args):
|
def default_search_algo_args_bs(search_args_mode,
|
||||||
|
metric_name,
|
||||||
|
metric_mode_name,
|
||||||
|
hpo_search_space=None,
|
||||||
|
**custom_hpo_args):
|
||||||
assert hpo_search_space, "hpo_search_space needs to be specified for calling AutoSearchAlgorithm.from_method_name"
|
assert hpo_search_space, "hpo_search_space needs to be specified for calling AutoSearchAlgorithm.from_method_name"
|
||||||
if "num_train_epochs" in hpo_search_space and \
|
if "num_train_epochs" in hpo_search_space and \
|
||||||
isinstance(hpo_search_space["num_train_epochs"], ray.tune.sample.Categorical):
|
isinstance(hpo_search_space["num_train_epochs"], ray.tune.sample.Categorical):
|
||||||
@ -126,48 +152,28 @@ def default_search_algo_args_bs(search_args_mode, hpo_search_space=None, **custo
|
|||||||
"num_train_epochs": min_epoch,
|
"num_train_epochs": min_epoch,
|
||||||
"per_device_train_batch_size": max(hpo_search_space["per_device_train_batch_size"].categories),
|
"per_device_train_batch_size": max(hpo_search_space["per_device_train_batch_size"].categories),
|
||||||
},
|
},
|
||||||
|
"space": hpo_search_space,
|
||||||
|
"metric": metric_name,
|
||||||
|
"mode": metric_mode_name
|
||||||
}
|
}
|
||||||
if search_args_mode == "cus":
|
if search_args_mode == "cus":
|
||||||
default_search_algo_args.update(custom_hpo_args)
|
default_search_algo_args.update(custom_hpo_args)
|
||||||
return default_search_algo_args
|
return default_search_algo_args
|
||||||
|
|
||||||
|
|
||||||
def experiment_search_algo_args_bs(hpo_search_space=None):
|
def default_search_algo_args_grid_search(search_args_mode,
|
||||||
if "num_train_epochs" in hpo_search_space and \
|
metric_name,
|
||||||
isinstance(hpo_search_space["num_train_epochs"], ray.tune.sample.Categorical):
|
metric_mode_name,
|
||||||
min_epoch = min(hpo_search_space["num_train_epochs"].categories)
|
hpo_search_space=None,
|
||||||
else:
|
**custom_hpo_args):
|
||||||
assert isinstance(hpo_search_space["num_train_epochs"], ray.tune.sample.Float)
|
|
||||||
min_epoch = hpo_search_space["num_train_epochs"].lower
|
|
||||||
default_search_algo_args = {
|
|
||||||
"low_cost_partial_config": {
|
|
||||||
"num_train_epochs": min_epoch,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return default_search_algo_args
|
|
||||||
|
|
||||||
|
|
||||||
def default_search_algo_args_skopt(hpo_search_space=None):
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def default_search_algo_args_dragonfly(hpo_search_space=None):
|
def default_search_algo_args_random_search(search_args_mode,
|
||||||
return {}
|
metric_name,
|
||||||
|
metric_mode_name,
|
||||||
|
hpo_search_space=None,
|
||||||
def default_search_algo_args_nevergrad(hpo_search_space=None):
|
**custom_hpo_args):
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def default_search_algo_args_hyperopt(hpo_search_space=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def default_search_algo_args_grid_search(search_args_mode, hpo_search_space=None, **custom_hpo_args):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def default_search_algo_args_random_search(search_args_mode, hpo_search_space=None, **custom_hpo_args):
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -12,10 +12,11 @@ try:
|
|||||||
from ray.tune.suggest import Searcher
|
from ray.tune.suggest import Searcher
|
||||||
from ray.tune.suggest.optuna import OptunaSearch as GlobalSearch
|
from ray.tune.suggest.optuna import OptunaSearch as GlobalSearch
|
||||||
from ray.tune.suggest.variant_generator import generate_variants
|
from ray.tune.suggest.variant_generator import generate_variants
|
||||||
|
from ray.tune.utils.util import flatten_dict
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from .suggestion import Searcher
|
from .suggestion import Searcher
|
||||||
from .suggestion import OptunaSearch as GlobalSearch
|
from .suggestion import OptunaSearch as GlobalSearch
|
||||||
from .variant_generator import generate_variants
|
from .variant_generator import generate_variants, flatten_dict
|
||||||
from .search_thread import SearchThread
|
from .search_thread import SearchThread
|
||||||
from .flow2 import FLOW2
|
from .flow2 import FLOW2
|
||||||
|
|
||||||
@ -48,7 +49,8 @@ class BlendSearch(Searcher):
|
|||||||
List[Tuple[Callable[[dict], float], str, float]]] = None,
|
List[Tuple[Callable[[dict], float], str, float]]] = None,
|
||||||
metric_constraints: Optional[
|
metric_constraints: Optional[
|
||||||
List[Tuple[str, str, float]]] = None,
|
List[Tuple[str, str, float]]] = None,
|
||||||
seed: Optional[int] = 20):
|
seed: Optional[int] = 20,
|
||||||
|
experimental: Optional[bool] = False):
|
||||||
'''Constructor
|
'''Constructor
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -106,6 +108,7 @@ class BlendSearch(Searcher):
|
|||||||
metric_constraints: A list of metric constraints to be satisfied.
|
metric_constraints: A list of metric constraints to be satisfied.
|
||||||
e.g., `['precision', '>=', 0.9]`
|
e.g., `['precision', '>=', 0.9]`
|
||||||
seed: An integer of the random seed.
|
seed: An integer of the random seed.
|
||||||
|
experimental: A bool of whether to use experimental features.
|
||||||
'''
|
'''
|
||||||
self._metric, self._mode = metric, mode
|
self._metric, self._mode = metric, mode
|
||||||
init_config = low_cost_partial_config or {}
|
init_config = low_cost_partial_config or {}
|
||||||
@ -127,11 +130,20 @@ class BlendSearch(Searcher):
|
|||||||
elif getattr(self, '__name__', None) != 'CFO':
|
elif getattr(self, '__name__', None) != 'CFO':
|
||||||
try:
|
try:
|
||||||
gs_seed = seed - 10 if (seed - 10) >= 0 else seed - 11 + (1 << 32)
|
gs_seed = seed - 10 if (seed - 10) >= 0 else seed - 11 + (1 << 32)
|
||||||
self._gs = GlobalSearch(space=space, metric=metric, mode=mode, seed=gs_seed)
|
if experimental:
|
||||||
|
import optuna as ot
|
||||||
|
sampler = ot.samplers.TPESampler(
|
||||||
|
seed=seed, multivariate=True, group=True)
|
||||||
|
else:
|
||||||
|
sampler = None
|
||||||
|
self._gs = GlobalSearch(
|
||||||
|
space=space, metric=metric, mode=mode, seed=gs_seed,
|
||||||
|
sampler=sampler)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
self._gs = GlobalSearch(space=space, metric=metric, mode=mode)
|
self._gs = GlobalSearch(space=space, metric=metric, mode=mode)
|
||||||
else:
|
else:
|
||||||
self._gs = None
|
self._gs = None
|
||||||
|
self._experimental = experimental
|
||||||
if getattr(self, '__name__', None) == 'CFO' and points_to_evaluate and len(
|
if getattr(self, '__name__', None) == 'CFO' and points_to_evaluate and len(
|
||||||
points_to_evaluate) > 1:
|
points_to_evaluate) > 1:
|
||||||
# use the best config in points_to_evaluate as the start point
|
# use the best config in points_to_evaluate as the start point
|
||||||
@ -292,8 +304,14 @@ class BlendSearch(Searcher):
|
|||||||
objective = result[self._ls.metric]
|
objective = result[self._ls.metric]
|
||||||
if (objective - self._metric_target) * self._ls.metric_op < 0:
|
if (objective - self._metric_target) * self._ls.metric_op < 0:
|
||||||
self._metric_target = objective
|
self._metric_target = objective
|
||||||
if thread_id == 0 and metric_constraint_satisfied \
|
if thread_id:
|
||||||
and self._create_condition(result):
|
if not self._metric_constraint_satisfied:
|
||||||
|
# no point has been found to satisfy metric constraint
|
||||||
|
self._expand_admissible_region()
|
||||||
|
if self._gs is not None and self._experimental:
|
||||||
|
self._gs.add_evaluated_point(flatten_dict(config), objective)
|
||||||
|
elif metric_constraint_satisfied and self._create_condition(
|
||||||
|
result):
|
||||||
# thread creator
|
# thread creator
|
||||||
thread_id = self._thread_count
|
thread_id = self._thread_count
|
||||||
self._started_from_given = self._candidate_start_points \
|
self._started_from_given = self._candidate_start_points \
|
||||||
@ -303,9 +321,6 @@ class BlendSearch(Searcher):
|
|||||||
else:
|
else:
|
||||||
self._started_from_low_cost = True
|
self._started_from_low_cost = True
|
||||||
self._create_thread(config, result)
|
self._create_thread(config, result)
|
||||||
elif thread_id and not self._metric_constraint_satisfied:
|
|
||||||
# no point has been found to satisfy metric constraint
|
|
||||||
self._expand_admissible_region()
|
|
||||||
# reset admissible region to ls bounding box
|
# reset admissible region to ls bounding box
|
||||||
self._gs_admissible_min.update(self._ls_bound_min)
|
self._gs_admissible_min.update(self._ls_bound_min)
|
||||||
self._gs_admissible_max.update(self._ls_bound_max)
|
self._gs_admissible_max.update(self._ls_bound_max)
|
||||||
|
|||||||
@ -290,9 +290,12 @@ class FLOW2(Searcher):
|
|||||||
return unflatten_dict(config)
|
return unflatten_dict(config)
|
||||||
|
|
||||||
def create(self, init_config: Dict, obj: float, cost: float) -> Searcher:
|
def create(self, init_config: Dict, obj: float, cost: float) -> Searcher:
|
||||||
|
flatten_config = flatten_dict(init_config)
|
||||||
|
# use the subspace where the init_config is located
|
||||||
|
space = {k: self.space[k] for k in flatten_config if k in self.space}
|
||||||
flow2 = self.__class__(
|
flow2 = self.__class__(
|
||||||
init_config, self.metric, self.mode, self._cat_hp_cost,
|
init_config, self.metric, self.mode, self._cat_hp_cost,
|
||||||
unflatten_dict(self.space), self.prune_attr,
|
unflatten_dict(space), self.prune_attr,
|
||||||
self.min_resource, self.max_resource,
|
self.min_resource, self.max_resource,
|
||||||
self.resource_multiple_factor, self.cost_attr, self._seed + 1)
|
self.resource_multiple_factor, self.cost_attr, self._seed + 1)
|
||||||
flow2.best_obj = obj * self.metric_op # minimize internally
|
flow2.best_obj = obj * self.metric_op # minimize internally
|
||||||
|
|||||||
@ -71,16 +71,20 @@ config_search_space = {
|
|||||||
low_cost_partial_config={'x':1}
|
low_cost_partial_config={'x':1}
|
||||||
|
|
||||||
# set up CFO
|
# set up CFO
|
||||||
search_alg_cfo = CFO(low_cost_partial_config=low_cost_partial_config)
|
cfo = CFO(low_cost_partial_config=low_cost_partial_config)
|
||||||
|
|
||||||
# set up BlendSearch.
|
# set up BlendSearch
|
||||||
search_alg_blendsearch = BlendSearch(metric="metric",
|
blendsearch = BlendSearch(
|
||||||
mode="min",
|
metric="metric", mode="min",
|
||||||
space=config_search_space,
|
space=config_search_space,
|
||||||
low_cost_partial_config=low_cost_partial_config)
|
low_cost_partial_config=low_cost_partial_config)
|
||||||
# NOTE that when using BlendSearch as a search_alg in ray tune, you need to
|
# NOTE: when using BlendSearch as a search_alg in ray tune, you need to
|
||||||
# configure the 'time_budget_s' for BlendSearch accordingly as follows such that BlendSearch is aware of the time budget. This step is not needed when BlendSearch is used as the search_alg in flaml.tune as it is already done automatically in flaml.
|
# configure the 'time_budget_s' for BlendSearch accordingly as follows such that
|
||||||
search_alg_blendsearch.set_search_properties(config={"time_budget_s": time_budget_s})
|
# BlendSearch is aware of the time budget. This step is not needed when
|
||||||
|
# BlendSearch is used as the search_alg in flaml.tune as it is already done
|
||||||
|
# automatically in flaml. Also, this step needs to be done after the search
|
||||||
|
# space is passed to BlendSearch and before raytune.run.
|
||||||
|
blendsearch.set_search_properties(config={"time_budget_s": time_budget_s})
|
||||||
|
|
||||||
analysis = raytune.run(
|
analysis = raytune.run(
|
||||||
evaluate_config, # the function to evaluate a config
|
evaluate_config, # the function to evaluate a config
|
||||||
@ -90,7 +94,7 @@ analysis = raytune.run(
|
|||||||
num_samples=-1, # the maximal number of configs to try, -1 means infinite
|
num_samples=-1, # the maximal number of configs to try, -1 means infinite
|
||||||
time_budget_s=time_budget_s, # the time budget in seconds
|
time_budget_s=time_budget_s, # the time budget in seconds
|
||||||
local_dir='logs/', # the local directory to store logs
|
local_dir='logs/', # the local directory to store logs
|
||||||
search_alg=search_alg_blendsearch # or search_alg_cfo
|
search_alg=blendsearch # or cfo
|
||||||
)
|
)
|
||||||
|
|
||||||
print(analysis.best_trial.last_result) # the best trial's result
|
print(analysis.best_trial.last_result) # the best trial's result
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
{
|
{
|
||||||
"Component": {
|
"Component": {
|
||||||
"Type": "pip",
|
"Type": "pip",
|
||||||
"pip": {"Name": "ray[tune]", "Version": "1.2.0" }
|
"pip": {"Name": "ray[tune]", "Version": "1.5.1" }
|
||||||
},
|
},
|
||||||
"DevelopmentDependency": false
|
"DevelopmentDependency": false
|
||||||
},
|
},
|
||||||
|
|||||||
78
flaml/tune/space.py
Normal file
78
flaml/tune/space.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
|
||||||
|
try:
|
||||||
|
from ray.tune import sample
|
||||||
|
except ImportError:
|
||||||
|
from . import sample
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def define_by_run_func(
|
||||||
|
trial, space: Dict, path: str = ""
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Define-by-run function to create the search space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None or a dict with constant values.
|
||||||
|
"""
|
||||||
|
config = {}
|
||||||
|
for key, domain in space.items():
|
||||||
|
if path:
|
||||||
|
key = path + '/' + key
|
||||||
|
if not isinstance(domain, sample.Domain):
|
||||||
|
config[key] = domain
|
||||||
|
continue
|
||||||
|
sampler = domain.get_sampler()
|
||||||
|
quantize = None
|
||||||
|
if isinstance(sampler, sample.Quantized):
|
||||||
|
quantize = sampler.q
|
||||||
|
sampler = sampler.sampler
|
||||||
|
if isinstance(sampler, sample.LogUniform):
|
||||||
|
logger.warning(
|
||||||
|
"Optuna does not handle quantization in loguniform "
|
||||||
|
"sampling. The parameter will be passed but it will "
|
||||||
|
"probably be ignored.")
|
||||||
|
if isinstance(domain, sample.Float):
|
||||||
|
if isinstance(sampler, sample.LogUniform):
|
||||||
|
if quantize:
|
||||||
|
logger.warning(
|
||||||
|
"Optuna does not support both quantization and "
|
||||||
|
"sampling from LogUniform. Dropped quantization.")
|
||||||
|
trial.suggest_float(
|
||||||
|
key, domain.lower, domain.upper, log=True)
|
||||||
|
elif isinstance(sampler, sample.Uniform):
|
||||||
|
if quantize:
|
||||||
|
trial.suggest_float(
|
||||||
|
key, domain.lower, domain.upper, step=quantize)
|
||||||
|
trial.suggest_float(key, domain.lower, domain.upper)
|
||||||
|
elif isinstance(domain, sample.Integer):
|
||||||
|
if isinstance(sampler, sample.LogUniform):
|
||||||
|
trial.suggest_int(
|
||||||
|
key, domain.lower, domain.upper, step=quantize or 1, log=True)
|
||||||
|
elif isinstance(sampler, sample.Uniform):
|
||||||
|
# Upper bound should be inclusive for quantization and
|
||||||
|
# exclusive otherwise
|
||||||
|
trial.suggest_int(
|
||||||
|
key, domain.lower, domain.upper, step=quantize or 1)
|
||||||
|
elif isinstance(domain, sample.Categorical):
|
||||||
|
if isinstance(sampler, sample.Uniform):
|
||||||
|
if not hasattr(domain, 'choices'):
|
||||||
|
domain.choices = list(range(len(domain.categories)))
|
||||||
|
choices = domain.choices
|
||||||
|
# This choice needs to be removed from the final config
|
||||||
|
index = trial.suggest_categorical(key + '_choice_', choices)
|
||||||
|
choice = domain.categories[index]
|
||||||
|
if isinstance(choice, dict):
|
||||||
|
key += f":{index}"
|
||||||
|
# the suffix needs to be removed from the final config
|
||||||
|
config[key] = define_by_run_func(trial, choice, key)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Optuna search does not support parameters of type "
|
||||||
|
"`{}` with samplers of type `{}`".format(
|
||||||
|
type(domain).__name__,
|
||||||
|
type(domain.sampler).__name__))
|
||||||
|
# Return all constants in a dictionary.
|
||||||
|
return config
|
||||||
@ -4,30 +4,31 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Copyright (c) 2020-2021. All rights reserved.\n",
|
"Copyright (c) 2020-2021. All rights reserved.\r\n",
|
||||||
"\n",
|
"\r\n",
|
||||||
"Licensed under the MIT License.\n",
|
"Licensed under the MIT License.\r\n",
|
||||||
"\n",
|
"\r\n",
|
||||||
"# Troubleshooting HPO for fine-tuning pre-trained language models\n",
|
"# Troubleshooting HPO for fine-tuning pre-trained language models\r\n",
|
||||||
"\n",
|
"\r\n",
|
||||||
"## 1. Introduction\n",
|
"## 1. Introduction\r\n",
|
||||||
"\n",
|
"\r\n",
|
||||||
"\n",
|
"In this notebook, we demonstrate a procedure for troubleshooting HPO failure in fine-tuning pre-trained language models (introduced in the following paper):\r\n",
|
||||||
"In this notebook, we demonstrate a procedure for troubleshooting HPO failure in fine-tuning pre-trained language models (introduced in the following paper):\n",
|
"\r\n",
|
||||||
"\n",
|
"*[An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://arxiv.org/abs/2106.09204). Xueqing Liu, Chi Wang. ACL-IJCNLP 2021*\r\n",
|
||||||
"*[An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://arxiv.org/abs/2106.09204). Xueqing Liu, Chi Wang. To appear in ACL-IJCNLP 2021*\n",
|
"\r\n",
|
||||||
"\n",
|
"Notes:\r\n",
|
||||||
"Notes:\n",
|
"\r\n",
|
||||||
"\n",
|
"*In this notebook, we only run each experiment 1 time for simplicity, which is different from the paper (3 times). To reproduce the paper's result, please run 3 repetitions and take the average scores.\r\n",
|
||||||
"*In this notebook, we only run each experiment 1 time for simplicity, which is different from the paper (3 times). To reproduce the paper's result, please run 3 repetitions and take the average scores.\n",
|
"\r\n",
|
||||||
"\n",
|
"*Running this notebook takes about one hour.\r\n",
|
||||||
"*Running this notebook takes about one hour.\n",
|
"\r\n",
|
||||||
"\n",
|
"FLAML requires `Python>=3.6`. To run this notebook example, please install flaml with the `notebook` and `nlp` options:\r\n",
|
||||||
"FLAML requires `Python>=3.6`. To run this notebook example, please install flaml with the `notebook` and `nlp` options:\n",
|
"\r\n",
|
||||||
"```bash\n",
|
"```bash\r\n",
|
||||||
"pip install flaml[nlp]\n",
|
"pip install flaml[nlp]\r\n",
|
||||||
"```\n",
|
"```\r\n",
|
||||||
"Our paper was developed under transformers version 3.4.0. We uninstall and reinstall transformers==3.4.0:"
|
"\r\n",
|
||||||
|
"Our paper was developed under transformers version 3.4.0. We uninstall and reinstall transformers==3.4.0:\r\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -40,9 +41,9 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install flaml[nlp]\n",
|
"!pip install flaml[nlp]\r\n",
|
||||||
"!pip install transformers==3.4.0\n",
|
"!pip install transformers==3.4.0\r\n",
|
||||||
"from flaml.nlp import AutoTransformers\n"
|
"from flaml.nlp import AutoTransformers\r\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -363,10 +364,10 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
|
"\u001b[2m\u001b[36m(pid=50964)\u001b[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
|
"\u001b[2m\u001b[36m(pid=50964)\u001b[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n",
|
"\u001b[2m\u001b[36m(pid=50948)\u001b[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n"
|
"\u001b[2m\u001b[36m(pid=50948)\u001b[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -484,12 +485,12 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=54411)\u001b[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=54411)\u001b[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=54411)\u001b[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=54417)\u001b[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=54417)\u001b[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n"
|
"\u001b[2m\u001b[36m(pid=54417)\u001b[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -589,18 +590,18 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n"
|
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -700,21 +701,21 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
||||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n"
|
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -791,22 +792,16 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "bfcd9a6a9254a5e160761a1fd7a9e444f011592c6770d9f4180dde058a9df5dd"
|
||||||
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "Python 3.7.7 64-bit ('flaml': conda)",
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"version": ""
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.8.0"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
4
setup.py
4
setup.py
@ -61,7 +61,7 @@ setuptools.setup(
|
|||||||
"optuna==2.8.0"
|
"optuna==2.8.0"
|
||||||
],
|
],
|
||||||
"ray": [
|
"ray": [
|
||||||
"ray[tune]==1.4.1",
|
"ray[tune]==1.5.1",
|
||||||
"pyyaml<5.3.1",
|
"pyyaml<5.3.1",
|
||||||
],
|
],
|
||||||
"azureml": [
|
"azureml": [
|
||||||
@ -74,7 +74,7 @@ setuptools.setup(
|
|||||||
"vowpalwabbit",
|
"vowpalwabbit",
|
||||||
],
|
],
|
||||||
"nlp": [
|
"nlp": [
|
||||||
"ray[tune]>=1.4.1",
|
"ray[tune]>=1.5.1",
|
||||||
"transformers",
|
"transformers",
|
||||||
"datasets==1.4.1",
|
"datasets==1.4.1",
|
||||||
"tensorboardX<=2.2",
|
"tensorboardX<=2.2",
|
||||||
|
|||||||
@ -37,6 +37,9 @@ def test_automl(budget=5):
|
|||||||
get_output_from_log(filename=settings['log_file_name'], time_budget=60)
|
get_output_from_log(filename=settings['log_file_name'], time_budget=60)
|
||||||
for config in config_history:
|
for config in config_history:
|
||||||
print(config)
|
print(config)
|
||||||
|
print(automl.prune_attr)
|
||||||
|
print(automl.max_resource)
|
||||||
|
print(automl.min_resource)
|
||||||
|
|
||||||
|
|
||||||
def test_mlflow():
|
def test_mlflow():
|
||||||
|
|||||||
@ -42,7 +42,16 @@ class TestLogging(unittest.TestCase):
|
|||||||
automl.fit(X_train=X_train[:n], y_train=y_train[:n],
|
automl.fit(X_train=X_train[:n], y_train=y_train[:n],
|
||||||
X_val=X_train[n:], y_val=y_train[n:],
|
X_val=X_train[n:], y_val=y_train[n:],
|
||||||
**automl_settings)
|
**automl_settings)
|
||||||
|
logger.info(automl.search_space)
|
||||||
|
logger.info(automl.low_cost_partial_config)
|
||||||
|
logger.info(automl.points_to_evalaute)
|
||||||
|
import optuna as ot
|
||||||
|
study = ot.create_study()
|
||||||
|
from flaml.tune.space import define_by_run_func
|
||||||
|
logger.info(define_by_run_func(study.ask(), automl.search_space))
|
||||||
|
config = automl.best_config.copy()
|
||||||
|
config['learner'] = automl.best_estimator
|
||||||
|
automl.trainable({"ml": config})
|
||||||
# Check if the log buffer is populated.
|
# Check if the log buffer is populated.
|
||||||
self.assertTrue(len(buf.getvalue()) > 0)
|
self.assertTrue(len(buf.getvalue()) > 0)
|
||||||
|
|
||||||
|
|||||||
@ -19,11 +19,11 @@ class XGBoost2D(XGBoostSklearnEstimator):
|
|||||||
return {
|
return {
|
||||||
'n_estimators': {
|
'n_estimators': {
|
||||||
'domain': tune.lograndint(lower=4, upper=upper),
|
'domain': tune.lograndint(lower=4, upper=upper),
|
||||||
'init_value': 4,
|
'low_cost_init_value': 4,
|
||||||
},
|
},
|
||||||
'max_leaves': {
|
'max_leaves': {
|
||||||
'domain': tune.lograndint(lower=4, upper=upper),
|
'domain': tune.lograndint(lower=4, upper=upper),
|
||||||
'init_value': 4,
|
'low_cost_init_value': 4,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ def test_simple(method=None):
|
|||||||
"n_jobs": 1,
|
"n_jobs": 1,
|
||||||
"hpo_method": method,
|
"hpo_method": method,
|
||||||
"log_type": "all",
|
"log_type": "all",
|
||||||
"time_budget": 3
|
"time_budget": 1
|
||||||
}
|
}
|
||||||
from sklearn.externals._arff import ArffException
|
from sklearn.externals._arff import ArffException
|
||||||
try:
|
try:
|
||||||
@ -51,6 +51,25 @@ def test_simple(method=None):
|
|||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
X, y, test_size=0.33, random_state=42)
|
X, y, test_size=0.33, random_state=42)
|
||||||
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
||||||
|
print(automl.estimator_list)
|
||||||
|
print(automl.search_space)
|
||||||
|
print(automl.points_to_evalaute)
|
||||||
|
config = automl.best_config.copy()
|
||||||
|
config['learner'] = automl.best_estimator
|
||||||
|
automl.trainable(config)
|
||||||
|
from flaml import tune
|
||||||
|
analysis = tune.run(
|
||||||
|
automl.trainable, automl.search_space, metric='val_loss',
|
||||||
|
low_cost_partial_config=automl.low_cost_partial_config,
|
||||||
|
points_to_evaluate=automl.points_to_evalaute,
|
||||||
|
cat_hp_cost=automl.cat_hp_cost,
|
||||||
|
prune_attr=automl.prune_attr,
|
||||||
|
min_resource=automl.min_resource,
|
||||||
|
max_resource=automl.max_resource,
|
||||||
|
time_budget_s=automl._state.time_budget,
|
||||||
|
config_constraints=[(automl.size, '<=', automl._mem_thres)],
|
||||||
|
metric_constraints=automl.metric_constraints)
|
||||||
|
print(analysis.trials[-1])
|
||||||
|
|
||||||
|
|
||||||
def _test_optuna():
|
def _test_optuna():
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
'''Require: pip install flaml[test,ray]
|
'''Require: pip install flaml[test,ray]
|
||||||
'''
|
'''
|
||||||
|
from flaml.searcher.blendsearch import BlendSearch
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
@ -199,6 +200,28 @@ def test_nested():
|
|||||||
logger.info(f"CFO best config: {best_trial.config}")
|
logger.info(f"CFO best config: {best_trial.config}")
|
||||||
logger.info(f"CFO best result: {best_trial.last_result}")
|
logger.info(f"CFO best result: {best_trial.last_result}")
|
||||||
|
|
||||||
|
analysis = tune.run(
|
||||||
|
simple_func,
|
||||||
|
search_alg=BlendSearch(
|
||||||
|
experimental=True,
|
||||||
|
space=search_space, metric="obj", mode="min",
|
||||||
|
low_cost_partial_config={
|
||||||
|
"cost_related": {"a": 1}
|
||||||
|
},
|
||||||
|
points_to_evaluate=[
|
||||||
|
{"b": .99, "cost_related": {"a": 3}},
|
||||||
|
{"b": .99, "cost_related": {"a": 2}},
|
||||||
|
{"cost_related": {"a": 8}}
|
||||||
|
],
|
||||||
|
metric_constraints=[("ab", "<=", 4)]),
|
||||||
|
local_dir='logs/',
|
||||||
|
num_samples=-1,
|
||||||
|
time_budget_s=.1)
|
||||||
|
|
||||||
|
best_trial = analysis.get_best_trial()
|
||||||
|
logger.info(f"BlendSearch exp best config: {best_trial.config}")
|
||||||
|
logger.info(f"BlendSearch exp best result: {best_trial.last_result}")
|
||||||
|
|
||||||
analysis = tune.run(
|
analysis = tune.run(
|
||||||
simple_func,
|
simple_func,
|
||||||
config=search_space,
|
config=search_space,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user