consider num_samples in bs thread priority (#207)

* consider num_samples in bs thread priority

* continue search for bs
This commit is contained in:
Chi Wang 2021-09-14 18:36:10 -07:00 committed by GitHub
parent ea6c6ded2f
commit a9d39b71da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 376 additions and 231 deletions

View File

@ -1,9 +1,9 @@
'''! """!
* Copyright (c) 2020-2021 Microsoft Corporation. All rights reserved. * Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the * Licensed under the MIT License. See LICENSE file in the
* project root for license information. * project root for license information.
''' """
from typing import Dict, Optional, List, Tuple, Callable from typing import Dict, Optional, List, Tuple, Callable, Union
import numpy as np import numpy as np
import time import time
import pickle import pickle
@ -11,7 +11,8 @@ import pickle
try: try:
from ray import __version__ as ray_version from ray import __version__ as ray_version
assert ray_version >= '1.0.0'
assert ray_version >= "1.0.0"
from ray.tune.suggest import Searcher from ray.tune.suggest import Searcher
from ray.tune.suggest.optuna import OptunaSearch as GlobalSearch from ray.tune.suggest.optuna import OptunaSearch as GlobalSearch
except (ImportError, AssertionError): except (ImportError, AssertionError):
@ -20,42 +21,45 @@ except (ImportError, AssertionError):
from ..tune.trial import unflatten_dict, flatten_dict from ..tune.trial import unflatten_dict, flatten_dict
from .search_thread import SearchThread from .search_thread import SearchThread
from .flow2 import FLOW2 from .flow2 import FLOW2
from ..tune.space import ( from ..tune.space import add_cost_to_space, indexof, normalize, define_by_run_func
add_cost_to_space, indexof, normalize, define_by_run_func)
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BlendSearch(Searcher): class BlendSearch(Searcher):
'''class for BlendSearch algorithm """class for BlendSearch algorithm"""
'''
cost_attr = "time_total_s" # cost attribute in result cost_attr = "time_total_s" # cost attribute in result
lagrange = '_lagrange' # suffix for lagrange-modified metric lagrange = "_lagrange" # suffix for lagrange-modified metric
penalty = 1e+10 # penalty term for constraints penalty = 1e10 # penalty term for constraints
LocalSearch = FLOW2 LocalSearch = FLOW2
def __init__(self, def __init__(
metric: Optional[str] = None, self,
mode: Optional[str] = None, metric: Optional[str] = None,
space: Optional[dict] = None, mode: Optional[str] = None,
low_cost_partial_config: Optional[dict] = None, space: Optional[dict] = None,
cat_hp_cost: Optional[dict] = None, low_cost_partial_config: Optional[dict] = None,
points_to_evaluate: Optional[List[dict]] = None, cat_hp_cost: Optional[dict] = None,
evaluated_rewards: Optional[List] = None, points_to_evaluate: Optional[List[dict]] = None,
prune_attr: Optional[str] = None, evaluated_rewards: Optional[List] = None,
min_resource: Optional[float] = None, time_budget_s: Union[int, float] = None,
max_resource: Optional[float] = None, num_samples: Optional[int] = None,
reduction_factor: Optional[float] = None, prune_attr: Optional[str] = None,
global_search_alg: Optional[Searcher] = None, min_resource: Optional[float] = None,
config_constraints: Optional[ max_resource: Optional[float] = None,
List[Tuple[Callable[[dict], float], str, float]]] = None, reduction_factor: Optional[float] = None,
metric_constraints: Optional[ global_search_alg: Optional[Searcher] = None,
List[Tuple[str, str, float]]] = None, config_constraints: Optional[
seed: Optional[int] = 20, List[Tuple[Callable[[dict], float], str, float]]
experimental: Optional[bool] = False): ] = None,
'''Constructor metric_constraints: Optional[List[Tuple[str, str, float]]] = None,
seed: Optional[int] = 20,
experimental: Optional[bool] = False,
):
"""Constructor
Args: Args:
metric: A string of the metric name to optimize for. metric: A string of the metric name to optimize for.
@ -87,6 +91,8 @@ class BlendSearch(Searcher):
as a list so the optimiser can be told the results without as a list so the optimiser can be told the results without
needing to re-compute the trial. Must be the same length as needing to re-compute the trial. Must be the same length as
points_to_evaluate. points_to_evaluate.
time_budget_s: int or float | Time budget in seconds.
num_samples: int | The number of configs to try.
prune_attr: A string of the attribute used for pruning. prune_attr: A string of the attribute used for pruning.
Not necessarily in space. Not necessarily in space.
When prune_attr is in space, it is a hyperparameter, e.g., When prune_attr is in space, it is a hyperparameter, e.g.,
@ -119,7 +125,7 @@ class BlendSearch(Searcher):
e.g., `['precision', '>=', 0.9]` e.g., `['precision', '>=', 0.9]`
seed: An integer of the random seed. seed: An integer of the random seed.
experimental: A bool of whether to use experimental features. experimental: A bool of whether to use experimental features.
''' """
self._metric, self._mode = metric, mode self._metric, self._mode = metric, mode
init_config = low_cost_partial_config or {} init_config = low_cost_partial_config or {}
if not init_config: if not init_config:
@ -132,8 +138,7 @@ class BlendSearch(Searcher):
if evaluated_rewards and mode: if evaluated_rewards and mode:
self._points_to_evaluate = [] self._points_to_evaluate = []
self._evaluated_rewards = [] self._evaluated_rewards = []
best = max(evaluated_rewards) if mode == 'max' else min( best = max(evaluated_rewards) if mode == "max" else min(evaluated_rewards)
evaluated_rewards)
# only keep the best points as start points # only keep the best points as start points
for i, r in enumerate(evaluated_rewards): for i, r in enumerate(evaluated_rewards):
if r == best: if r == best:
@ -152,51 +157,78 @@ class BlendSearch(Searcher):
if space: if space:
add_cost_to_space(space, init_config, self._cat_hp_cost) add_cost_to_space(space, init_config, self._cat_hp_cost)
self._ls = self.LocalSearch( self._ls = self.LocalSearch(
init_config, metric, mode, space, prune_attr, init_config,
min_resource, max_resource, reduction_factor, self.cost_attr, seed) metric,
mode,
space,
prune_attr,
min_resource,
max_resource,
reduction_factor,
self.cost_attr,
seed,
)
if global_search_alg is not None: if global_search_alg is not None:
self._gs = global_search_alg self._gs = global_search_alg
elif getattr(self, '__name__', None) != 'CFO': elif getattr(self, "__name__", None) != "CFO":
if space and self._ls.hierarchical: if space and self._ls.hierarchical:
from functools import partial from functools import partial
gs_space = partial(define_by_run_func, space=space) gs_space = partial(define_by_run_func, space=space)
evaluated_rewards = None # not supproted by define-by-run evaluated_rewards = None # not supproted by define-by-run
else: else:
gs_space = space gs_space = space
gs_seed = seed - 10 if (seed - 10) >= 0 else seed - 11 + (1 << 32) gs_seed = seed - 10 if (seed - 10) >= 0 else seed - 11 + (1 << 32)
if experimental: if experimental:
import optuna as ot import optuna as ot
sampler = ot.samplers.TPESampler( sampler = ot.samplers.TPESampler(
seed=seed, multivariate=True, group=True) seed=seed, multivariate=True, group=True
)
else: else:
sampler = None sampler = None
try: try:
self._gs = GlobalSearch( self._gs = GlobalSearch(
space=gs_space, metric=metric, mode=mode, seed=gs_seed, space=gs_space,
sampler=sampler, points_to_evaluate=points_to_evaluate, metric=metric,
evaluated_rewards=evaluated_rewards) mode=mode,
seed=gs_seed,
sampler=sampler,
points_to_evaluate=points_to_evaluate,
evaluated_rewards=evaluated_rewards,
)
except ValueError: except ValueError:
self._gs = GlobalSearch( self._gs = GlobalSearch(
space=gs_space, metric=metric, mode=mode, seed=gs_seed, space=gs_space,
sampler=sampler) metric=metric,
mode=mode,
seed=gs_seed,
sampler=sampler,
)
self._gs.space = space self._gs.space = space
else: else:
self._gs = None self._gs = None
self._experimental = experimental self._experimental = experimental
if getattr(self, '__name__', None) == 'CFO' and points_to_evaluate and len( if (
self._points_to_evaluate) > 1: getattr(self, "__name__", None) == "CFO"
and points_to_evaluate
and len(self._points_to_evaluate) > 1
):
# use the best config in points_to_evaluate as the start point # use the best config in points_to_evaluate as the start point
self._candidate_start_points = {} self._candidate_start_points = {}
self._started_from_low_cost = not low_cost_partial_config self._started_from_low_cost = not low_cost_partial_config
else: else:
self._candidate_start_points = None self._candidate_start_points = None
self._time_budget_s, self._num_samples = time_budget_s, num_samples
if space: if space:
self._init_search() self._init_search()
def set_search_properties(self, def set_search_properties(
metric: Optional[str] = None, self,
mode: Optional[str] = None, metric: Optional[str] = None,
config: Optional[Dict] = None) -> bool: mode: Optional[str] = None,
config: Optional[Dict] = None,
) -> bool:
metric_changed = mode_changed = False metric_changed = mode_changed = False
if metric and self._metric != metric: if metric and self._metric != metric:
metric_changed = True metric_changed = True
@ -216,34 +248,54 @@ class BlendSearch(Searcher):
self._gs.set_search_properties(metric, mode, config) self._gs.set_search_properties(metric, mode, config)
self._gs.space = config self._gs.space = config
if config: if config:
add_cost_to_space( add_cost_to_space(config, self._ls.init_config, self._cat_hp_cost)
config, self._ls.init_config, self._cat_hp_cost)
self._ls.set_search_properties(metric, mode, config) self._ls.set_search_properties(metric, mode, config)
self._init_search() self._init_search()
elif metric_changed or mode_changed: else:
# reset search when metric or mode changed if metric_changed or mode_changed:
self._ls.set_search_properties(metric, mode) # reset search when metric or mode changed
if self._gs is not None: self._ls.set_search_properties(metric, mode)
self._gs = GlobalSearch( if self._gs is not None:
space=self._gs._space, metric=metric, mode=mode, self._gs = GlobalSearch(
sampler=self._gs._sampler) space=self._gs._space,
self._gs.space = self._ls.space metric=metric,
self._init_search() mode=mode,
if config: sampler=self._gs._sampler,
if 'time_budget_s' in config: )
time_budget_s = config['time_budget_s'] self._gs.space = self._ls.space
if time_budget_s is not None: self._init_search()
self._deadline = time_budget_s + time.time() if config:
SearchThread.set_eps(time_budget_s) # CFO doesn't need these settings
if 'metric_target' in config: if "time_budget_s" in config:
self._metric_target = config.get('metric_target') self._time_budget_s = config["time_budget_s"] # budget from now
now = time.time()
self._time_used += now - self._start_time
self._start_time = now
self._set_deadline()
if "metric_target" in config:
self._metric_target = config.get("metric_target")
if "num_samples" in config:
self._num_samples = (
config["num_samples"]
+ len(self._result)
+ len(self._trial_proposed_by)
)
return True return True
def _set_deadline(self):
if self._time_budget_s is not None:
self._deadline = self._time_budget_s + self._start_time
SearchThread.set_eps(self._time_budget_s)
else:
self._deadline = np.inf
def _init_search(self): def _init_search(self):
'''initialize the search """initialize the search"""
''' self._start_time = time.time()
self._time_used = 0
self._set_deadline()
self._is_ls_ever_converged = False self._is_ls_ever_converged = False
self._subspace = {} # the subspace for each trial id self._subspace = {} # the subspace for each trial id
self._metric_target = np.inf * self._ls.metric_op self._metric_target = np.inf * self._ls.metric_op
self._search_thread_pool = { self._search_thread_pool = {
# id: int -> thread: SearchThread # id: int -> thread: SearchThread
@ -253,35 +305,41 @@ class BlendSearch(Searcher):
self._init_used = self._ls.init_config is None self._init_used = self._ls.init_config is None
self._trial_proposed_by = {} # trial_id: str -> thread_id: int self._trial_proposed_by = {} # trial_id: str -> thread_id: int
self._ls_bound_min = normalize( self._ls_bound_min = normalize(
self._ls.init_config.copy(), self._ls.space, self._ls.init_config, self._ls.init_config.copy(),
{}, recursive=True) self._ls.space,
self._ls.init_config,
{},
recursive=True,
)
self._ls_bound_max = self._ls_bound_min.copy() self._ls_bound_max = self._ls_bound_min.copy()
self._gs_admissible_min = self._ls_bound_min.copy() self._gs_admissible_min = self._ls_bound_min.copy()
self._gs_admissible_max = self._ls_bound_max.copy() self._gs_admissible_max = self._ls_bound_max.copy()
self._result = {} # config_signature: tuple -> result: Dict self._result = {} # config_signature: tuple -> result: Dict
self._deadline = np.inf
if self._metric_constraints: if self._metric_constraints:
self._metric_constraint_satisfied = False self._metric_constraint_satisfied = False
self._metric_constraint_penalty = [ self._metric_constraint_penalty = [
self.penalty for _ in self._metric_constraints] self.penalty for _ in self._metric_constraints
]
else: else:
self._metric_constraint_satisfied = True self._metric_constraint_satisfied = True
self._metric_constraint_penalty = None self._metric_constraint_penalty = None
self.best_resource = self._ls.min_resource self.best_resource = self._ls.min_resource
def save(self, checkpoint_path: str): def save(self, checkpoint_path: str):
''' save states to a checkpoint path """save states to a checkpoint path"""
''' self._time_used += time.time() - self._start_time
self._start_time = time.time()
save_object = self save_object = self
with open(checkpoint_path, "wb") as outputFile: with open(checkpoint_path, "wb") as outputFile:
pickle.dump(save_object, outputFile) pickle.dump(save_object, outputFile)
def restore(self, checkpoint_path: str): def restore(self, checkpoint_path: str):
''' restore states from checkpoint """restore states from checkpoint"""
'''
with open(checkpoint_path, "rb") as inputFile: with open(checkpoint_path, "rb") as inputFile:
state = pickle.load(inputFile) state = pickle.load(inputFile)
self.__dict__ = state.__dict__ self.__dict__ = state.__dict__
self._start_time = time.time()
self._set_deadline()
@property @property
def metric_target(self): def metric_target(self):
@ -291,10 +349,10 @@ class BlendSearch(Searcher):
def is_ls_ever_converged(self): def is_ls_ever_converged(self):
return self._is_ls_ever_converged return self._is_ls_ever_converged
def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, def on_trial_complete(
error: bool = False): self, trial_id: str, result: Optional[Dict] = None, error: bool = False
''' search thread updater and cleaner ):
''' """search thread updater and cleaner"""
metric_constraint_satisfied = True metric_constraint_satisfied = True
if result and not error and self._metric_constraints: if result and not error and self._metric_constraints:
# account for metric constraints if any # account for metric constraints if any
@ -304,12 +362,15 @@ class BlendSearch(Searcher):
value = result.get(metric_constraint) value = result.get(metric_constraint)
if value: if value:
# sign is <= or >= # sign is <= or >=
sign_op = 1 if sign == '<=' else -1 sign_op = 1 if sign == "<=" else -1
violation = (value - threshold) * sign_op violation = (value - threshold) * sign_op
if violation > 0: if violation > 0:
# add penalty term to the metric # add penalty term to the metric
objective += self._metric_constraint_penalty[ objective += (
i] * violation * self._ls.metric_op self._metric_constraint_penalty[i]
* violation
* self._ls.metric_op
)
metric_constraint_satisfied = False metric_constraint_satisfied = False
if self._metric_constraint_penalty[i] < self.penalty: if self._metric_constraint_penalty[i] < self.penalty:
self._metric_constraint_penalty[i] += violation self._metric_constraint_penalty[i] += violation
@ -321,16 +382,18 @@ class BlendSearch(Searcher):
thread_id = self._trial_proposed_by.get(trial_id) thread_id = self._trial_proposed_by.get(trial_id)
if thread_id in self._search_thread_pool: if thread_id in self._search_thread_pool:
self._search_thread_pool[thread_id].on_trial_complete( self._search_thread_pool[thread_id].on_trial_complete(
trial_id, result, error) trial_id, result, error
)
del self._trial_proposed_by[trial_id] del self._trial_proposed_by[trial_id]
if result: if result:
config = result.get('config', {}) config = result.get("config", {})
if not config: if not config:
for key, value in result.items(): for key, value in result.items():
if key.startswith('config/'): if key.startswith("config/"):
config[key[7:]] = value config[key[7:]] = value
signature = self._ls.config_signature( signature = self._ls.config_signature(
config, self._subspace.get(trial_id, {})) config, self._subspace.get(trial_id, {})
)
if error: # remove from result cache if error: # remove from result cache
del self._result[signature] del self._result[signature]
else: # add to result cache else: # add to result cache
@ -345,28 +408,34 @@ class BlendSearch(Searcher):
if not self._metric_constraint_satisfied: if not self._metric_constraint_satisfied:
# no point has been found to satisfy metric constraint # no point has been found to satisfy metric constraint
self._expand_admissible_region( self._expand_admissible_region(
self._ls_bound_min, self._ls_bound_max, self._ls_bound_min,
self._subspace.get(trial_id, self._ls.space)) self._ls_bound_max,
if self._gs is not None and self._experimental and ( self._subspace.get(trial_id, self._ls.space),
not self._ls.hierarchical): )
self._gs.add_evaluated_point( if (
flatten_dict(config), objective) self._gs is not None
and self._experimental
and (not self._ls.hierarchical)
):
self._gs.add_evaluated_point(flatten_dict(config), objective)
# TODO: recover when supported # TODO: recover when supported
# converted = convert_key(config, self._gs.space) # converted = convert_key(config, self._gs.space)
# logger.info(converted) # logger.info(converted)
# self._gs.add_evaluated_point(converted, objective) # self._gs.add_evaluated_point(converted, objective)
elif metric_constraint_satisfied and self._create_condition( elif metric_constraint_satisfied and self._create_condition(result):
result):
# thread creator # thread creator
thread_id = self._thread_count thread_id = self._thread_count
self._started_from_given = self._candidate_start_points \ self._started_from_given = (
self._candidate_start_points
and trial_id in self._candidate_start_points and trial_id in self._candidate_start_points
)
if self._started_from_given: if self._started_from_given:
del self._candidate_start_points[trial_id] del self._candidate_start_points[trial_id]
else: else:
self._started_from_low_cost = True self._started_from_low_cost = True
self._create_thread(config, result, self._subspace.get( self._create_thread(
trial_id, self._ls.space)) config, result, self._subspace.get(trial_id, self._ls.space)
)
# reset admissible region to ls bounding box # reset admissible region to ls bounding box
self._gs_admissible_min.update(self._ls_bound_min) self._gs_admissible_min.update(self._ls_bound_min)
self._gs_admissible_max.update(self._ls_bound_max) self._gs_admissible_max.update(self._ls_bound_max)
@ -374,26 +443,38 @@ class BlendSearch(Searcher):
if thread_id and thread_id in self._search_thread_pool: if thread_id and thread_id in self._search_thread_pool:
# local search thread # local search thread
self._clean(thread_id) self._clean(thread_id)
if trial_id in self._subspace and not (self._candidate_start_points if trial_id in self._subspace and not (
and trial_id in self._candidate_start_points): self._candidate_start_points and trial_id in self._candidate_start_points
):
del self._subspace[trial_id] del self._subspace[trial_id]
def _create_thread(self, config, result, space): def _create_thread(self, config, result, space):
self._search_thread_pool[self._thread_count] = SearchThread( self._search_thread_pool[self._thread_count] = SearchThread(
self._ls.mode, self._ls.mode,
self._ls.create( self._ls.create(
config, result[self._ls.metric], config,
cost=result.get(self.cost_attr, 1), space=space), result[self._ls.metric],
self.cost_attr cost=result.get(self.cost_attr, 1),
space=space,
),
self.cost_attr,
) )
self._thread_count += 1 self._thread_count += 1
self._update_admissible_region( self._update_admissible_region(
unflatten_dict(config), self._ls_bound_min, self._ls_bound_max, space, unflatten_dict(config),
self._ls.space) self._ls_bound_min,
self._ls_bound_max,
space,
self._ls.space,
)
def _update_admissible_region( def _update_admissible_region(
self, config, admissible_min, admissible_max, subspace: Dict = {}, self,
space: Dict = {} config,
admissible_min,
admissible_max,
subspace: Dict = {},
space: Dict = {},
): ):
# update admissible region # update admissible region
normalized_config = normalize(config, subspace, config, {}) normalized_config = normalize(config, subspace, config, {})
@ -404,13 +485,19 @@ class BlendSearch(Searcher):
choice = indexof(domain, value) choice = indexof(domain, value)
self._update_admissible_region( self._update_admissible_region(
value, value,
admissible_min[key][choice], admissible_max[key][choice], admissible_min[key][choice],
subspace[key], domain[choice] admissible_max[key][choice],
subspace[key],
domain[choice],
) )
elif isinstance(value, dict): elif isinstance(value, dict):
self._update_admissible_region( self._update_admissible_region(
value, admissible_min[key], admissible_max[key], value,
subspace[key], space[key]) admissible_min[key],
admissible_max[key],
subspace[key],
space[key],
)
else: else:
if value > admissible_max[key]: if value > admissible_max[key]:
admissible_max[key] = value admissible_max[key] = value
@ -418,19 +505,18 @@ class BlendSearch(Searcher):
admissible_min[key] = value admissible_min[key] = value
def _create_condition(self, result: Dict) -> bool: def _create_condition(self, result: Dict) -> bool:
''' create thread condition """create thread condition"""
'''
if len(self._search_thread_pool) < 2: if len(self._search_thread_pool) < 2:
return True return True
obj_median = np.median( obj_median = np.median(
[thread.obj_best1 for id, thread in self._search_thread_pool.items() [thread.obj_best1 for id, thread in self._search_thread_pool.items() if id]
if id]) )
return result[self._ls.metric] * self._ls.metric_op < obj_median return result[self._ls.metric] * self._ls.metric_op < obj_median
def _clean(self, thread_id: int): def _clean(self, thread_id: int):
''' delete thread and increase admissible region if converged, """delete thread and increase admissible region if converged,
merge local threads if they are close merge local threads if they are close
''' """
assert thread_id assert thread_id
todelete = set() todelete = set()
for id in self._search_thread_pool: for id in self._search_thread_pool:
@ -447,8 +533,10 @@ class BlendSearch(Searcher):
self._is_ls_ever_converged = True self._is_ls_ever_converged = True
todelete.add(thread_id) todelete.add(thread_id)
self._expand_admissible_region( self._expand_admissible_region(
self._ls_bound_min, self._ls_bound_max, self._ls_bound_min,
self._search_thread_pool[thread_id].space) self._ls_bound_max,
self._search_thread_pool[thread_id].space,
)
if self._candidate_start_points: if self._candidate_start_points:
if not self._started_from_given: if not self._started_from_given:
# remove start points whose perf is worse than the converged # remove start points whose perf is worse than the converged
@ -456,7 +544,8 @@ class BlendSearch(Searcher):
worse = [ worse = [
trial_id trial_id
for trial_id, r in self._candidate_start_points.items() for trial_id, r in self._candidate_start_points.items()
if r and r[self._ls.metric] * self._ls.metric_op >= obj] if r and r[self._ls.metric] * self._ls.metric_op >= obj
]
# logger.info(f"remove candidate start points {worse} than {obj}") # logger.info(f"remove candidate start points {worse} than {obj}")
for trial_id in worse: for trial_id in worse:
del self._candidate_start_points[trial_id] del self._candidate_start_points[trial_id]
@ -472,8 +561,10 @@ class BlendSearch(Searcher):
best_trial_id = None best_trial_id = None
obj_best = None obj_best = None
for trial_id, r in self._candidate_start_points.items(): for trial_id, r in self._candidate_start_points.items():
if r and (best_trial_id is None if r and (
or r[self._ls.metric] * self._ls.metric_op < obj_best): best_trial_id is None
or r[self._ls.metric] * self._ls.metric_op < obj_best
):
best_trial_id = trial_id best_trial_id = trial_id
obj_best = r[self._ls.metric] * self._ls.metric_op obj_best = r[self._ls.metric] * self._ls.metric_op
if best_trial_id: if best_trial_id:
@ -481,20 +572,22 @@ class BlendSearch(Searcher):
config = {} config = {}
result = self._candidate_start_points[best_trial_id] result = self._candidate_start_points[best_trial_id]
for key, value in result.items(): for key, value in result.items():
if key.startswith('config/'): if key.startswith("config/"):
config[key[7:]] = value config[key[7:]] = value
self._started_from_given = True self._started_from_given = True
del self._candidate_start_points[best_trial_id] del self._candidate_start_points[best_trial_id]
self._create_thread(config, result, self._subspace.get( self._create_thread(
best_trial_id, self._ls.space)) config, result, self._subspace.get(best_trial_id, self._ls.space)
)
def _expand_admissible_region(self, lower, upper, space): def _expand_admissible_region(self, lower, upper, space):
for key in upper: for key in upper:
ub = upper[key] ub = upper[key]
if isinstance(ub, list): if isinstance(ub, list):
choice = space[key]['_choice_'] choice = space[key]["_choice_"]
self._expand_admissible_region( self._expand_admissible_region(
lower[key][choice], upper[key][choice], space[key]) lower[key][choice], upper[key][choice], space[key]
)
elif isinstance(ub, dict): elif isinstance(ub, dict):
self._expand_admissible_region(lower[key], ub, space[key]) self._expand_admissible_region(lower[key], ub, space[key])
else: else:
@ -502,8 +595,7 @@ class BlendSearch(Searcher):
lower[key] -= self._ls.STEPSIZE lower[key] -= self._ls.STEPSIZE
def _inferior(self, id1: int, id2: int) -> bool: def _inferior(self, id1: int, id2: int) -> bool:
''' whether thread id1 is inferior to id2 """whether thread id1 is inferior to id2"""
'''
t1 = self._search_thread_pool[id1] t1 = self._search_thread_pool[id1]
t2 = self._search_thread_pool[id2] t2 = self._search_thread_pool[id2]
if t1.obj_best1 < t2.obj_best2: if t1.obj_best1 < t2.obj_best2:
@ -515,8 +607,7 @@ class BlendSearch(Searcher):
return False return False
def on_trial_result(self, trial_id: str, result: Dict): def on_trial_result(self, trial_id: str, result: Dict):
''' receive intermediate result """receive intermediate result"""
'''
if trial_id not in self._trial_proposed_by: if trial_id not in self._trial_proposed_by:
return return
thread_id = self._trial_proposed_by[trial_id] thread_id = self._trial_proposed_by[trial_id]
@ -527,8 +618,7 @@ class BlendSearch(Searcher):
self._search_thread_pool[thread_id].on_trial_result(trial_id, result) self._search_thread_pool[thread_id].on_trial_result(trial_id, result)
def suggest(self, trial_id: str) -> Optional[Dict]: def suggest(self, trial_id: str) -> Optional[Dict]:
''' choose thread, suggest a valid config """choose thread, suggest a valid config"""
'''
if self._init_used and not self._points_to_evaluate: if self._init_used and not self._points_to_evaluate:
choice, backup = self._select_thread() choice, backup = self._select_thread()
# if choice < 0: # timeout # if choice < 0: # timeout
@ -540,8 +630,10 @@ class BlendSearch(Searcher):
# local search thread finishes # local search thread finishes
if self._search_thread_pool[choice].converged: if self._search_thread_pool[choice].converged:
self._expand_admissible_region( self._expand_admissible_region(
self._ls_bound_min, self._ls_bound_max, self._ls_bound_min,
self._search_thread_pool[choice].space) self._ls_bound_max,
self._search_thread_pool[choice].space,
)
del self._search_thread_pool[choice] del self._search_thread_pool[choice]
return None return None
# preliminary check; not checking config validation # preliminary check; not checking config validation
@ -558,8 +650,12 @@ class BlendSearch(Searcher):
return None return None
use_rs = 1 use_rs = 1
if choice or self._valid( if choice or self._valid(
config, self._ls.space, space, self._gs_admissible_min, config,
self._gs_admissible_max): self._ls.space,
space,
self._gs_admissible_min,
self._gs_admissible_max,
):
# LS or valid or no backup choice # LS or valid or no backup choice
self._trial_proposed_by[trial_id] = choice self._trial_proposed_by[trial_id] = choice
self._search_thread_pool[choice].running += use_rs self._search_thread_pool[choice].running += use_rs
@ -568,7 +664,8 @@ class BlendSearch(Searcher):
# use CFO's init point # use CFO's init point
init_config = self._ls.init_config init_config = self._ls.init_config
config, space = self._ls.complete_config( config, space = self._ls.complete_config(
init_config, self._ls_bound_min, self._ls_bound_max) init_config, self._ls_bound_min, self._ls_bound_max
)
self._trial_proposed_by[trial_id] = choice self._trial_proposed_by[trial_id] = choice
self._search_thread_pool[choice].running += 1 self._search_thread_pool[choice].running += 1
else: else:
@ -583,12 +680,20 @@ class BlendSearch(Searcher):
if not choice: # global search if not choice: # global search
# temporarily relax admissible region for parallel proposals # temporarily relax admissible region for parallel proposals
self._update_admissible_region( self._update_admissible_region(
config, self._gs_admissible_min, self._gs_admissible_max, config,
space, self._ls.space) self._gs_admissible_min,
self._gs_admissible_max,
space,
self._ls.space,
)
else: else:
self._update_admissible_region( self._update_admissible_region(
config, self._ls_bound_min, self._ls_bound_max, space, config,
self._ls.space) self._ls_bound_min,
self._ls_bound_max,
space,
self._ls.space,
)
self._gs_admissible_min.update(self._ls_bound_min) self._gs_admissible_min.update(self._ls_bound_min)
self._gs_admissible_max.update(self._ls_bound_max) self._gs_admissible_max.update(self._ls_bound_max)
signature = self._ls.config_signature(config, space) signature = self._ls.config_signature(config, space)
@ -605,7 +710,8 @@ class BlendSearch(Searcher):
else: else:
init_config = self._ls.init_config init_config = self._ls.init_config
config, space = self._ls.complete_config( config, space = self._ls.complete_config(
init_config, self._ls_bound_min, self._ls_bound_max) init_config, self._ls_bound_min, self._ls_bound_max
)
if reward is None: if reward is None:
config_signature = self._ls.config_signature(config, space) config_signature = self._ls.config_signature(config, space)
result = self._result.get(config_signature) result = self._result.get(config_signature)
@ -620,18 +726,15 @@ class BlendSearch(Searcher):
self._search_thread_pool[0].running += 1 self._search_thread_pool[0].running += 1
self._subspace[trial_id] = space self._subspace[trial_id] = space
if reward is not None: if reward is not None:
result = { result = {self._metric: reward, self.cost_attr: 1, "config": config}
self._metric: reward, self.cost_attr: 1,
'config': config
}
self.on_trial_complete(trial_id, result) self.on_trial_complete(trial_id, result)
return None return None
return config return config
def _should_skip(self, choice, trial_id, config, space) -> bool: def _should_skip(self, choice, trial_id, config, space) -> bool:
''' if config is None or config's result is known or constraints are violated """if config is None or config's result is known or constraints are violated
return True; o.w. return False return True; o.w. return False
''' """
if config is None: if config is None:
return True return True
config_signature = self._ls.config_signature(config, space) config_signature = self._ls.config_signature(config, space)
@ -641,11 +744,15 @@ class BlendSearch(Searcher):
for constraint in self._config_constraints: for constraint in self._config_constraints:
func, sign, threshold = constraint func, sign, threshold = constraint
value = func(config) value = func(config)
if (sign == '<=' and value > threshold if (
or sign == '>=' and value < threshold): sign == "<="
and value > threshold
or sign == ">="
and value < threshold
):
self._result[config_signature] = { self._result[config_signature] = {
self._metric: np.inf * self._ls.metric_op, self._metric: np.inf * self._ls.metric_op,
'time_total_s': 1, "time_total_s": 1,
} }
exists = True exists = True
break break
@ -654,7 +761,8 @@ class BlendSearch(Searcher):
result = self._result.get(config_signature) result = self._result.get(config_signature)
if result: # finished if result: # finished
self._search_thread_pool[choice].on_trial_complete( self._search_thread_pool[choice].on_trial_complete(
trial_id, result, error=False) trial_id, result, error=False
)
if choice: if choice:
# local search thread # local search thread
self._clean(choice) self._clean(choice)
@ -666,14 +774,23 @@ class BlendSearch(Searcher):
return False return False
def _select_thread(self) -> Tuple: def _select_thread(self) -> Tuple:
''' thread selector; use can_suggest to check LS availability """thread selector; use can_suggest to check LS availability"""
'''
# update priority # update priority
min_eci = self._deadline - time.time() now = time.time()
min_eci = self._deadline - now
if min_eci <= 0: if min_eci <= 0:
# return -1, -1 # return -1, -1
# keep proposing new configs assuming no budget left # keep proposing new configs assuming no budget left
min_eci = 0 min_eci = 0
elif self._num_samples and self._num_samples > 0:
# estimate time left according to num_samples limitation
num_finished = len(self._result)
num_proposed = num_finished + len(self._trial_proposed_by)
num_left = max(self._num_samples - num_proposed, 0)
if num_proposed > 0:
time_used = now - self._start_time + self._time_used
min_eci = min(min_eci, time_used / num_finished * num_left)
# print(f"{min_eci}, {time_used / num_finished * num_left}, {num_finished}, {num_left}")
max_speed = 0 max_speed = 0
for thread in self._search_thread_pool.values(): for thread in self._search_thread_pool.values():
if thread.speed > max_speed: if thread.speed > max_speed:
@ -698,10 +815,10 @@ class BlendSearch(Searcher):
backup_thread_id = thread_id backup_thread_id = thread_id
return top_thread_id, backup_thread_id return top_thread_id, backup_thread_id
def _valid(self, config: Dict, space: Dict, subspace: Dict, def _valid(
lower: Dict, upper: Dict) -> bool: self, config: Dict, space: Dict, subspace: Dict, lower: Dict, upper: Dict
''' config validator ) -> bool:
''' """config validator"""
normalized_config = normalize(config, subspace, config, {}) normalized_config = normalize(config, subspace, config, {})
for key, lb in lower.items(): for key, lb in lower.items():
if key in config: if key in config:
@ -719,114 +836,143 @@ class BlendSearch(Searcher):
else: else:
nestedspace = None nestedspace = None
if nestedspace: if nestedspace:
valid = self._valid( valid = self._valid(value, domain, nestedspace, lb, ub)
value, domain, nestedspace, lb, ub)
if not valid: if not valid:
return False return False
elif (value + self._ls.STEPSIZE < lower[key] elif (
or value > upper[key] + self._ls.STEPSIZE): value + self._ls.STEPSIZE < lower[key]
or value > upper[key] + self._ls.STEPSIZE
):
return False return False
return True return True
try: try:
from ray import __version__ as ray_version from ray import __version__ as ray_version
assert ray_version >= '1.0.0'
from ray.tune import (uniform, quniform, choice, randint, qrandint, randn, assert ray_version >= "1.0.0"
qrandn, loguniform, qloguniform) from ray.tune import (
uniform,
quniform,
choice,
randint,
qrandint,
randn,
qrandn,
loguniform,
qloguniform,
)
except (ImportError, AssertionError): except (ImportError, AssertionError):
from ..tune.sample import (uniform, quniform, choice, randint, qrandint, randn, from ..tune.sample import (
qrandn, loguniform, qloguniform) uniform,
quniform,
choice,
randint,
qrandint,
randn,
qrandn,
loguniform,
qloguniform,
)
try: try:
from nni.tuner import Tuner as NNITuner from nni.tuner import Tuner as NNITuner
from nni.utils import extract_scalar_reward from nni.utils import extract_scalar_reward
except ImportError: except ImportError:
class NNITuner: class NNITuner:
pass pass
def extract_scalar_reward(x: Dict): def extract_scalar_reward(x: Dict):
return x.get('reward') return x.get("reward")
class BlendSearchTuner(BlendSearch, NNITuner): class BlendSearchTuner(BlendSearch, NNITuner):
'''Tuner class for NNI """Tuner class for NNI"""
'''
def receive_trial_result(self, parameter_id, parameters, value, def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
**kwargs): """
'''
Receive trial's final result. Receive trial's final result.
parameter_id: int parameter_id: int
parameters: object created by 'generate_parameters()' parameters: object created by 'generate_parameters()'
value: final metrics of the trial, including default metric value: final metrics of the trial, including default metric
''' """
result = { result = {
'config': parameters, self._metric: extract_scalar_reward(value), "config": parameters,
self.cost_attr: 1 if isinstance(value, float) else value.get( self._metric: extract_scalar_reward(value),
self.cost_attr, value.get('sequence', 1)) self.cost_attr: 1
if isinstance(value, float)
else value.get(self.cost_attr, value.get("sequence", 1))
# if nni does not report training cost, # if nni does not report training cost,
# using sequence as an approximation. # using sequence as an approximation.
# if no sequence, using a constant 1 # if no sequence, using a constant 1
} }
self.on_trial_complete(str(parameter_id), result) self.on_trial_complete(str(parameter_id), result)
... ...
def generate_parameters(self, parameter_id, **kwargs) -> Dict: def generate_parameters(self, parameter_id, **kwargs) -> Dict:
''' """
Returns a set of trial (hyper-)parameters, as a serializable object Returns a set of trial (hyper-)parameters, as a serializable object
parameter_id: int parameter_id: int
''' """
return self.suggest(str(parameter_id)) return self.suggest(str(parameter_id))
... ...
def update_search_space(self, search_space): def update_search_space(self, search_space):
''' """
Tuners are advised to support updating search space at run-time. Tuners are advised to support updating search space at run-time.
If a tuner can only set search space once before generating first hyper-parameters, If a tuner can only set search space once before generating first hyper-parameters,
it should explicitly document this behaviour. it should explicitly document this behaviour.
search_space: JSON object created by experiment owner search_space: JSON object created by experiment owner
''' """
config = {} config = {}
for key, value in search_space.items(): for key, value in search_space.items():
v = value.get("_value") v = value.get("_value")
_type = value['_type'] _type = value["_type"]
if _type == 'choice': if _type == "choice":
config[key] = choice(v) config[key] = choice(v)
elif _type == 'randint': elif _type == "randint":
config[key] = randint(*v) config[key] = randint(*v)
elif _type == 'uniform': elif _type == "uniform":
config[key] = uniform(*v) config[key] = uniform(*v)
elif _type == 'quniform': elif _type == "quniform":
config[key] = quniform(*v) config[key] = quniform(*v)
elif _type == 'loguniform': elif _type == "loguniform":
config[key] = loguniform(*v) config[key] = loguniform(*v)
elif _type == 'qloguniform': elif _type == "qloguniform":
config[key] = qloguniform(*v) config[key] = qloguniform(*v)
elif _type == 'normal': elif _type == "normal":
config[key] = randn(*v) config[key] = randn(*v)
elif _type == 'qnormal': elif _type == "qnormal":
config[key] = qrandn(*v) config[key] = qrandn(*v)
else: else:
raise ValueError( raise ValueError(f"unsupported type in search_space {_type}")
f'unsupported type in search_space {_type}')
add_cost_to_space(config, {}, {}) add_cost_to_space(config, {}, {})
self._ls = self.LocalSearch( self._ls = self.LocalSearch(
{}, self._ls.metric, self._mode, config, cost_attr=self.cost_attr, {},
seed=self._ls.seed) self._ls.metric,
self._mode,
config,
cost_attr=self.cost_attr,
seed=self._ls.seed,
)
if self._gs is not None: if self._gs is not None:
self._gs = GlobalSearch( self._gs = GlobalSearch(
space=config, metric=self._metric, mode=self._mode, space=config,
sampler=self._gs._sampler) metric=self._metric,
mode=self._mode,
sampler=self._gs._sampler,
)
self._gs.space = config self._gs.space = config
self._init_search() self._init_search()
class CFO(BlendSearchTuner): class CFO(BlendSearchTuner):
''' class for CFO algorithm """class for CFO algorithm"""
'''
__name__ = 'CFO' __name__ = "CFO"
def suggest(self, trial_id: str) -> Optional[Dict]: def suggest(self, trial_id: str) -> Optional[Dict]:
# Number of threads is 1 or 2. Thread 0 is a vacuous thread # Number of threads is 1 or 2. Thread 0 is a vacuous thread
@ -843,8 +989,7 @@ class CFO(BlendSearchTuner):
return key, key return key, key
def _create_condition(self, result: Dict) -> bool: def _create_condition(self, result: Dict) -> bool:
''' create thread condition """create thread condition"""
'''
if self._points_to_evaluate: if self._points_to_evaluate:
# still evaluating user-specified init points # still evaluating user-specified init points
# we evaluate all candidate start points before we # we evaluate all candidate start points before we
@ -856,16 +1001,18 @@ class CFO(BlendSearchTuner):
# result needs to match or exceed the best candidate start point # result needs to match or exceed the best candidate start point
obj_best = min( obj_best = min(
self._ls.metric_op * r[self._ls.metric] self._ls.metric_op * r[self._ls.metric]
for r in self._candidate_start_points.values() if r) for r in self._candidate_start_points.values()
if r
)
return result[self._ls.metric] * self._ls.metric_op <= obj_best return result[self._ls.metric] * self._ls.metric_op <= obj_best
else: else:
return True return True
def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, def on_trial_complete(
error: bool = False): self, trial_id: str, result: Optional[Dict] = None, error: bool = False
):
super().on_trial_complete(trial_id, result, error) super().on_trial_complete(trial_id, result, error)
if self._candidate_start_points \ if self._candidate_start_points and trial_id in self._candidate_start_points:
and trial_id in self._candidate_start_points:
# the trial is a candidate start point # the trial is a candidate start point
self._candidate_start_points[trial_id] = result self._candidate_start_points[trial_id] = result
if len(self._search_thread_pool) < 2 and not self._points_to_evaluate: if len(self._search_thread_pool) < 2 and not self._points_to_evaluate:

View File

@ -83,8 +83,6 @@ def report(_metric=None, **kwargs):
return tune.report(_metric, **kwargs) return tune.report(_metric, **kwargs)
else: else:
result = kwargs result = kwargs
if _verbose == 2:
logger.info(f"result: {kwargs}")
if _metric: if _metric:
result[DEFAULT_METRIC] = _metric result[DEFAULT_METRIC] = _metric
trial = _runner.running_trial trial = _runner.running_trial
@ -114,7 +112,7 @@ def run(
cat_hp_cost: Optional[dict] = None, cat_hp_cost: Optional[dict] = None,
metric: Optional[str] = None, metric: Optional[str] = None,
mode: Optional[str] = None, mode: Optional[str] = None,
time_budget_s: Union[int, float, datetime.timedelta] = None, time_budget_s: Union[int, float] = None,
points_to_evaluate: Optional[List[dict]] = None, points_to_evaluate: Optional[List[dict]] = None,
evaluated_rewards: Optional[List] = None, evaluated_rewards: Optional[List] = None,
prune_attr: Optional[str] = None, prune_attr: Optional[str] = None,
@ -184,7 +182,7 @@ def run(
metric: A string of the metric name to optimize for. metric: A string of the metric name to optimize for.
mode: A string in ['min', 'max'] to specify the objective as mode: A string in ['min', 'max'] to specify the objective as
minimization or maximization. minimization or maximization.
time_budget_s: A float of the time budget in seconds. time_budget_s: int or float | The time budget in seconds.
points_to_evaluate: A list of initial hyperparameter points_to_evaluate: A list of initial hyperparameter
configurations to run first. configurations to run first.
evaluated_rewards (list): If you have previously evaluated the evaluated_rewards (list): If you have previously evaluated the
@ -291,6 +289,8 @@ def run(
evaluated_rewards=evaluated_rewards, evaluated_rewards=evaluated_rewards,
low_cost_partial_config=low_cost_partial_config, low_cost_partial_config=low_cost_partial_config,
cat_hp_cost=cat_hp_cost, cat_hp_cost=cat_hp_cost,
time_budget_s=time_budget_s,
num_samples=num_samples,
prune_attr=prune_attr, prune_attr=prune_attr,
min_resource=min_resource, min_resource=min_resource,
max_resource=max_resource, max_resource=max_resource,
@ -303,10 +303,12 @@ def run(
if metric is None or mode is None: if metric is None or mode is None:
metric = metric or search_alg.metric metric = metric or search_alg.metric
mode = mode or search_alg.mode mode = mode or search_alg.mode
if time_budget_s: if time_budget_s or num_samples > 0:
search_alg.set_search_properties( search_alg.set_search_properties(
None, None, config={"time_budget_s": time_budget_s} None,
) None,
config={"time_budget_s": time_budget_s, "num_samples": num_samples},
)
scheduler = None scheduler = None
if report_intermediate_result: if report_intermediate_result:
params = {} params = {}

View File

@ -6,7 +6,6 @@ import numpy as np
from flaml.searcher.suggestion import ConcurrencyLimiter from flaml.searcher.suggestion import ConcurrencyLimiter
from flaml import tune from flaml import tune
from flaml import CFO from flaml import CFO
from flaml import BlendSearch
class AbstractWarmStartTest: class AbstractWarmStartTest:
@ -27,28 +26,24 @@ class AbstractWarmStartTest:
search_alg, cost = self.set_basic_conf() search_alg, cost = self.set_basic_conf()
search_alg = ConcurrencyLimiter(search_alg, 1) search_alg = ConcurrencyLimiter(search_alg, 1)
results_exp_1 = tune.run( results_exp_1 = tune.run(
cost, cost, num_samples=5, search_alg=search_alg, verbose=0, local_dir=self.tmpdir
num_samples=5, )
search_alg=search_alg,
verbose=0,
local_dir=self.tmpdir)
checkpoint_path = os.path.join(self.tmpdir, self.experiment_name) checkpoint_path = os.path.join(self.tmpdir, self.experiment_name)
search_alg.save(checkpoint_path) search_alg.save(checkpoint_path)
return results_exp_1, np.random.get_state(), checkpoint_path return results_exp_1, np.random.get_state(), checkpoint_path
def run_explicit_restore(self, random_state, checkpoint_path): def run_explicit_restore(self, random_state, checkpoint_path):
np.random.set_state(random_state)
search_alg2, cost = self.set_basic_conf() search_alg2, cost = self.set_basic_conf()
search_alg2 = ConcurrencyLimiter(search_alg2, 1) search_alg2 = ConcurrencyLimiter(search_alg2, 1)
search_alg2.restore(checkpoint_path) search_alg2.restore(checkpoint_path)
np.random.set_state(random_state)
return tune.run(cost, num_samples=5, search_alg=search_alg2, verbose=0) return tune.run(cost, num_samples=5, search_alg=search_alg2, verbose=0)
def run_full(self): def run_full(self):
np.random.seed(162) np.random.seed(162)
search_alg3, cost = self.set_basic_conf() search_alg3, cost = self.set_basic_conf()
search_alg3 = ConcurrencyLimiter(search_alg3, 1) search_alg3 = ConcurrencyLimiter(search_alg3, 1)
return tune.run( return tune.run(cost, num_samples=10, search_alg=search_alg3, verbose=0)
cost, num_samples=10, search_alg=search_alg3, verbose=0)
def testReproduce(self): def testReproduce(self):
results_exp_1, _, _ = self.run_part_from_scratch() results_exp_1, _, _ = self.run_part_from_scratch()
@ -75,7 +70,7 @@ class CFOWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
} }
def cost(param): def cost(param):
tune.report(loss=(param["height"] - 14)**2 - abs(param["width"] - 3)) tune.report(loss=(param["height"] - 14) ** 2 - abs(param["width"] - 3))
search_alg = CFO( search_alg = CFO(
space=space, space=space,
@ -86,6 +81,7 @@ class CFOWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
return search_alg, cost return search_alg, cost
# # # Not doing test for BS because of problems with random seed in OptunaSearch # # # Not doing test for BS because of problems with random seed in OptunaSearch
# class BlendsearchWarmStartTest(AbstractWarmStartTest, unittest.TestCase): # class BlendsearchWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
# def set_basic_conf(self): # def set_basic_conf(self):