mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-10 16:55:42 +00:00
constraints (#88)
* pre-training constraints * metric constraints after training
This commit is contained in:
parent
3083229e40
commit
0925e2b308
@ -1036,9 +1036,8 @@ class AutoML:
|
|||||||
prune_attr=prune_attr,
|
prune_attr=prune_attr,
|
||||||
min_resource=min_resource,
|
min_resource=min_resource,
|
||||||
max_resource=max_resource,
|
max_resource=max_resource,
|
||||||
resources_per_trial={"cpu": self._state.n_jobs,
|
config_constraints=[(learner_class.size, '<=', self._mem_thres)]
|
||||||
"mem": self._mem_thres},
|
)
|
||||||
mem_size=learner_class.size)
|
|
||||||
else:
|
else:
|
||||||
algo = SearchAlgo(
|
algo = SearchAlgo(
|
||||||
metric='val_loss', mode='min', space=search_space,
|
metric='val_loss', mode='min', space=search_space,
|
||||||
|
@ -237,8 +237,8 @@ class DataTransformer:
|
|||||||
SimpleImputer(missing_values=np.nan, strategy='median'),
|
SimpleImputer(missing_values=np.nan, strategy='median'),
|
||||||
X_num.columns)])
|
X_num.columns)])
|
||||||
X[num_columns] = self.transformer.fit_transform(X_num)
|
X[num_columns] = self.transformer.fit_transform(X_num)
|
||||||
self._cat_columns, self._num_columns, self._datetime_columns = cat_columns, \
|
self._cat_columns, self._num_columns, self._datetime_columns = \
|
||||||
num_columns, datetime_columns
|
cat_columns, num_columns, datetime_columns
|
||||||
self._drop = drop
|
self._drop = drop
|
||||||
|
|
||||||
if task == 'regression':
|
if task == 'regression':
|
||||||
@ -275,4 +275,3 @@ class DataTransformer:
|
|||||||
X_num.columns = range(X_num.shape[1])
|
X_num.columns = range(X_num.shape[1])
|
||||||
X[num_columns] = self.transformer.transform(X_num)
|
X[num_columns] = self.transformer.transform(X_num)
|
||||||
return X
|
return X
|
||||||
|
|
||||||
|
@ -39,9 +39,11 @@ class BlendSearch(Searcher):
|
|||||||
min_resource: Optional[float] = None,
|
min_resource: Optional[float] = None,
|
||||||
max_resource: Optional[float] = None,
|
max_resource: Optional[float] = None,
|
||||||
reduction_factor: Optional[float] = None,
|
reduction_factor: Optional[float] = None,
|
||||||
resources_per_trial: Optional[dict] = None,
|
|
||||||
global_search_alg: Optional[Searcher] = None,
|
global_search_alg: Optional[Searcher] = None,
|
||||||
mem_size: Callable[[dict], float] = None,
|
config_constraints: Optional[
|
||||||
|
List[Tuple[Callable[[dict], float], str, float]]] = None,
|
||||||
|
metric_constraints: Optional[
|
||||||
|
List[Tuple[str, str, float]]] = None,
|
||||||
seed: Optional[int] = 20):
|
seed: Optional[int] = 20):
|
||||||
'''Constructor
|
'''Constructor
|
||||||
|
|
||||||
@ -82,14 +84,23 @@ class BlendSearch(Searcher):
|
|||||||
prune_attr; only valid if prune_attr is not in space.
|
prune_attr; only valid if prune_attr is not in space.
|
||||||
reduction_factor: A float of the reduction factor used for
|
reduction_factor: A float of the reduction factor used for
|
||||||
incremental pruning.
|
incremental pruning.
|
||||||
resources_per_trial: A dictionary of the resources permitted per
|
|
||||||
trial, such as 'mem'.
|
|
||||||
global_search_alg: A Searcher instance as the global search
|
global_search_alg: A Searcher instance as the global search
|
||||||
instance. If omitted, Optuna is used. The following algos have
|
instance. If omitted, Optuna is used. The following algos have
|
||||||
known issues when used as global_search_alg:
|
known issues when used as global_search_alg:
|
||||||
- HyperOptSearch raises exception sometimes
|
- HyperOptSearch raises exception sometimes
|
||||||
- TuneBOHB has its own scheduler
|
- TuneBOHB has its own scheduler
|
||||||
mem_size: A function to estimate the memory size for a given config.
|
config_constraints: A list of config constraints to be satisfied.
|
||||||
|
e.g.,
|
||||||
|
|
||||||
|
.. code-block: python
|
||||||
|
|
||||||
|
config_constraints = [(mem_size, '<=', 1024**3)]
|
||||||
|
|
||||||
|
mem_size is a function which produces a float number for the bytes
|
||||||
|
needed for a config.
|
||||||
|
It is used to skip configs which do not fit in memory.
|
||||||
|
metric_constraints: A list of metric constraints to be satisfied.
|
||||||
|
e.g., `['precision', '>=', 0.9]`
|
||||||
seed: An integer of the random seed.
|
seed: An integer of the random seed.
|
||||||
'''
|
'''
|
||||||
self._metric, self._mode = metric, mode
|
self._metric, self._mode = metric, mode
|
||||||
@ -104,10 +115,8 @@ class BlendSearch(Searcher):
|
|||||||
self._ls = LocalSearch(
|
self._ls = LocalSearch(
|
||||||
init_config, metric, mode, cat_hp_cost, space,
|
init_config, metric, mode, cat_hp_cost, space,
|
||||||
prune_attr, min_resource, max_resource, reduction_factor, seed)
|
prune_attr, min_resource, max_resource, reduction_factor, seed)
|
||||||
self._resources_per_trial = resources_per_trial
|
self._config_constraints = config_constraints
|
||||||
self._mem_size = mem_size
|
self._metric_constraints = metric_constraints
|
||||||
self._mem_threshold = resources_per_trial.get(
|
|
||||||
'mem') if resources_per_trial else None
|
|
||||||
self._init_search()
|
self._init_search()
|
||||||
|
|
||||||
def set_search_properties(self,
|
def set_search_properties(self,
|
||||||
@ -171,9 +180,8 @@ class BlendSearch(Searcher):
|
|||||||
self._points_to_evaluate = state._points_to_evaluate
|
self._points_to_evaluate = state._points_to_evaluate
|
||||||
self._gs = state._gs
|
self._gs = state._gs
|
||||||
self._ls = state._ls
|
self._ls = state._ls
|
||||||
self._resources_per_trial = state._resources_per_trial
|
self._config_constraints = state._config_constraints
|
||||||
self._mem_size = state._mem_size
|
self._metric_constraints = state._metric_constraints
|
||||||
self._mem_threshold = state._mem_threshold
|
|
||||||
|
|
||||||
def restore_from_dir(self, checkpoint_dir: str):
|
def restore_from_dir(self, checkpoint_dir: str):
|
||||||
super.restore_from_dir(checkpoint_dir)
|
super.restore_from_dir(checkpoint_dir)
|
||||||
@ -182,6 +190,20 @@ class BlendSearch(Searcher):
|
|||||||
error: bool = False):
|
error: bool = False):
|
||||||
''' search thread updater and cleaner
|
''' search thread updater and cleaner
|
||||||
'''
|
'''
|
||||||
|
if result and not error and self._metric_constraints:
|
||||||
|
# accout for metric constraints if any
|
||||||
|
objective = result[self._metric]
|
||||||
|
for constraint in self._metric_constraints:
|
||||||
|
metric_constraint, sign, threshold = constraint
|
||||||
|
value = result.get(metric_constraint)
|
||||||
|
if value:
|
||||||
|
# sign is <= or >=
|
||||||
|
sign_op = 1 if sign == '<=' else -1
|
||||||
|
violation = (value - threshold) * sign_op
|
||||||
|
if violation > 0:
|
||||||
|
# add penalty term to the metric
|
||||||
|
objective += 1e+10 * violation * self._ls.metric_op
|
||||||
|
result[self._metric] = objective
|
||||||
thread_id = self._trial_proposed_by.get(trial_id)
|
thread_id = self._trial_proposed_by.get(trial_id)
|
||||||
if thread_id in self._search_thread_pool:
|
if thread_id in self._search_thread_pool:
|
||||||
self._search_thread_pool[thread_id].on_trial_complete(
|
self._search_thread_pool[thread_id].on_trial_complete(
|
||||||
@ -196,23 +218,24 @@ class BlendSearch(Searcher):
|
|||||||
del self._result[self._ls.config_signature(config)]
|
del self._result[self._ls.config_signature(config)]
|
||||||
else: # add to result cache
|
else: # add to result cache
|
||||||
self._result[self._ls.config_signature(config)] = result
|
self._result[self._ls.config_signature(config)] = result
|
||||||
# update target metric if improved
|
# update target metric if improved
|
||||||
if (result[self._metric] - self._metric_target) * self._ls.metric_op < 0:
|
objective = result[self._metric]
|
||||||
self._metric_target = result[self._metric]
|
if (objective - self._metric_target) * self._ls.metric_op < 0:
|
||||||
if not thread_id and self._create_condition(result):
|
self._metric_target = objective
|
||||||
# thread creator
|
if not thread_id and self._create_condition(result):
|
||||||
self._search_thread_pool[self._thread_count] = SearchThread(
|
# thread creator
|
||||||
self._ls.mode,
|
self._search_thread_pool[self._thread_count] = SearchThread(
|
||||||
self._ls.create(config, result[self._metric], cost=result[
|
self._ls.mode,
|
||||||
self.cost_attr])
|
self._ls.create(
|
||||||
)
|
config, objective, cost=result[self.cost_attr])
|
||||||
thread_id = self._thread_count
|
)
|
||||||
self._thread_count += 1
|
thread_id = self._thread_count
|
||||||
self._update_admissible_region(
|
self._thread_count += 1
|
||||||
config, self._ls_bound_min, self._ls_bound_max)
|
self._update_admissible_region(
|
||||||
# reset admissible region to ls bounding box
|
config, self._ls_bound_min, self._ls_bound_max)
|
||||||
self._gs_admissible_min.update(self._ls_bound_min)
|
# reset admissible region to ls bounding box
|
||||||
self._gs_admissible_max.update(self._ls_bound_max)
|
self._gs_admissible_min.update(self._ls_bound_min)
|
||||||
|
self._gs_admissible_max.update(self._ls_bound_max)
|
||||||
# cleaner
|
# cleaner
|
||||||
if thread_id and thread_id in self._search_thread_pool:
|
if thread_id and thread_id in self._search_thread_pool:
|
||||||
# local search thread
|
# local search thread
|
||||||
@ -262,7 +285,7 @@ class BlendSearch(Searcher):
|
|||||||
def _expand_admissible_region(self):
|
def _expand_admissible_region(self):
|
||||||
for key in self._ls_bound_max:
|
for key in self._ls_bound_max:
|
||||||
self._ls_bound_max[key] += self._ls.STEPSIZE
|
self._ls_bound_max[key] += self._ls.STEPSIZE
|
||||||
self._ls_bound_min[key] -= self._ls.STEPSIZE
|
self._ls_bound_min[key] -= self._ls.STEPSIZE
|
||||||
|
|
||||||
def _inferior(self, id1: int, id2: int) -> bool:
|
def _inferior(self, id1: int, id2: int) -> bool:
|
||||||
''' whether thread id1 is inferior to id2
|
''' whether thread id1 is inferior to id2
|
||||||
@ -362,20 +385,26 @@ class BlendSearch(Searcher):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
def _should_skip(self, choice, trial_id, config) -> bool:
|
def _should_skip(self, choice, trial_id, config) -> bool:
|
||||||
''' if config is None or config's result is known or above mem threshold
|
''' if config is None or config's result is known or constraints are violated
|
||||||
return True; o.w. return False
|
return True; o.w. return False
|
||||||
'''
|
'''
|
||||||
if config is None:
|
if config is None:
|
||||||
return True
|
return True
|
||||||
config_signature = self._ls.config_signature(config)
|
config_signature = self._ls.config_signature(config)
|
||||||
exists = config_signature in self._result
|
exists = config_signature in self._result
|
||||||
# check mem constraint
|
# check constraints
|
||||||
if not exists and self._mem_threshold and self._mem_size(
|
if not exists and self._config_constraints:
|
||||||
config) > self._mem_threshold:
|
for constraint in self._config_constraints:
|
||||||
self._result[config_signature] = {
|
func, sign, threshold = constraint
|
||||||
self._metric: np.inf * self._ls.metric_op, 'time_total_s': 1
|
value = func(config)
|
||||||
}
|
if (sign == '<=' and value > threshold
|
||||||
exists = True
|
or sign == '>=' and value < threshold):
|
||||||
|
self._result[config_signature] = {
|
||||||
|
self._metric: np.inf * self._ls.metric_op,
|
||||||
|
'time_total_s': 1,
|
||||||
|
}
|
||||||
|
exists = True
|
||||||
|
break
|
||||||
if exists:
|
if exists:
|
||||||
if not self._use_rs:
|
if not self._use_rs:
|
||||||
result = self._result.get(config_signature)
|
result = self._result.get(config_signature)
|
||||||
|
@ -544,7 +544,7 @@ class FLOW2(Searcher):
|
|||||||
self._configs[trial_id] = (config, self.step)
|
self._configs[trial_id] = (config, self.step)
|
||||||
self._num_proposedby_incumbent += 1
|
self._num_proposedby_incumbent += 1
|
||||||
if self._init_phrase:
|
if self._init_phrase:
|
||||||
if self._direction_tried is None:
|
if self._direction_tried is None:
|
||||||
if self._same:
|
if self._same:
|
||||||
# check if the new config is different from self.best_config
|
# check if the new config is different from self.best_config
|
||||||
same = True
|
same = True
|
||||||
@ -566,17 +566,17 @@ class FLOW2(Searcher):
|
|||||||
break
|
break
|
||||||
self._same = same
|
self._same = same
|
||||||
if self._num_proposedby_incumbent == self.dir and (
|
if self._num_proposedby_incumbent == self.dir and (
|
||||||
not self._resource or self._resource == self.max_resource):
|
not self._resource or self._resource == self.max_resource):
|
||||||
# check stuck condition if using max resource
|
# check stuck condition if using max resource
|
||||||
self._num_proposedby_incumbent -= 2
|
self._num_proposedby_incumbent -= 2
|
||||||
self._init_phrase = False
|
self._init_phrase = False
|
||||||
if self.step >= self.step_lower_bound:
|
if self.step >= self.step_lower_bound:
|
||||||
# decrease step size
|
# decrease step size
|
||||||
self._oldK = self._K if self._K else self._iter_best_config
|
self._oldK = self._K if self._K else self._iter_best_config
|
||||||
self._K = self.trial_count_proposed + 1
|
self._K = self.trial_count_proposed + 1
|
||||||
self.step *= np.sqrt(self._oldK / self._K)
|
self.step *= np.sqrt(self._oldK / self._K)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
return unflatten_dict(config)
|
return unflatten_dict(config)
|
||||||
|
|
||||||
def _project(self, config):
|
def _project(self, config):
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
* Licensed under the MIT License. See LICENSE file in the
|
* Licensed under the MIT License. See LICENSE file in the
|
||||||
* project root for license information.
|
* project root for license information.
|
||||||
'''
|
'''
|
||||||
from typing import Optional, Union, List, Callable
|
from typing import Optional, Union, List, Callable, Tuple
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
try:
|
try:
|
||||||
@ -118,7 +118,10 @@ def run(training_function,
|
|||||||
local_dir: Optional[str] = None,
|
local_dir: Optional[str] = None,
|
||||||
num_samples: Optional[int] = 1,
|
num_samples: Optional[int] = 1,
|
||||||
resources_per_trial: Optional[dict] = None,
|
resources_per_trial: Optional[dict] = None,
|
||||||
mem_size: Callable[[dict], float] = None,
|
config_constraints: Optional[
|
||||||
|
List[Tuple[Callable[[dict], float], str, float]]] = None,
|
||||||
|
metric_constraints: Optional[
|
||||||
|
List[Tuple[str, str, float]]] = None,
|
||||||
use_ray: Optional[bool] = False):
|
use_ray: Optional[bool] = False):
|
||||||
'''The trigger for HPO.
|
'''The trigger for HPO.
|
||||||
|
|
||||||
@ -210,11 +213,19 @@ def run(training_function,
|
|||||||
used; or a local dir to save the tuning log.
|
used; or a local dir to save the tuning log.
|
||||||
num_samples: An integer of the number of configs to try. Defaults to 1.
|
num_samples: An integer of the number of configs to try. Defaults to 1.
|
||||||
resources_per_trial: A dictionary of the hardware resources to allocate
|
resources_per_trial: A dictionary of the hardware resources to allocate
|
||||||
per trial, e.g., `{'mem': 1024**3}`. When not using ray backend,
|
per trial, e.g., `{'cpu': 1}`. Only valid when using ray backend.
|
||||||
only 'mem' is used as approximate resource constraints
|
config_constraints: A list of config constraints to be satisfied.
|
||||||
(in conjunction with mem_size).
|
e.g.,
|
||||||
mem_size: A function to estimate the memory size for a given config.
|
|
||||||
|
.. code-block: python
|
||||||
|
|
||||||
|
config_constraints = [(mem_size, '<=', 1024**3)]
|
||||||
|
|
||||||
|
mem_size is a function which produces a float number for the bytes
|
||||||
|
needed for a config.
|
||||||
It is used to skip configs which do not fit in memory.
|
It is used to skip configs which do not fit in memory.
|
||||||
|
metric_constraints: A list of metric constraints to be satisfied.
|
||||||
|
e.g., `['precision', '>=', 0.9]`
|
||||||
use_ray: A boolean of whether to use ray as the backend
|
use_ray: A boolean of whether to use ray as the backend
|
||||||
'''
|
'''
|
||||||
global _use_ray
|
global _use_ray
|
||||||
@ -252,8 +263,8 @@ def run(training_function,
|
|||||||
prune_attr=prune_attr,
|
prune_attr=prune_attr,
|
||||||
min_resource=min_resource, max_resource=max_resource,
|
min_resource=min_resource, max_resource=max_resource,
|
||||||
reduction_factor=reduction_factor,
|
reduction_factor=reduction_factor,
|
||||||
resources_per_trial=resources_per_trial,
|
config_constraints=config_constraints,
|
||||||
mem_size=mem_size)
|
metric_constraints=metric_constraints)
|
||||||
if time_budget_s:
|
if time_budget_s:
|
||||||
search_alg.set_search_properties(metric, mode, config={
|
search_alg.set_search_properties(metric, mode, config={
|
||||||
'time_budget_s': time_budget_s})
|
'time_budget_s': time_budget_s})
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "0.3.6"
|
__version__ = "0.4.0"
|
||||||
|
@ -70,7 +70,7 @@ class MyRegularizedGreedyForest(SKLearnEstimator):
|
|||||||
|
|
||||||
def logregobj(preds, dtrain):
|
def logregobj(preds, dtrain):
|
||||||
labels = dtrain.get_label()
|
labels = dtrain.get_label()
|
||||||
preds = 1.0 / (1.0 + np.exp(-preds)) # transform raw leaf weight
|
preds = 1.0 / (1.0 + np.exp(-preds)) # transform raw leaf weight
|
||||||
grad = preds - labels
|
grad = preds - labels
|
||||||
hess = preds * (1.0 - preds)
|
hess = preds * (1.0 - preds)
|
||||||
return grad, hess
|
return grad, hess
|
||||||
@ -81,7 +81,7 @@ class MyXGB1(XGBoostEstimator):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
super().__init__(objective=logregobj, **params)
|
super().__init__(objective=logregobj, **params)
|
||||||
|
|
||||||
|
|
||||||
class MyXGB2(XGBoostEstimator):
|
class MyXGB2(XGBoostEstimator):
|
||||||
@ -226,32 +226,34 @@ class TestAutoML(unittest.TestCase):
|
|||||||
|
|
||||||
automl_experiment = AutoML()
|
automl_experiment = AutoML()
|
||||||
automl_settings = {
|
automl_settings = {
|
||||||
"time_budget": 2,
|
"time_budget": 2,
|
||||||
"metric": 'mse',
|
"metric": 'mse',
|
||||||
"task": 'regression',
|
"task": 'regression',
|
||||||
"log_file_name": "test/datetime_columns.log",
|
"log_file_name": "test/datetime_columns.log",
|
||||||
"log_training_metric": True,
|
"log_training_metric": True,
|
||||||
"n_jobs": 1,
|
"n_jobs": 1,
|
||||||
"model_history": True
|
"model_history": True
|
||||||
}
|
}
|
||||||
|
|
||||||
fake_df = pd.DataFrame({'A': [datetime(1900, 2, 3), datetime(1900, 3, 4)]})
|
fake_df = pd.DataFrame({'A': [datetime(1900, 2, 3), datetime(1900, 3, 4)]})
|
||||||
y = np.array([0, 1])
|
y = np.array([0, 1])
|
||||||
automl_experiment.fit(X_train=fake_df, X_val=fake_df, y_train=y, y_val=y, **automl_settings)
|
automl_experiment.fit(
|
||||||
|
X_train=fake_df, X_val=fake_df, y_train=y, y_val=y, **automl_settings)
|
||||||
|
|
||||||
y_pred = automl_experiment.predict(fake_df)
|
y_pred = automl_experiment.predict(fake_df)
|
||||||
|
print(y_pred)
|
||||||
|
|
||||||
def test_micro_macro_f1(self):
|
def test_micro_macro_f1(self):
|
||||||
automl_experiment = AutoML()
|
automl_experiment = AutoML()
|
||||||
automl_experiment_macro = AutoML()
|
automl_experiment_macro = AutoML()
|
||||||
|
|
||||||
automl_settings = {
|
automl_settings = {
|
||||||
"time_budget": 2,
|
"time_budget": 2,
|
||||||
"task": 'classification',
|
"task": 'classification',
|
||||||
"log_file_name": "test/micro_macro_f1.log",
|
"log_file_name": "test/micro_macro_f1.log",
|
||||||
"log_training_metric": True,
|
"log_training_metric": True,
|
||||||
"n_jobs": 1,
|
"n_jobs": 1,
|
||||||
"model_history": True
|
"model_history": True
|
||||||
}
|
}
|
||||||
|
|
||||||
X_train, y_train = load_iris(return_X_y=True)
|
X_train, y_train = load_iris(return_X_y=True)
|
||||||
|
0
test/tune/__init__.py
Normal file
0
test/tune/__init__.py
Normal file
@ -1,19 +1,21 @@
|
|||||||
'''Require: pip install flaml[test,ray]
|
'''Require: pip install flaml[test,ray]
|
||||||
'''
|
'''
|
||||||
import unittest
|
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
import sklearn.metrics
|
import sklearn.metrics
|
||||||
import sklearn.datasets
|
import sklearn.datasets
|
||||||
try:
|
try:
|
||||||
from ray.tune.integration.xgboost import TuneReportCheckpointCallback
|
from ray.tune.integration.xgboost import TuneReportCheckpointCallback
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("skip test_tune because ray tune cannot be imported.")
|
print("skip test_xgboost because ray tune cannot be imported.")
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.addHandler(logging.FileHandler('test/tune_xgboost.log'))
|
os.makedirs('logs', exist_ok=True)
|
||||||
|
logger.addHandler(logging.FileHandler('logs/tune_xgboost.log'))
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def train_breast_cancer(config: dict):
|
def train_breast_cancer(config: dict):
|
||||||
@ -61,6 +63,7 @@ def _test_xgboost(method='BlendSearch'):
|
|||||||
for n_cpu in [8]:
|
for n_cpu in [8]:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
ray.init(num_cpus=n_cpu, num_gpus=0)
|
ray.init(num_cpus=n_cpu, num_gpus=0)
|
||||||
|
# ray.init(address='auto')
|
||||||
if method == 'BlendSearch':
|
if method == 'BlendSearch':
|
||||||
analysis = tune.run(
|
analysis = tune.run(
|
||||||
train_breast_cancer,
|
train_breast_cancer,
|
||||||
@ -163,21 +166,28 @@ def test_nested():
|
|||||||
}
|
}
|
||||||
|
|
||||||
def simple_func(config):
|
def simple_func(config):
|
||||||
tune.report(metric=(config["cost_related"]["a"] - 4)**2
|
obj = (config["cost_related"]["a"] - 4)**2 \
|
||||||
* (config["b"] - 0.7)**2)
|
+ (config["b"] - config["cost_related"]["a"])**2
|
||||||
|
tune.report(obj=obj)
|
||||||
|
tune.report(obj=obj, ab=config["cost_related"]["a"] * config["b"])
|
||||||
|
|
||||||
tune.run(
|
analysis = tune.run(
|
||||||
simple_func,
|
simple_func,
|
||||||
config=search_space,
|
config=search_space,
|
||||||
low_cost_partial_config={
|
low_cost_partial_config={
|
||||||
"cost_related": {"a": 1}
|
"cost_related": {"a": 1}
|
||||||
},
|
},
|
||||||
metric="metric",
|
metric="obj",
|
||||||
mode="min",
|
mode="min",
|
||||||
|
metric_constraints=[("ab", "<=", 4)],
|
||||||
local_dir='logs/',
|
local_dir='logs/',
|
||||||
num_samples=-1,
|
num_samples=-1,
|
||||||
time_budget_s=1)
|
time_budget_s=1)
|
||||||
|
|
||||||
|
best_trial = analysis.get_best_trial()
|
||||||
|
logger.info(f"Best config: {best_trial.config}")
|
||||||
|
logger.info(f"Best result: {best_trial.last_result}")
|
||||||
|
|
||||||
|
|
||||||
def test_xgboost_bs():
|
def test_xgboost_bs():
|
||||||
_test_xgboost()
|
_test_xgboost()
|
||||||
@ -224,4 +234,4 @@ def _test_xgboost_bohb():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
test_xgboost_bs()
|
Loading…
x
Reference in New Issue
Block a user