2021-11-06 09:37:33 -07:00
|
|
|
# !
|
|
|
|
# * Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
# * Licensed under the MIT License. See LICENSE file in the
|
|
|
|
# * project root for license information.
|
2021-09-14 18:36:10 -07:00
|
|
|
from typing import Dict, Optional, List, Tuple, Callable, Union
|
2021-02-05 21:41:14 -08:00
|
|
|
import numpy as np
|
|
|
|
import time
|
|
|
|
import pickle
|
2021-03-16 22:13:35 -07:00
|
|
|
|
2021-09-01 16:25:04 -07:00
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
try:
|
2021-08-12 02:02:22 -04:00
|
|
|
from ray import __version__ as ray_version
|
2021-09-14 18:36:10 -07:00
|
|
|
|
|
|
|
assert ray_version >= "1.0.0"
|
2021-02-05 21:41:14 -08:00
|
|
|
from ray.tune.suggest import Searcher
|
|
|
|
from ray.tune.suggest.optuna import OptunaSearch as GlobalSearch
|
2021-08-12 02:02:22 -04:00
|
|
|
except (ImportError, AssertionError):
|
2021-06-07 19:49:45 -04:00
|
|
|
from .suggestion import Searcher
|
|
|
|
from .suggestion import OptunaSearch as GlobalSearch
|
2021-09-04 01:42:21 -07:00
|
|
|
from ..tune.trial import unflatten_dict, flatten_dict
|
2021-12-16 13:12:47 +08:00
|
|
|
from ..tune import INCUMBENT_RESULT
|
2021-02-05 21:41:14 -08:00
|
|
|
from .search_thread import SearchThread
|
2021-07-05 21:17:26 -04:00
|
|
|
from .flow2 import FLOW2
|
2021-09-14 18:36:10 -07:00
|
|
|
from ..tune.space import add_cost_to_space, indexof, normalize, define_by_run_func
|
2021-02-05 21:41:14 -08:00
|
|
|
import logging
|
2021-09-14 18:36:10 -07:00
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class BlendSearch(Searcher):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""class for BlendSearch algorithm."""
|
2021-02-05 21:41:14 -08:00
|
|
|
|
2021-04-08 09:29:55 -07:00
|
|
|
cost_attr = "time_total_s" # cost attribute in result
|
2021-09-14 18:36:10 -07:00
|
|
|
lagrange = "_lagrange" # suffix for lagrange-modified metric
|
|
|
|
penalty = 1e10 # penalty term for constraints
|
2021-07-05 21:17:26 -04:00
|
|
|
LocalSearch = FLOW2
|
2021-02-28 12:43:43 -08:00
|
|
|
|
2021-09-14 18:36:10 -07:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
metric: Optional[str] = None,
|
|
|
|
mode: Optional[str] = None,
|
|
|
|
space: Optional[dict] = None,
|
|
|
|
low_cost_partial_config: Optional[dict] = None,
|
|
|
|
cat_hp_cost: Optional[dict] = None,
|
|
|
|
points_to_evaluate: Optional[List[dict]] = None,
|
|
|
|
evaluated_rewards: Optional[List] = None,
|
|
|
|
time_budget_s: Union[int, float] = None,
|
|
|
|
num_samples: Optional[int] = None,
|
2021-12-04 21:52:20 -05:00
|
|
|
resource_attr: Optional[str] = None,
|
2021-09-14 18:36:10 -07:00
|
|
|
min_resource: Optional[float] = None,
|
|
|
|
max_resource: Optional[float] = None,
|
|
|
|
reduction_factor: Optional[float] = None,
|
|
|
|
global_search_alg: Optional[Searcher] = None,
|
|
|
|
config_constraints: Optional[
|
|
|
|
List[Tuple[Callable[[dict], float], str, float]]
|
|
|
|
] = None,
|
|
|
|
metric_constraints: Optional[List[Tuple[str, str, float]]] = None,
|
|
|
|
seed: Optional[int] = 20,
|
|
|
|
experimental: Optional[bool] = False,
|
2021-12-16 13:12:47 +08:00
|
|
|
use_incumbent_result_in_evaluation=False,
|
2021-09-14 18:36:10 -07:00
|
|
|
):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""Constructor.
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
metric: A string of the metric name to optimize for.
|
|
|
|
mode: A string in ['min', 'max'] to specify the objective as
|
2021-06-02 22:08:24 -04:00
|
|
|
minimization or maximization.
|
2021-02-05 21:41:14 -08:00
|
|
|
space: A dictionary to specify the search space.
|
2021-04-08 09:29:55 -07:00
|
|
|
low_cost_partial_config: A dictionary from a subset of
|
2021-04-06 11:37:52 -07:00
|
|
|
controlled dimensions to the initial low-cost values.
|
2021-04-08 09:29:55 -07:00
|
|
|
e.g.,
|
2021-04-06 11:37:52 -07:00
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
.. code-block:: python
|
|
|
|
|
2021-04-06 11:37:52 -07:00
|
|
|
{'n_estimators': 4, 'max_leaves': 4}
|
2021-04-08 09:29:55 -07:00
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
cat_hp_cost: A dictionary from a subset of categorical dimensions
|
2021-04-08 09:29:55 -07:00
|
|
|
to the relative cost of each choice.
|
2021-02-05 21:41:14 -08:00
|
|
|
e.g.,
|
2021-04-08 09:29:55 -07:00
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
{'tree_method': [1, 1, 2]}
|
2021-04-08 09:29:55 -07:00
|
|
|
|
|
|
|
i.e., the relative cost of the
|
2021-02-05 21:41:14 -08:00
|
|
|
three choices of 'tree_method' is 1, 1 and 2 respectively.
|
2021-09-04 01:42:21 -07:00
|
|
|
points_to_evaluate: Initial parameter suggestions to be run first.
|
|
|
|
evaluated_rewards (list): If you have previously evaluated the
|
|
|
|
parameters passed in as points_to_evaluate you can avoid
|
|
|
|
re-running those trials by passing in the reward attributes
|
|
|
|
as a list so the optimiser can be told the results without
|
|
|
|
needing to re-compute the trial. Must be the same length as
|
|
|
|
points_to_evaluate.
|
2021-09-14 18:36:10 -07:00
|
|
|
time_budget_s: int or float | Time budget in seconds.
|
|
|
|
num_samples: int | The number of configs to try.
|
2021-12-04 21:52:20 -05:00
|
|
|
resource_attr: A string to specify the resource dimension and the best
|
|
|
|
performance is assumed to be at the max_resource.
|
|
|
|
min_resource: A float of the minimal resource to use for the resource_attr.
|
|
|
|
max_resource: A float of the maximal resource to use for the resource_attr.
|
2021-02-05 21:41:14 -08:00
|
|
|
reduction_factor: A float of the reduction factor used for
|
|
|
|
incremental pruning.
|
|
|
|
global_search_alg: A Searcher instance as the global search
|
|
|
|
instance. If omitted, Optuna is used. The following algos have
|
|
|
|
known issues when used as global_search_alg:
|
|
|
|
- HyperOptSearch raises exception sometimes
|
|
|
|
- TuneBOHB has its own scheduler
|
2021-05-18 15:57:42 -07:00
|
|
|
config_constraints: A list of config constraints to be satisfied.
|
|
|
|
e.g.,
|
|
|
|
|
|
|
|
.. code-block: python
|
|
|
|
|
|
|
|
config_constraints = [(mem_size, '<=', 1024**3)]
|
|
|
|
|
|
|
|
mem_size is a function which produces a float number for the bytes
|
|
|
|
needed for a config.
|
|
|
|
It is used to skip configs which do not fit in memory.
|
|
|
|
metric_constraints: A list of metric constraints to be satisfied.
|
|
|
|
e.g., `['precision', '>=', 0.9]`
|
2021-03-19 23:31:35 -04:00
|
|
|
seed: An integer of the random seed.
|
2021-08-02 19:10:26 -04:00
|
|
|
experimental: A bool of whether to use experimental features.
|
2021-09-14 18:36:10 -07:00
|
|
|
"""
|
2021-02-05 21:41:14 -08:00
|
|
|
self._metric, self._mode = metric, mode
|
2021-12-16 13:12:47 +08:00
|
|
|
self._use_incumbent_result_in_evaluation = use_incumbent_result_in_evaluation
|
2021-04-06 11:37:52 -07:00
|
|
|
init_config = low_cost_partial_config or {}
|
2021-06-02 22:08:24 -04:00
|
|
|
if not init_config:
|
2021-10-08 16:09:43 -07:00
|
|
|
logger.info(
|
2021-06-04 10:31:33 -07:00
|
|
|
"No low-cost partial config given to the search algorithm. "
|
2021-06-02 22:08:24 -04:00
|
|
|
"For cost-frugal search, "
|
2021-06-04 10:31:33 -07:00
|
|
|
"consider providing low-cost values for cost-related hps via "
|
2021-10-08 16:09:43 -07:00
|
|
|
"'low_cost_partial_config'. More info can be found at "
|
|
|
|
"https://github.com/microsoft/FLAML/wiki/About-%60low_cost_partial_config%60"
|
2021-06-02 22:08:24 -04:00
|
|
|
)
|
2021-09-04 01:42:21 -07:00
|
|
|
if evaluated_rewards and mode:
|
|
|
|
self._points_to_evaluate = []
|
|
|
|
self._evaluated_rewards = []
|
2021-09-14 18:36:10 -07:00
|
|
|
best = max(evaluated_rewards) if mode == "max" else min(evaluated_rewards)
|
2021-09-04 01:42:21 -07:00
|
|
|
# only keep the best points as start points
|
|
|
|
for i, r in enumerate(evaluated_rewards):
|
|
|
|
if r == best:
|
|
|
|
p = points_to_evaluate[i]
|
|
|
|
self._points_to_evaluate.append(p)
|
|
|
|
self._evaluated_rewards.append(r)
|
|
|
|
else:
|
|
|
|
self._points_to_evaluate = points_to_evaluate or []
|
|
|
|
self._evaluated_rewards = evaluated_rewards or []
|
2021-05-22 08:51:38 -07:00
|
|
|
self._config_constraints = config_constraints
|
|
|
|
self._metric_constraints = metric_constraints
|
|
|
|
if self._metric_constraints:
|
|
|
|
# metric modified by lagrange
|
|
|
|
metric += self.lagrange
|
2021-08-12 02:02:22 -04:00
|
|
|
self._cat_hp_cost = cat_hp_cost or {}
|
|
|
|
if space:
|
|
|
|
add_cost_to_space(space, init_config, self._cat_hp_cost)
|
2021-09-04 01:42:21 -07:00
|
|
|
self._ls = self.LocalSearch(
|
2021-09-14 18:36:10 -07:00
|
|
|
init_config,
|
|
|
|
metric,
|
|
|
|
mode,
|
|
|
|
space,
|
2021-12-04 21:52:20 -05:00
|
|
|
resource_attr,
|
2021-09-14 18:36:10 -07:00
|
|
|
min_resource,
|
|
|
|
max_resource,
|
|
|
|
reduction_factor,
|
|
|
|
self.cost_attr,
|
|
|
|
seed,
|
|
|
|
)
|
2021-02-05 21:41:14 -08:00
|
|
|
if global_search_alg is not None:
|
|
|
|
self._gs = global_search_alg
|
2021-09-14 18:36:10 -07:00
|
|
|
elif getattr(self, "__name__", None) != "CFO":
|
2021-09-04 01:42:21 -07:00
|
|
|
if space and self._ls.hierarchical:
|
|
|
|
from functools import partial
|
2021-09-14 18:36:10 -07:00
|
|
|
|
2021-09-04 01:42:21 -07:00
|
|
|
gs_space = partial(define_by_run_func, space=space)
|
2021-09-14 18:36:10 -07:00
|
|
|
evaluated_rewards = None # not supproted by define-by-run
|
2021-09-04 01:42:21 -07:00
|
|
|
else:
|
|
|
|
gs_space = space
|
|
|
|
gs_seed = seed - 10 if (seed - 10) >= 0 else seed - 11 + (1 << 32)
|
|
|
|
if experimental:
|
|
|
|
import optuna as ot
|
2021-09-14 18:36:10 -07:00
|
|
|
|
2021-09-04 01:42:21 -07:00
|
|
|
sampler = ot.samplers.TPESampler(
|
2021-09-14 18:36:10 -07:00
|
|
|
seed=seed, multivariate=True, group=True
|
|
|
|
)
|
2021-09-04 01:42:21 -07:00
|
|
|
else:
|
|
|
|
sampler = None
|
2021-06-07 19:49:45 -04:00
|
|
|
try:
|
2021-09-04 01:42:21 -07:00
|
|
|
self._gs = GlobalSearch(
|
2021-09-14 18:36:10 -07:00
|
|
|
space=gs_space,
|
|
|
|
metric=metric,
|
|
|
|
mode=mode,
|
|
|
|
seed=gs_seed,
|
|
|
|
sampler=sampler,
|
|
|
|
points_to_evaluate=points_to_evaluate,
|
|
|
|
evaluated_rewards=evaluated_rewards,
|
|
|
|
)
|
2021-09-04 01:42:21 -07:00
|
|
|
except ValueError:
|
2021-08-02 19:10:26 -04:00
|
|
|
self._gs = GlobalSearch(
|
2021-09-14 18:36:10 -07:00
|
|
|
space=gs_space,
|
|
|
|
metric=metric,
|
|
|
|
mode=mode,
|
|
|
|
seed=gs_seed,
|
|
|
|
sampler=sampler,
|
|
|
|
)
|
2021-08-12 02:02:22 -04:00
|
|
|
self._gs.space = space
|
2021-02-05 21:41:14 -08:00
|
|
|
else:
|
|
|
|
self._gs = None
|
2021-08-02 19:10:26 -04:00
|
|
|
self._experimental = experimental
|
2021-09-14 18:36:10 -07:00
|
|
|
if (
|
|
|
|
getattr(self, "__name__", None) == "CFO"
|
|
|
|
and points_to_evaluate
|
|
|
|
and len(self._points_to_evaluate) > 1
|
|
|
|
):
|
2021-07-31 16:39:31 -04:00
|
|
|
# use the best config in points_to_evaluate as the start point
|
|
|
|
self._candidate_start_points = {}
|
|
|
|
self._started_from_low_cost = not low_cost_partial_config
|
|
|
|
else:
|
|
|
|
self._candidate_start_points = None
|
2021-09-14 18:36:10 -07:00
|
|
|
self._time_budget_s, self._num_samples = time_budget_s, num_samples
|
2021-11-20 20:05:28 -08:00
|
|
|
if space is not None:
|
2021-08-23 16:26:46 -04:00
|
|
|
self._init_search()
|
2021-04-08 09:29:55 -07:00
|
|
|
|
2021-09-14 18:36:10 -07:00
|
|
|
def set_search_properties(
|
|
|
|
self,
|
|
|
|
metric: Optional[str] = None,
|
|
|
|
mode: Optional[str] = None,
|
|
|
|
config: Optional[Dict] = None,
|
2021-10-08 16:09:43 -07:00
|
|
|
setting: Optional[Dict] = None,
|
2021-09-14 18:36:10 -07:00
|
|
|
) -> bool:
|
2021-07-31 16:39:31 -04:00
|
|
|
metric_changed = mode_changed = False
|
|
|
|
if metric and self._metric != metric:
|
|
|
|
metric_changed = True
|
|
|
|
self._metric = metric
|
|
|
|
if self._metric_constraints:
|
|
|
|
# metric modified by lagrange
|
|
|
|
metric += self.lagrange
|
|
|
|
# TODO: don't change metric for global search methods that
|
|
|
|
# can handle constraints already
|
|
|
|
if mode and self._mode != mode:
|
|
|
|
mode_changed = True
|
|
|
|
self._mode = mode
|
2021-06-07 19:49:45 -04:00
|
|
|
if not self._ls.space:
|
2021-07-31 16:39:31 -04:00
|
|
|
# the search space can be set only once
|
2021-02-05 23:42:28 -08:00
|
|
|
if self._gs is not None:
|
2021-09-04 01:42:21 -07:00
|
|
|
# define-by-run is not supported via set_search_properties
|
2021-02-05 23:42:28 -08:00
|
|
|
self._gs.set_search_properties(metric, mode, config)
|
2021-08-12 02:02:22 -04:00
|
|
|
self._gs.space = config
|
|
|
|
if config:
|
2021-09-14 18:36:10 -07:00
|
|
|
add_cost_to_space(config, self._ls.init_config, self._cat_hp_cost)
|
2021-08-12 02:02:22 -04:00
|
|
|
self._ls.set_search_properties(metric, mode, config)
|
2021-02-05 21:41:14 -08:00
|
|
|
self._init_search()
|
2021-09-14 18:36:10 -07:00
|
|
|
else:
|
|
|
|
if metric_changed or mode_changed:
|
|
|
|
# reset search when metric or mode changed
|
|
|
|
self._ls.set_search_properties(metric, mode)
|
|
|
|
if self._gs is not None:
|
|
|
|
self._gs = GlobalSearch(
|
|
|
|
space=self._gs._space,
|
|
|
|
metric=metric,
|
|
|
|
mode=mode,
|
|
|
|
sampler=self._gs._sampler,
|
|
|
|
)
|
|
|
|
self._gs.space = self._ls.space
|
|
|
|
self._init_search()
|
2021-10-08 16:09:43 -07:00
|
|
|
if setting:
|
|
|
|
# CFO doesn't need these settings
|
|
|
|
if "time_budget_s" in setting:
|
|
|
|
self._time_budget_s = setting["time_budget_s"] # budget from now
|
|
|
|
now = time.time()
|
|
|
|
self._time_used += now - self._start_time
|
|
|
|
self._start_time = now
|
|
|
|
self._set_deadline()
|
|
|
|
if "metric_target" in setting:
|
|
|
|
self._metric_target = setting.get("metric_target")
|
|
|
|
if "num_samples" in setting:
|
|
|
|
self._num_samples = (
|
|
|
|
setting["num_samples"]
|
|
|
|
+ len(self._result)
|
|
|
|
+ len(self._trial_proposed_by)
|
|
|
|
)
|
2021-02-05 21:41:14 -08:00
|
|
|
return True
|
|
|
|
|
2021-09-14 18:36:10 -07:00
|
|
|
def _set_deadline(self):
|
|
|
|
if self._time_budget_s is not None:
|
|
|
|
self._deadline = self._time_budget_s + self._start_time
|
|
|
|
SearchThread.set_eps(self._time_budget_s)
|
|
|
|
else:
|
|
|
|
self._deadline = np.inf
|
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
def _init_search(self):
|
2021-09-14 18:36:10 -07:00
|
|
|
"""initialize the search"""
|
|
|
|
self._start_time = time.time()
|
|
|
|
self._time_used = 0
|
|
|
|
self._set_deadline()
|
2021-09-04 01:42:21 -07:00
|
|
|
self._is_ls_ever_converged = False
|
2021-09-14 18:36:10 -07:00
|
|
|
self._subspace = {} # the subspace for each trial id
|
2021-02-05 21:41:14 -08:00
|
|
|
self._metric_target = np.inf * self._ls.metric_op
|
|
|
|
self._search_thread_pool = {
|
|
|
|
# id: int -> thread: SearchThread
|
|
|
|
0: SearchThread(self._ls.mode, self._gs)
|
2021-04-08 09:29:55 -07:00
|
|
|
}
|
|
|
|
self._thread_count = 1 # total # threads created
|
2021-02-05 21:41:14 -08:00
|
|
|
self._init_used = self._ls.init_config is None
|
2021-04-08 09:29:55 -07:00
|
|
|
self._trial_proposed_by = {} # trial_id: str -> thread_id: int
|
2021-08-12 02:02:22 -04:00
|
|
|
self._ls_bound_min = normalize(
|
2021-09-14 18:36:10 -07:00
|
|
|
self._ls.init_config.copy(),
|
|
|
|
self._ls.space,
|
|
|
|
self._ls.init_config,
|
|
|
|
{},
|
|
|
|
recursive=True,
|
|
|
|
)
|
2021-10-15 21:36:42 -07:00
|
|
|
self._ls_bound_max = normalize(
|
|
|
|
self._ls.init_config.copy(),
|
|
|
|
self._ls.space,
|
|
|
|
self._ls.init_config,
|
|
|
|
{},
|
|
|
|
recursive=True,
|
|
|
|
)
|
2021-03-05 23:39:14 -08:00
|
|
|
self._gs_admissible_min = self._ls_bound_min.copy()
|
|
|
|
self._gs_admissible_max = self._ls_bound_max.copy()
|
2021-04-08 09:29:55 -07:00
|
|
|
self._result = {} # config_signature: tuple -> result: Dict
|
2021-05-22 08:51:38 -07:00
|
|
|
if self._metric_constraints:
|
|
|
|
self._metric_constraint_satisfied = False
|
|
|
|
self._metric_constraint_penalty = [
|
2021-09-14 18:36:10 -07:00
|
|
|
self.penalty for _ in self._metric_constraints
|
|
|
|
]
|
2021-05-22 08:51:38 -07:00
|
|
|
else:
|
|
|
|
self._metric_constraint_satisfied = True
|
|
|
|
self._metric_constraint_penalty = None
|
2021-09-04 01:42:21 -07:00
|
|
|
self.best_resource = self._ls.min_resource
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
def save(self, checkpoint_path: str):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""save states to a checkpoint path."""
|
2021-09-14 18:36:10 -07:00
|
|
|
self._time_used += time.time() - self._start_time
|
|
|
|
self._start_time = time.time()
|
2021-03-16 22:13:35 -07:00
|
|
|
save_object = self
|
2021-02-05 21:41:14 -08:00
|
|
|
with open(checkpoint_path, "wb") as outputFile:
|
|
|
|
pickle.dump(save_object, outputFile)
|
2021-04-08 09:29:55 -07:00
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
def restore(self, checkpoint_path: str):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""restore states from checkpoint."""
|
2021-02-05 21:41:14 -08:00
|
|
|
with open(checkpoint_path, "rb") as inputFile:
|
2021-03-16 22:13:35 -07:00
|
|
|
state = pickle.load(inputFile)
|
2021-08-12 02:02:22 -04:00
|
|
|
self.__dict__ = state.__dict__
|
2021-09-14 18:36:10 -07:00
|
|
|
self._start_time = time.time()
|
|
|
|
self._set_deadline()
|
2021-02-05 21:41:14 -08:00
|
|
|
|
2021-06-02 22:08:24 -04:00
|
|
|
@property
|
|
|
|
def metric_target(self):
|
|
|
|
return self._metric_target
|
|
|
|
|
2021-08-12 02:02:22 -04:00
|
|
|
@property
|
|
|
|
def is_ls_ever_converged(self):
|
|
|
|
return self._is_ls_ever_converged
|
|
|
|
|
2021-09-14 18:36:10 -07:00
|
|
|
def on_trial_complete(
|
|
|
|
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
|
|
|
):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""search thread updater and cleaner."""
|
2021-05-22 08:51:38 -07:00
|
|
|
metric_constraint_satisfied = True
|
2021-05-18 15:57:42 -07:00
|
|
|
if result and not error and self._metric_constraints:
|
2021-05-22 08:51:38 -07:00
|
|
|
# account for metric constraints if any
|
2021-05-18 15:57:42 -07:00
|
|
|
objective = result[self._metric]
|
2021-05-22 08:51:38 -07:00
|
|
|
for i, constraint in enumerate(self._metric_constraints):
|
2021-05-18 15:57:42 -07:00
|
|
|
metric_constraint, sign, threshold = constraint
|
|
|
|
value = result.get(metric_constraint)
|
|
|
|
if value:
|
|
|
|
# sign is <= or >=
|
2021-09-14 18:36:10 -07:00
|
|
|
sign_op = 1 if sign == "<=" else -1
|
2021-05-18 15:57:42 -07:00
|
|
|
violation = (value - threshold) * sign_op
|
|
|
|
if violation > 0:
|
|
|
|
# add penalty term to the metric
|
2021-09-14 18:36:10 -07:00
|
|
|
objective += (
|
|
|
|
self._metric_constraint_penalty[i]
|
|
|
|
* violation
|
|
|
|
* self._ls.metric_op
|
|
|
|
)
|
2021-05-22 08:51:38 -07:00
|
|
|
metric_constraint_satisfied = False
|
|
|
|
if self._metric_constraint_penalty[i] < self.penalty:
|
|
|
|
self._metric_constraint_penalty[i] += violation
|
|
|
|
result[self._metric + self.lagrange] = objective
|
|
|
|
if metric_constraint_satisfied and not self._metric_constraint_satisfied:
|
|
|
|
# found a feasible point
|
|
|
|
self._metric_constraint_penalty = [1 for _ in self._metric_constraints]
|
|
|
|
self._metric_constraint_satisfied |= metric_constraint_satisfied
|
2021-02-05 21:41:14 -08:00
|
|
|
thread_id = self._trial_proposed_by.get(trial_id)
|
2021-04-08 09:29:55 -07:00
|
|
|
if thread_id in self._search_thread_pool:
|
2021-02-05 21:41:14 -08:00
|
|
|
self._search_thread_pool[thread_id].on_trial_complete(
|
2021-09-14 18:36:10 -07:00
|
|
|
trial_id, result, error
|
|
|
|
)
|
2021-02-05 21:41:14 -08:00
|
|
|
del self._trial_proposed_by[trial_id]
|
|
|
|
if result:
|
2021-09-14 18:36:10 -07:00
|
|
|
config = result.get("config", {})
|
2021-09-04 01:42:21 -07:00
|
|
|
if not config:
|
|
|
|
for key, value in result.items():
|
2021-09-14 18:36:10 -07:00
|
|
|
if key.startswith("config/"):
|
2021-09-04 01:42:21 -07:00
|
|
|
config[key[7:]] = value
|
2021-08-12 02:02:22 -04:00
|
|
|
signature = self._ls.config_signature(
|
2021-09-14 18:36:10 -07:00
|
|
|
config, self._subspace.get(trial_id, {})
|
|
|
|
)
|
2021-04-08 09:29:55 -07:00
|
|
|
if error: # remove from result cache
|
2021-08-12 02:02:22 -04:00
|
|
|
del self._result[signature]
|
2021-04-08 09:29:55 -07:00
|
|
|
else: # add to result cache
|
2021-08-12 02:02:22 -04:00
|
|
|
self._result[signature] = result
|
2021-05-18 15:57:42 -07:00
|
|
|
# update target metric if improved
|
2021-07-31 16:39:31 -04:00
|
|
|
objective = result[self._ls.metric]
|
2021-05-18 15:57:42 -07:00
|
|
|
if (objective - self._metric_target) * self._ls.metric_op < 0:
|
|
|
|
self._metric_target = objective
|
2021-09-04 01:42:21 -07:00
|
|
|
if self._ls.resource:
|
2021-12-04 21:52:20 -05:00
|
|
|
self._best_resource = config[self._ls.resource_attr]
|
2021-08-02 19:10:26 -04:00
|
|
|
if thread_id:
|
|
|
|
if not self._metric_constraint_satisfied:
|
|
|
|
# no point has been found to satisfy metric constraint
|
2021-08-12 02:02:22 -04:00
|
|
|
self._expand_admissible_region(
|
2021-09-14 18:36:10 -07:00
|
|
|
self._ls_bound_min,
|
|
|
|
self._ls_bound_max,
|
|
|
|
self._subspace.get(trial_id, self._ls.space),
|
|
|
|
)
|
|
|
|
if (
|
|
|
|
self._gs is not None
|
|
|
|
and self._experimental
|
|
|
|
and (not self._ls.hierarchical)
|
|
|
|
):
|
|
|
|
self._gs.add_evaluated_point(flatten_dict(config), objective)
|
2021-09-04 01:42:21 -07:00
|
|
|
# TODO: recover when supported
|
|
|
|
# converted = convert_key(config, self._gs.space)
|
|
|
|
# logger.info(converted)
|
|
|
|
# self._gs.add_evaluated_point(converted, objective)
|
2021-09-14 18:36:10 -07:00
|
|
|
elif metric_constraint_satisfied and self._create_condition(result):
|
2021-05-18 15:57:42 -07:00
|
|
|
# thread creator
|
|
|
|
thread_id = self._thread_count
|
2021-09-14 18:36:10 -07:00
|
|
|
self._started_from_given = (
|
|
|
|
self._candidate_start_points
|
2021-07-31 16:39:31 -04:00
|
|
|
and trial_id in self._candidate_start_points
|
2021-09-14 18:36:10 -07:00
|
|
|
)
|
2021-07-31 16:39:31 -04:00
|
|
|
if self._started_from_given:
|
|
|
|
del self._candidate_start_points[trial_id]
|
|
|
|
else:
|
|
|
|
self._started_from_low_cost = True
|
2021-09-14 18:36:10 -07:00
|
|
|
self._create_thread(
|
|
|
|
config, result, self._subspace.get(trial_id, self._ls.space)
|
|
|
|
)
|
2021-05-18 15:57:42 -07:00
|
|
|
# reset admissible region to ls bounding box
|
|
|
|
self._gs_admissible_min.update(self._ls_bound_min)
|
|
|
|
self._gs_admissible_max.update(self._ls_bound_max)
|
2021-02-05 21:41:14 -08:00
|
|
|
# cleaner
|
|
|
|
if thread_id and thread_id in self._search_thread_pool:
|
|
|
|
# local search thread
|
|
|
|
self._clean(thread_id)
|
2021-09-14 18:36:10 -07:00
|
|
|
if trial_id in self._subspace and not (
|
|
|
|
self._candidate_start_points and trial_id in self._candidate_start_points
|
|
|
|
):
|
2021-08-12 02:02:22 -04:00
|
|
|
del self._subspace[trial_id]
|
2021-02-05 21:41:14 -08:00
|
|
|
|
2021-08-12 02:02:22 -04:00
|
|
|
def _create_thread(self, config, result, space):
|
2021-07-31 16:39:31 -04:00
|
|
|
self._search_thread_pool[self._thread_count] = SearchThread(
|
|
|
|
self._ls.mode,
|
|
|
|
self._ls.create(
|
2021-09-14 18:36:10 -07:00
|
|
|
config,
|
|
|
|
result[self._ls.metric],
|
|
|
|
cost=result.get(self.cost_attr, 1),
|
|
|
|
space=space,
|
|
|
|
),
|
|
|
|
self.cost_attr,
|
2021-07-31 16:39:31 -04:00
|
|
|
)
|
|
|
|
self._thread_count += 1
|
|
|
|
self._update_admissible_region(
|
2021-09-14 18:36:10 -07:00
|
|
|
unflatten_dict(config),
|
|
|
|
self._ls_bound_min,
|
|
|
|
self._ls_bound_max,
|
|
|
|
space,
|
|
|
|
self._ls.space,
|
|
|
|
)
|
2021-07-31 16:39:31 -04:00
|
|
|
|
2021-08-12 02:02:22 -04:00
|
|
|
def _update_admissible_region(
|
2021-09-14 18:36:10 -07:00
|
|
|
self,
|
|
|
|
config,
|
|
|
|
admissible_min,
|
|
|
|
admissible_max,
|
|
|
|
subspace: Dict = {},
|
|
|
|
space: Dict = {},
|
2021-08-12 02:02:22 -04:00
|
|
|
):
|
2021-03-05 23:39:14 -08:00
|
|
|
# update admissible region
|
2021-09-01 16:25:04 -07:00
|
|
|
normalized_config = normalize(config, subspace, config, {})
|
2021-03-05 23:39:14 -08:00
|
|
|
for key in admissible_min:
|
|
|
|
value = normalized_config[key]
|
2021-08-12 02:02:22 -04:00
|
|
|
if isinstance(admissible_max[key], list):
|
2021-09-01 16:25:04 -07:00
|
|
|
domain = space[key]
|
|
|
|
choice = indexof(domain, value)
|
2021-08-12 02:02:22 -04:00
|
|
|
self._update_admissible_region(
|
|
|
|
value,
|
2021-09-14 18:36:10 -07:00
|
|
|
admissible_min[key][choice],
|
|
|
|
admissible_max[key][choice],
|
|
|
|
subspace[key],
|
|
|
|
domain[choice],
|
2021-08-12 02:02:22 -04:00
|
|
|
)
|
2021-10-15 21:36:42 -07:00
|
|
|
if len(admissible_max[key]) > len(domain.categories):
|
|
|
|
# points + index
|
|
|
|
normal = (choice + 0.5) / len(domain.categories)
|
|
|
|
admissible_max[key][-1] = max(normal, admissible_max[key][-1])
|
|
|
|
admissible_min[key][-1] = min(normal, admissible_min[key][-1])
|
2021-08-12 02:02:22 -04:00
|
|
|
elif isinstance(value, dict):
|
|
|
|
self._update_admissible_region(
|
2021-09-14 18:36:10 -07:00
|
|
|
value,
|
|
|
|
admissible_min[key],
|
|
|
|
admissible_max[key],
|
|
|
|
subspace[key],
|
|
|
|
space[key],
|
|
|
|
)
|
2021-08-12 02:02:22 -04:00
|
|
|
else:
|
|
|
|
if value > admissible_max[key]:
|
|
|
|
admissible_max[key] = value
|
|
|
|
elif value < admissible_min[key]:
|
|
|
|
admissible_min[key] = value
|
2021-03-05 23:39:14 -08:00
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
def _create_condition(self, result: Dict) -> bool:
|
2021-09-14 18:36:10 -07:00
|
|
|
"""create thread condition"""
|
2021-04-08 09:29:55 -07:00
|
|
|
if len(self._search_thread_pool) < 2:
|
|
|
|
return True
|
|
|
|
obj_median = np.median(
|
2021-09-14 18:36:10 -07:00
|
|
|
[thread.obj_best1 for id, thread in self._search_thread_pool.items() if id]
|
|
|
|
)
|
2021-07-31 16:39:31 -04:00
|
|
|
return result[self._ls.metric] * self._ls.metric_op < obj_median
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
def _clean(self, thread_id: int):
|
2021-09-14 18:36:10 -07:00
|
|
|
"""delete thread and increase admissible region if converged,
|
2021-02-05 21:41:14 -08:00
|
|
|
merge local threads if they are close
|
2021-09-14 18:36:10 -07:00
|
|
|
"""
|
2021-02-05 21:41:14 -08:00
|
|
|
assert thread_id
|
|
|
|
todelete = set()
|
|
|
|
for id in self._search_thread_pool:
|
2021-04-08 09:29:55 -07:00
|
|
|
if id and id != thread_id:
|
2021-02-05 21:41:14 -08:00
|
|
|
if self._inferior(id, thread_id):
|
|
|
|
todelete.add(id)
|
|
|
|
for id in self._search_thread_pool:
|
2021-04-08 09:29:55 -07:00
|
|
|
if id and id != thread_id:
|
2021-02-05 21:41:14 -08:00
|
|
|
if self._inferior(thread_id, id):
|
|
|
|
todelete.add(thread_id)
|
2021-04-08 09:29:55 -07:00
|
|
|
break
|
2021-07-31 16:39:31 -04:00
|
|
|
create_new = False
|
2021-02-05 21:41:14 -08:00
|
|
|
if self._search_thread_pool[thread_id].converged:
|
2021-08-12 02:02:22 -04:00
|
|
|
self._is_ls_ever_converged = True
|
2021-02-05 21:41:14 -08:00
|
|
|
todelete.add(thread_id)
|
2021-08-12 02:02:22 -04:00
|
|
|
self._expand_admissible_region(
|
2021-09-14 18:36:10 -07:00
|
|
|
self._ls_bound_min,
|
|
|
|
self._ls_bound_max,
|
|
|
|
self._search_thread_pool[thread_id].space,
|
|
|
|
)
|
2021-07-31 16:39:31 -04:00
|
|
|
if self._candidate_start_points:
|
|
|
|
if not self._started_from_given:
|
|
|
|
# remove start points whose perf is worse than the converged
|
|
|
|
obj = self._search_thread_pool[thread_id].obj_best1
|
|
|
|
worse = [
|
|
|
|
trial_id
|
|
|
|
for trial_id, r in self._candidate_start_points.items()
|
2021-09-14 18:36:10 -07:00
|
|
|
if r and r[self._ls.metric] * self._ls.metric_op >= obj
|
|
|
|
]
|
2021-07-31 16:39:31 -04:00
|
|
|
# logger.info(f"remove candidate start points {worse} than {obj}")
|
|
|
|
for trial_id in worse:
|
|
|
|
del self._candidate_start_points[trial_id]
|
|
|
|
if self._candidate_start_points and self._started_from_low_cost:
|
|
|
|
create_new = True
|
2021-02-05 21:41:14 -08:00
|
|
|
for id in todelete:
|
|
|
|
del self._search_thread_pool[id]
|
2021-07-31 16:39:31 -04:00
|
|
|
if create_new:
|
|
|
|
self._create_thread_from_best_candidate()
|
|
|
|
|
|
|
|
def _create_thread_from_best_candidate(self):
|
|
|
|
# find the best start point
|
|
|
|
best_trial_id = None
|
|
|
|
obj_best = None
|
|
|
|
for trial_id, r in self._candidate_start_points.items():
|
2021-09-14 18:36:10 -07:00
|
|
|
if r and (
|
|
|
|
best_trial_id is None
|
|
|
|
or r[self._ls.metric] * self._ls.metric_op < obj_best
|
|
|
|
):
|
2021-07-31 16:39:31 -04:00
|
|
|
best_trial_id = trial_id
|
|
|
|
obj_best = r[self._ls.metric] * self._ls.metric_op
|
|
|
|
if best_trial_id:
|
|
|
|
# create a new thread
|
|
|
|
config = {}
|
|
|
|
result = self._candidate_start_points[best_trial_id]
|
|
|
|
for key, value in result.items():
|
2021-09-14 18:36:10 -07:00
|
|
|
if key.startswith("config/"):
|
2021-07-31 16:39:31 -04:00
|
|
|
config[key[7:]] = value
|
|
|
|
self._started_from_given = True
|
|
|
|
del self._candidate_start_points[best_trial_id]
|
2021-09-14 18:36:10 -07:00
|
|
|
self._create_thread(
|
|
|
|
config, result, self._subspace.get(best_trial_id, self._ls.space)
|
|
|
|
)
|
2021-08-12 02:02:22 -04:00
|
|
|
|
|
|
|
def _expand_admissible_region(self, lower, upper, space):
|
2021-10-15 21:36:42 -07:00
|
|
|
"""expand the admissible region for the subspace `space`"""
|
2021-08-12 02:02:22 -04:00
|
|
|
for key in upper:
|
|
|
|
ub = upper[key]
|
|
|
|
if isinstance(ub, list):
|
2021-09-14 18:36:10 -07:00
|
|
|
choice = space[key]["_choice_"]
|
2021-08-12 02:02:22 -04:00
|
|
|
self._expand_admissible_region(
|
2021-09-14 18:36:10 -07:00
|
|
|
lower[key][choice], upper[key][choice], space[key]
|
|
|
|
)
|
2021-08-12 02:02:22 -04:00
|
|
|
elif isinstance(ub, dict):
|
|
|
|
self._expand_admissible_region(lower[key], ub, space[key])
|
|
|
|
else:
|
|
|
|
upper[key] += self._ls.STEPSIZE
|
|
|
|
lower[key] -= self._ls.STEPSIZE
|
2021-05-07 04:29:38 +00:00
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
def _inferior(self, id1: int, id2: int) -> bool:
|
2021-09-14 18:36:10 -07:00
|
|
|
"""whether thread id1 is inferior to id2"""
|
2021-02-05 21:41:14 -08:00
|
|
|
t1 = self._search_thread_pool[id1]
|
|
|
|
t2 = self._search_thread_pool[id2]
|
2021-04-08 09:29:55 -07:00
|
|
|
if t1.obj_best1 < t2.obj_best2:
|
|
|
|
return False
|
|
|
|
elif t1.resource and t1.resource < t2.resource:
|
|
|
|
return False
|
|
|
|
elif t2.reach(t1):
|
|
|
|
return True
|
|
|
|
return False
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
def on_trial_result(self, trial_id: str, result: Dict):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""receive intermediate result."""
|
2021-04-08 09:29:55 -07:00
|
|
|
if trial_id not in self._trial_proposed_by:
|
|
|
|
return
|
2021-02-05 21:41:14 -08:00
|
|
|
thread_id = self._trial_proposed_by[trial_id]
|
2021-04-08 09:29:55 -07:00
|
|
|
if thread_id not in self._search_thread_pool:
|
|
|
|
return
|
2021-05-22 08:51:38 -07:00
|
|
|
if result and self._metric_constraints:
|
|
|
|
result[self._metric + self.lagrange] = result[self._metric]
|
2021-02-05 21:41:14 -08:00
|
|
|
self._search_thread_pool[thread_id].on_trial_result(trial_id, result)
|
|
|
|
|
|
|
|
def suggest(self, trial_id: str) -> Optional[Dict]:
|
2021-11-06 09:37:33 -07:00
|
|
|
"""choose thread, suggest a valid config."""
|
2021-02-05 21:41:14 -08:00
|
|
|
if self._init_used and not self._points_to_evaluate:
|
|
|
|
choice, backup = self._select_thread()
|
2021-09-04 01:42:21 -07:00
|
|
|
# if choice < 0: # timeout
|
|
|
|
# return None
|
2021-02-05 21:41:14 -08:00
|
|
|
config = self._search_thread_pool[choice].suggest(trial_id)
|
2021-09-04 01:42:21 -07:00
|
|
|
if not choice and config is not None and self._ls.resource:
|
2021-12-04 21:52:20 -05:00
|
|
|
config[self._ls.resource_attr] = self.best_resource
|
2021-09-04 01:42:21 -07:00
|
|
|
elif choice and config is None:
|
2021-05-07 04:29:38 +00:00
|
|
|
# local search thread finishes
|
|
|
|
if self._search_thread_pool[choice].converged:
|
2021-08-12 02:02:22 -04:00
|
|
|
self._expand_admissible_region(
|
2021-09-14 18:36:10 -07:00
|
|
|
self._ls_bound_min,
|
|
|
|
self._ls_bound_max,
|
|
|
|
self._search_thread_pool[choice].space,
|
|
|
|
)
|
2021-05-07 04:29:38 +00:00
|
|
|
del self._search_thread_pool[choice]
|
|
|
|
return None
|
2021-03-05 23:39:14 -08:00
|
|
|
# preliminary check; not checking config validation
|
2021-08-12 02:02:22 -04:00
|
|
|
space = self._search_thread_pool[choice].space
|
|
|
|
skip = self._should_skip(choice, trial_id, config, space)
|
|
|
|
use_rs = 0
|
2021-02-05 21:41:14 -08:00
|
|
|
if skip:
|
2021-04-08 09:29:55 -07:00
|
|
|
if choice:
|
2021-02-05 21:41:14 -08:00
|
|
|
return None
|
2021-03-05 23:39:14 -08:00
|
|
|
# use rs when BO fails to suggest a config
|
2021-08-12 02:02:22 -04:00
|
|
|
config, space = self._ls.complete_config({})
|
|
|
|
skip = self._should_skip(-1, trial_id, config, space)
|
2021-04-08 09:29:55 -07:00
|
|
|
if skip:
|
|
|
|
return None
|
2021-08-12 02:02:22 -04:00
|
|
|
use_rs = 1
|
|
|
|
if choice or self._valid(
|
2021-09-14 18:36:10 -07:00
|
|
|
config,
|
|
|
|
self._ls.space,
|
|
|
|
space,
|
|
|
|
self._gs_admissible_min,
|
|
|
|
self._gs_admissible_max,
|
|
|
|
):
|
2021-02-05 21:41:14 -08:00
|
|
|
# LS or valid or no backup choice
|
|
|
|
self._trial_proposed_by[trial_id] = choice
|
2021-08-12 02:02:22 -04:00
|
|
|
self._search_thread_pool[choice].running += use_rs
|
2021-04-08 09:29:55 -07:00
|
|
|
else: # invalid config proposed by GS
|
2021-03-05 23:39:14 -08:00
|
|
|
if choice == backup:
|
|
|
|
# use CFO's init point
|
|
|
|
init_config = self._ls.init_config
|
2021-08-12 02:02:22 -04:00
|
|
|
config, space = self._ls.complete_config(
|
2021-09-14 18:36:10 -07:00
|
|
|
init_config, self._ls_bound_min, self._ls_bound_max
|
|
|
|
)
|
2021-03-05 23:39:14 -08:00
|
|
|
self._trial_proposed_by[trial_id] = choice
|
2021-08-12 02:02:22 -04:00
|
|
|
self._search_thread_pool[choice].running += 1
|
2021-03-05 23:39:14 -08:00
|
|
|
else:
|
2021-08-12 02:02:22 -04:00
|
|
|
thread = self._search_thread_pool[backup]
|
|
|
|
config = thread.suggest(trial_id)
|
|
|
|
space = thread.space
|
|
|
|
skip = self._should_skip(backup, trial_id, config, space)
|
2021-04-08 09:29:55 -07:00
|
|
|
if skip:
|
2021-03-05 23:39:14 -08:00
|
|
|
return None
|
|
|
|
self._trial_proposed_by[trial_id] = backup
|
|
|
|
choice = backup
|
2021-04-08 09:29:55 -07:00
|
|
|
if not choice: # global search
|
2021-03-05 23:39:14 -08:00
|
|
|
# temporarily relax admissible region for parallel proposals
|
2021-04-08 09:29:55 -07:00
|
|
|
self._update_admissible_region(
|
2021-09-14 18:36:10 -07:00
|
|
|
config,
|
|
|
|
self._gs_admissible_min,
|
|
|
|
self._gs_admissible_max,
|
|
|
|
space,
|
|
|
|
self._ls.space,
|
|
|
|
)
|
2021-03-05 23:39:14 -08:00
|
|
|
else:
|
2021-04-08 09:29:55 -07:00
|
|
|
self._update_admissible_region(
|
2021-09-14 18:36:10 -07:00
|
|
|
config,
|
|
|
|
self._ls_bound_min,
|
|
|
|
self._ls_bound_max,
|
|
|
|
space,
|
|
|
|
self._ls.space,
|
|
|
|
)
|
2021-03-05 23:39:14 -08:00
|
|
|
self._gs_admissible_min.update(self._ls_bound_min)
|
|
|
|
self._gs_admissible_max.update(self._ls_bound_max)
|
2021-08-12 02:02:22 -04:00
|
|
|
signature = self._ls.config_signature(config, space)
|
|
|
|
self._result[signature] = {}
|
|
|
|
self._subspace[trial_id] = space
|
2021-04-08 09:29:55 -07:00
|
|
|
else: # use init config
|
2021-07-31 16:39:31 -04:00
|
|
|
if self._candidate_start_points is not None and self._points_to_evaluate:
|
|
|
|
self._candidate_start_points[trial_id] = None
|
2021-09-04 01:42:21 -07:00
|
|
|
reward = None
|
|
|
|
if self._points_to_evaluate:
|
|
|
|
init_config = self._points_to_evaluate.pop(0)
|
|
|
|
if self._evaluated_rewards:
|
|
|
|
reward = self._evaluated_rewards.pop(0)
|
|
|
|
else:
|
|
|
|
init_config = self._ls.init_config
|
2021-08-12 02:02:22 -04:00
|
|
|
config, space = self._ls.complete_config(
|
2021-09-14 18:36:10 -07:00
|
|
|
init_config, self._ls_bound_min, self._ls_bound_max
|
|
|
|
)
|
2021-09-04 01:42:21 -07:00
|
|
|
if reward is None:
|
|
|
|
config_signature = self._ls.config_signature(config, space)
|
|
|
|
result = self._result.get(config_signature)
|
|
|
|
if result: # tried before
|
|
|
|
return None
|
|
|
|
elif result is None: # not tried before
|
|
|
|
self._result[config_signature] = {}
|
|
|
|
else: # running but no result yet
|
|
|
|
return None
|
2021-02-05 21:41:14 -08:00
|
|
|
self._init_used = True
|
2021-03-05 23:39:14 -08:00
|
|
|
self._trial_proposed_by[trial_id] = 0
|
2021-06-25 14:24:46 -07:00
|
|
|
self._search_thread_pool[0].running += 1
|
2021-08-12 02:02:22 -04:00
|
|
|
self._subspace[trial_id] = space
|
2021-09-04 01:42:21 -07:00
|
|
|
if reward is not None:
|
2021-09-14 18:36:10 -07:00
|
|
|
result = {self._metric: reward, self.cost_attr: 1, "config": config}
|
2021-09-04 01:42:21 -07:00
|
|
|
self.on_trial_complete(trial_id, result)
|
|
|
|
return None
|
2021-12-16 13:12:47 +08:00
|
|
|
if self._use_incumbent_result_in_evaluation:
|
|
|
|
if self._trial_proposed_by[trial_id] > 0:
|
|
|
|
choice_thread = self._search_thread_pool[
|
|
|
|
self._trial_proposed_by[trial_id]
|
|
|
|
]
|
|
|
|
config[INCUMBENT_RESULT] = choice_thread.best_result
|
2021-02-05 21:41:14 -08:00
|
|
|
return config
|
|
|
|
|
2021-08-12 02:02:22 -04:00
|
|
|
def _should_skip(self, choice, trial_id, config, space) -> bool:
|
2021-09-14 18:36:10 -07:00
|
|
|
"""if config is None or config's result is known or constraints are violated
|
|
|
|
return True; o.w. return False
|
|
|
|
"""
|
2021-04-08 09:29:55 -07:00
|
|
|
if config is None:
|
|
|
|
return True
|
2021-08-12 02:02:22 -04:00
|
|
|
config_signature = self._ls.config_signature(config, space)
|
2021-02-05 21:41:14 -08:00
|
|
|
exists = config_signature in self._result
|
2021-05-18 15:57:42 -07:00
|
|
|
# check constraints
|
|
|
|
if not exists and self._config_constraints:
|
|
|
|
for constraint in self._config_constraints:
|
|
|
|
func, sign, threshold = constraint
|
|
|
|
value = func(config)
|
2021-09-14 18:36:10 -07:00
|
|
|
if (
|
|
|
|
sign == "<="
|
|
|
|
and value > threshold
|
|
|
|
or sign == ">="
|
|
|
|
and value < threshold
|
|
|
|
):
|
2021-05-18 15:57:42 -07:00
|
|
|
self._result[config_signature] = {
|
|
|
|
self._metric: np.inf * self._ls.metric_op,
|
2021-09-14 18:36:10 -07:00
|
|
|
"time_total_s": 1,
|
2021-05-18 15:57:42 -07:00
|
|
|
}
|
|
|
|
exists = True
|
|
|
|
break
|
2021-06-25 14:24:46 -07:00
|
|
|
if exists: # suggested before
|
|
|
|
if choice >= 0: # not fallback to rs
|
2021-02-05 21:41:14 -08:00
|
|
|
result = self._result.get(config_signature)
|
2021-06-25 14:24:46 -07:00
|
|
|
if result: # finished
|
2021-02-05 21:41:14 -08:00
|
|
|
self._search_thread_pool[choice].on_trial_complete(
|
2021-09-14 18:36:10 -07:00
|
|
|
trial_id, result, error=False
|
|
|
|
)
|
2021-02-05 21:41:14 -08:00
|
|
|
if choice:
|
|
|
|
# local search thread
|
|
|
|
self._clean(choice)
|
2021-06-25 14:24:46 -07:00
|
|
|
# else: # running
|
2021-03-05 23:39:14 -08:00
|
|
|
# # tell the thread there is an error
|
|
|
|
# self._search_thread_pool[choice].on_trial_complete(
|
2021-04-08 09:29:55 -07:00
|
|
|
# trial_id, {}, error=True)
|
2021-02-05 21:41:14 -08:00
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
def _select_thread(self) -> Tuple:
|
2021-09-14 18:36:10 -07:00
|
|
|
"""thread selector; use can_suggest to check LS availability"""
|
2021-02-05 21:41:14 -08:00
|
|
|
# update priority
|
2021-09-14 18:36:10 -07:00
|
|
|
now = time.time()
|
|
|
|
min_eci = self._deadline - now
|
2021-04-08 09:29:55 -07:00
|
|
|
if min_eci <= 0:
|
2021-07-20 17:00:44 -07:00
|
|
|
# return -1, -1
|
|
|
|
# keep proposing new configs assuming no budget left
|
|
|
|
min_eci = 0
|
2021-09-14 18:36:10 -07:00
|
|
|
elif self._num_samples and self._num_samples > 0:
|
|
|
|
# estimate time left according to num_samples limitation
|
|
|
|
num_finished = len(self._result)
|
|
|
|
num_proposed = num_finished + len(self._trial_proposed_by)
|
|
|
|
num_left = max(self._num_samples - num_proposed, 0)
|
|
|
|
if num_proposed > 0:
|
|
|
|
time_used = now - self._start_time + self._time_used
|
|
|
|
min_eci = min(min_eci, time_used / num_finished * num_left)
|
|
|
|
# print(f"{min_eci}, {time_used / num_finished * num_left}, {num_finished}, {num_left}")
|
2021-02-05 21:41:14 -08:00
|
|
|
max_speed = 0
|
2021-04-08 09:29:55 -07:00
|
|
|
for thread in self._search_thread_pool.values():
|
|
|
|
if thread.speed > max_speed:
|
|
|
|
max_speed = thread.speed
|
|
|
|
for thread in self._search_thread_pool.values():
|
2021-02-05 21:41:14 -08:00
|
|
|
thread.update_eci(self._metric_target, max_speed)
|
2021-04-08 09:29:55 -07:00
|
|
|
if thread.eci < min_eci:
|
|
|
|
min_eci = thread.eci
|
2021-02-05 21:41:14 -08:00
|
|
|
for thread in self._search_thread_pool.values():
|
|
|
|
thread.update_priority(min_eci)
|
|
|
|
|
|
|
|
top_thread_id = backup_thread_id = 0
|
|
|
|
priority1 = priority2 = self._search_thread_pool[0].priority
|
|
|
|
for thread_id, thread in self._search_thread_pool.items():
|
|
|
|
if thread_id and thread.can_suggest:
|
|
|
|
priority = thread.priority
|
2021-04-08 09:29:55 -07:00
|
|
|
if priority > priority1:
|
2021-02-05 21:41:14 -08:00
|
|
|
priority1 = priority
|
|
|
|
top_thread_id = thread_id
|
|
|
|
if priority > priority2 or backup_thread_id == 0:
|
|
|
|
priority2 = priority
|
|
|
|
backup_thread_id = thread_id
|
|
|
|
return top_thread_id, backup_thread_id
|
|
|
|
|
2021-09-14 18:36:10 -07:00
|
|
|
def _valid(
|
|
|
|
self, config: Dict, space: Dict, subspace: Dict, lower: Dict, upper: Dict
|
|
|
|
) -> bool:
|
|
|
|
"""config validator"""
|
2021-09-01 16:25:04 -07:00
|
|
|
normalized_config = normalize(config, subspace, config, {})
|
2021-08-12 02:02:22 -04:00
|
|
|
for key, lb in lower.items():
|
2021-02-05 21:41:14 -08:00
|
|
|
if key in config:
|
2021-03-17 17:51:23 +01:00
|
|
|
value = normalized_config[key]
|
2021-08-12 02:02:22 -04:00
|
|
|
if isinstance(lb, list):
|
2021-09-01 16:25:04 -07:00
|
|
|
domain = space[key]
|
|
|
|
index = indexof(domain, value)
|
|
|
|
nestedspace = subspace[key]
|
|
|
|
lb = lb[index]
|
|
|
|
ub = upper[key][index]
|
2021-08-12 02:02:22 -04:00
|
|
|
elif isinstance(lb, dict):
|
2021-09-01 16:25:04 -07:00
|
|
|
nestedspace = subspace[key]
|
|
|
|
domain = space[key]
|
|
|
|
ub = upper[key]
|
2021-08-12 02:02:22 -04:00
|
|
|
else:
|
2021-09-01 16:25:04 -07:00
|
|
|
nestedspace = None
|
|
|
|
if nestedspace:
|
2021-09-14 18:36:10 -07:00
|
|
|
valid = self._valid(value, domain, nestedspace, lb, ub)
|
2021-08-12 02:02:22 -04:00
|
|
|
if not valid:
|
|
|
|
return False
|
2021-09-14 18:36:10 -07:00
|
|
|
elif (
|
|
|
|
value + self._ls.STEPSIZE < lower[key]
|
|
|
|
or value > upper[key] + self._ls.STEPSIZE
|
|
|
|
):
|
2021-02-05 21:41:14 -08:00
|
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
2021-04-08 09:29:55 -07:00
|
|
|
try:
|
2021-08-12 02:02:22 -04:00
|
|
|
from ray import __version__ as ray_version
|
2021-09-14 18:36:10 -07:00
|
|
|
|
|
|
|
assert ray_version >= "1.0.0"
|
|
|
|
from ray.tune import (
|
|
|
|
uniform,
|
|
|
|
quniform,
|
|
|
|
choice,
|
|
|
|
randint,
|
|
|
|
qrandint,
|
|
|
|
randn,
|
|
|
|
qrandn,
|
|
|
|
loguniform,
|
|
|
|
qloguniform,
|
|
|
|
)
|
2021-08-12 02:02:22 -04:00
|
|
|
except (ImportError, AssertionError):
|
2021-09-14 18:36:10 -07:00
|
|
|
from ..tune.sample import (
|
|
|
|
uniform,
|
|
|
|
quniform,
|
|
|
|
choice,
|
|
|
|
randint,
|
|
|
|
qrandint,
|
|
|
|
randn,
|
|
|
|
qrandn,
|
|
|
|
loguniform,
|
|
|
|
qloguniform,
|
|
|
|
)
|
2021-04-08 09:29:55 -07:00
|
|
|
|
2021-02-28 12:43:43 -08:00
|
|
|
try:
|
|
|
|
from nni.tuner import Tuner as NNITuner
|
|
|
|
from nni.utils import extract_scalar_reward
|
2021-09-04 01:42:21 -07:00
|
|
|
except ImportError:
|
2021-09-14 18:36:10 -07:00
|
|
|
|
2021-09-04 01:42:21 -07:00
|
|
|
class NNITuner:
|
|
|
|
pass
|
2021-02-28 12:43:43 -08:00
|
|
|
|
2021-09-04 01:42:21 -07:00
|
|
|
def extract_scalar_reward(x: Dict):
|
2021-09-14 23:16:28 -07:00
|
|
|
return x.get("default")
|
2021-02-28 12:43:43 -08:00
|
|
|
|
|
|
|
|
2021-09-04 01:42:21 -07:00
|
|
|
class BlendSearchTuner(BlendSearch, NNITuner):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""Tuner class for NNI."""
|
2021-09-04 01:42:21 -07:00
|
|
|
|
2021-09-14 18:36:10 -07:00
|
|
|
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""Receive trial's final result.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
parameter_id: int.
|
|
|
|
parameters: object created by `generate_parameters()`.
|
|
|
|
value: final metrics of the trial, including default metric.
|
2021-09-14 18:36:10 -07:00
|
|
|
"""
|
2021-09-10 16:39:16 -07:00
|
|
|
result = {
|
2021-09-14 18:36:10 -07:00
|
|
|
"config": parameters,
|
|
|
|
self._metric: extract_scalar_reward(value),
|
|
|
|
self.cost_attr: 1
|
|
|
|
if isinstance(value, float)
|
|
|
|
else value.get(self.cost_attr, value.get("sequence", 1))
|
2021-09-10 16:39:16 -07:00
|
|
|
# if nni does not report training cost,
|
|
|
|
# using sequence as an approximation.
|
|
|
|
# if no sequence, using a constant 1
|
|
|
|
}
|
2021-09-04 01:42:21 -07:00
|
|
|
self.on_trial_complete(str(parameter_id), result)
|
2021-09-14 18:36:10 -07:00
|
|
|
|
2021-09-04 01:42:21 -07:00
|
|
|
...
|
|
|
|
|
|
|
|
def generate_parameters(self, parameter_id, **kwargs) -> Dict:
|
2021-11-06 09:37:33 -07:00
|
|
|
"""Returns a set of trial (hyper-)parameters, as a serializable object.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
parameter_id: int.
|
2021-09-14 18:36:10 -07:00
|
|
|
"""
|
2021-09-04 01:42:21 -07:00
|
|
|
return self.suggest(str(parameter_id))
|
2021-09-14 18:36:10 -07:00
|
|
|
|
2021-09-04 01:42:21 -07:00
|
|
|
...
|
|
|
|
|
|
|
|
def update_search_space(self, search_space):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""Required by NNI.
|
|
|
|
|
2021-09-04 01:42:21 -07:00
|
|
|
Tuners are advised to support updating search space at run-time.
|
|
|
|
If a tuner can only set search space once before generating first hyper-parameters,
|
|
|
|
it should explicitly document this behaviour.
|
2021-11-06 09:37:33 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
search_space: JSON object created by experiment owner.
|
2021-09-14 18:36:10 -07:00
|
|
|
"""
|
2021-09-04 01:42:21 -07:00
|
|
|
config = {}
|
|
|
|
for key, value in search_space.items():
|
|
|
|
v = value.get("_value")
|
2021-09-14 18:36:10 -07:00
|
|
|
_type = value["_type"]
|
|
|
|
if _type == "choice":
|
2021-09-04 01:42:21 -07:00
|
|
|
config[key] = choice(v)
|
2021-09-14 18:36:10 -07:00
|
|
|
elif _type == "randint":
|
2021-09-04 01:42:21 -07:00
|
|
|
config[key] = randint(*v)
|
2021-09-14 18:36:10 -07:00
|
|
|
elif _type == "uniform":
|
2021-09-04 01:42:21 -07:00
|
|
|
config[key] = uniform(*v)
|
2021-09-14 18:36:10 -07:00
|
|
|
elif _type == "quniform":
|
2021-09-04 01:42:21 -07:00
|
|
|
config[key] = quniform(*v)
|
2021-09-14 18:36:10 -07:00
|
|
|
elif _type == "loguniform":
|
2021-09-04 01:42:21 -07:00
|
|
|
config[key] = loguniform(*v)
|
2021-09-14 18:36:10 -07:00
|
|
|
elif _type == "qloguniform":
|
2021-09-04 01:42:21 -07:00
|
|
|
config[key] = qloguniform(*v)
|
2021-09-14 18:36:10 -07:00
|
|
|
elif _type == "normal":
|
2021-09-04 01:42:21 -07:00
|
|
|
config[key] = randn(*v)
|
2021-09-14 18:36:10 -07:00
|
|
|
elif _type == "qnormal":
|
2021-09-04 01:42:21 -07:00
|
|
|
config[key] = qrandn(*v)
|
|
|
|
else:
|
2021-09-14 18:36:10 -07:00
|
|
|
raise ValueError(f"unsupported type in search_space {_type}")
|
2021-09-14 23:16:28 -07:00
|
|
|
# low_cost_partial_config is passed to constructor,
|
|
|
|
# which is before update_search_space() is called
|
|
|
|
init_config = self._ls.init_config
|
|
|
|
add_cost_to_space(config, init_config, self._cat_hp_cost)
|
2021-09-04 01:42:21 -07:00
|
|
|
self._ls = self.LocalSearch(
|
2021-09-14 23:16:28 -07:00
|
|
|
init_config,
|
2021-09-14 18:36:10 -07:00
|
|
|
self._ls.metric,
|
|
|
|
self._mode,
|
|
|
|
config,
|
2021-12-04 21:52:20 -05:00
|
|
|
self._ls.resource_attr,
|
2021-09-14 23:16:28 -07:00
|
|
|
self._ls.min_resource,
|
|
|
|
self._ls.max_resource,
|
|
|
|
self._ls.resource_multiple_factor,
|
2021-09-14 18:36:10 -07:00
|
|
|
cost_attr=self.cost_attr,
|
|
|
|
seed=self._ls.seed,
|
|
|
|
)
|
2021-09-04 01:42:21 -07:00
|
|
|
if self._gs is not None:
|
|
|
|
self._gs = GlobalSearch(
|
2021-09-14 18:36:10 -07:00
|
|
|
space=config,
|
|
|
|
metric=self._metric,
|
|
|
|
mode=self._mode,
|
|
|
|
sampler=self._gs._sampler,
|
|
|
|
)
|
2021-09-04 01:42:21 -07:00
|
|
|
self._gs.space = config
|
|
|
|
self._init_search()
|
2021-02-28 12:43:43 -08:00
|
|
|
|
|
|
|
|
|
|
|
class CFO(BlendSearchTuner):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""class for CFO algorithm."""
|
2021-02-05 21:41:14 -08:00
|
|
|
|
2021-09-14 18:36:10 -07:00
|
|
|
__name__ = "CFO"
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
def suggest(self, trial_id: str) -> Optional[Dict]:
|
2021-02-05 22:45:02 -08:00
|
|
|
# Number of threads is 1 or 2. Thread 0 is a vacuous thread
|
2021-04-08 09:29:55 -07:00
|
|
|
assert len(self._search_thread_pool) < 3, len(self._search_thread_pool)
|
2021-02-05 21:41:14 -08:00
|
|
|
if len(self._search_thread_pool) < 2:
|
2021-07-31 16:39:31 -04:00
|
|
|
# When a local thread converges, the number of threads is 1
|
2021-02-05 21:41:14 -08:00
|
|
|
# Need to restart
|
|
|
|
self._init_used = False
|
|
|
|
return super().suggest(trial_id)
|
|
|
|
|
|
|
|
def _select_thread(self) -> Tuple:
|
|
|
|
for key in self._search_thread_pool:
|
2021-04-08 09:29:55 -07:00
|
|
|
if key:
|
|
|
|
return key, key
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
def _create_condition(self, result: Dict) -> bool:
|
2021-09-14 18:36:10 -07:00
|
|
|
"""create thread condition"""
|
2021-07-31 16:39:31 -04:00
|
|
|
if self._points_to_evaluate:
|
|
|
|
# still evaluating user-specified init points
|
|
|
|
# we evaluate all candidate start points before we
|
|
|
|
# create the first local search thread
|
|
|
|
return False
|
|
|
|
if len(self._search_thread_pool) == 2:
|
|
|
|
return False
|
|
|
|
if self._candidate_start_points and self._thread_count == 1:
|
|
|
|
# result needs to match or exceed the best candidate start point
|
|
|
|
obj_best = min(
|
2021-10-29 20:44:16 +02:00
|
|
|
(
|
|
|
|
self._ls.metric_op * r[self._ls.metric]
|
|
|
|
for r in self._candidate_start_points.values()
|
|
|
|
if r
|
|
|
|
),
|
|
|
|
default=-np.inf,
|
2021-09-14 18:36:10 -07:00
|
|
|
)
|
2021-10-29 20:44:16 +02:00
|
|
|
|
2021-07-31 16:39:31 -04:00
|
|
|
return result[self._ls.metric] * self._ls.metric_op <= obj_best
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
2021-09-14 18:36:10 -07:00
|
|
|
def on_trial_complete(
|
|
|
|
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
|
|
|
):
|
2021-07-31 16:39:31 -04:00
|
|
|
super().on_trial_complete(trial_id, result, error)
|
2021-09-14 18:36:10 -07:00
|
|
|
if self._candidate_start_points and trial_id in self._candidate_start_points:
|
2021-07-31 16:39:31 -04:00
|
|
|
# the trial is a candidate start point
|
|
|
|
self._candidate_start_points[trial_id] = result
|
|
|
|
if len(self._search_thread_pool) < 2 and not self._points_to_evaluate:
|
|
|
|
self._create_thread_from_best_candidate()
|
2021-09-19 11:19:23 -07:00
|
|
|
|
|
|
|
|
|
|
|
class RandomSearch(CFO):
|
2021-11-06 09:37:33 -07:00
|
|
|
"""Class for random search."""
|
|
|
|
|
2021-09-19 11:19:23 -07:00
|
|
|
def suggest(self, trial_id: str) -> Optional[Dict]:
|
|
|
|
if self._points_to_evaluate:
|
|
|
|
return super().suggest(trial_id)
|
|
|
|
config, _ = self._ls.complete_config({})
|
|
|
|
return config
|
|
|
|
|
|
|
|
def on_trial_complete(
|
|
|
|
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
|
|
|
):
|
|
|
|
return
|
|
|
|
|
|
|
|
def on_trial_result(self, trial_id: str, result: Dict):
|
|
|
|
return
|