2021-02-05 21:41:14 -08:00
|
|
|
'''!
|
|
|
|
* Copyright (c) 2020-2021 Microsoft Corporation. All rights reserved.
|
|
|
|
* Licensed under the MIT License. See LICENSE file in the
|
|
|
|
* project root for license information.
|
|
|
|
'''
|
|
|
|
from typing import Dict, Optional
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
|
|
from ray.tune.suggest import Searcher
|
|
|
|
except ImportError:
|
|
|
|
from .suggestion import Searcher
|
|
|
|
from .flow2 import FLOW2
|
|
|
|
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class SearchThread:
|
|
|
|
'''Class of global or local search thread
|
|
|
|
'''
|
|
|
|
|
|
|
|
cost_attr = 'time_total_s'
|
2021-06-07 19:49:45 -04:00
|
|
|
_eps = 1.0
|
2021-02-05 21:41:14 -08:00
|
|
|
|
2021-04-08 09:29:55 -07:00
|
|
|
def __init__(self, mode: str = "min",
|
2021-02-05 21:41:14 -08:00
|
|
|
search_alg: Optional[Searcher] = None):
|
|
|
|
''' When search_alg is omitted, use local search FLOW2
|
|
|
|
'''
|
|
|
|
self._search_alg = search_alg
|
2021-03-05 23:39:14 -08:00
|
|
|
self._is_ls = isinstance(search_alg, FLOW2)
|
2021-02-05 21:41:14 -08:00
|
|
|
self._mode = mode
|
2021-04-08 09:29:55 -07:00
|
|
|
self._metric_op = 1 if mode == 'min' else -1
|
2021-02-05 21:41:14 -08:00
|
|
|
self.cost_best = self.cost_last = self.cost_total = self.cost_best1 = \
|
|
|
|
getattr(search_alg, 'cost_incumbent', 0)
|
|
|
|
self.cost_best2 = 0
|
|
|
|
self.obj_best1 = self.obj_best2 = getattr(
|
2021-04-08 09:29:55 -07:00
|
|
|
search_alg, 'best_obj', np.inf) # inherently minimize
|
|
|
|
# eci: estimated cost for improvement
|
2021-02-05 21:41:14 -08:00
|
|
|
self.eci = self.cost_best
|
|
|
|
self.priority = self.speed = 0
|
2021-04-08 09:29:55 -07:00
|
|
|
self._init_config = True
|
2021-06-25 14:24:46 -07:00
|
|
|
self.running = 0 # the number of running trials from the thread
|
2021-04-08 09:29:55 -07:00
|
|
|
|
2021-06-07 19:49:45 -04:00
|
|
|
@classmethod
|
|
|
|
def set_eps(cls, time_budget_s):
|
|
|
|
cls._eps = max(min(time_budget_s / 1000.0, 1.0), 1e-10)
|
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
def suggest(self, trial_id: str) -> Optional[Dict]:
|
|
|
|
''' use the suggest() of the underlying search algorithm
|
|
|
|
'''
|
|
|
|
if isinstance(self._search_alg, FLOW2):
|
|
|
|
config = self._search_alg.suggest(trial_id)
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
config = self._search_alg.suggest(trial_id)
|
2021-04-08 09:29:55 -07:00
|
|
|
except FloatingPointError:
|
2021-02-05 21:41:14 -08:00
|
|
|
logger.warning(
|
2021-04-08 09:29:55 -07:00
|
|
|
'The global search method raises FloatingPointError. '
|
2021-02-05 21:41:14 -08:00
|
|
|
'Ignoring for this iteration.')
|
|
|
|
config = None
|
2021-06-25 14:24:46 -07:00
|
|
|
if config is not None:
|
|
|
|
self.running += 1
|
2021-02-05 21:41:14 -08:00
|
|
|
return config
|
|
|
|
|
|
|
|
def update_priority(self, eci: Optional[float] = 0):
|
|
|
|
# optimistic projection
|
|
|
|
self.priority = eci * self.speed - self.obj_best1
|
|
|
|
|
|
|
|
def update_eci(self, metric_target: float,
|
2021-04-08 09:29:55 -07:00
|
|
|
max_speed: Optional[float] = np.inf):
|
|
|
|
# calculate eci: estimated cost for improvement over metric_target
|
2021-02-05 21:41:14 -08:00
|
|
|
best_obj = metric_target * self._metric_op
|
2021-04-08 09:29:55 -07:00
|
|
|
if not self.speed:
|
|
|
|
self.speed = max_speed
|
2021-02-05 21:41:14 -08:00
|
|
|
self.eci = max(self.cost_total - self.cost_best1,
|
2021-04-08 09:29:55 -07:00
|
|
|
self.cost_best1 - self.cost_best2)
|
2021-02-05 21:41:14 -08:00
|
|
|
if self.obj_best1 > best_obj and self.speed > 0:
|
2021-04-08 09:29:55 -07:00
|
|
|
self.eci = max(self.eci, 2 * (self.obj_best1 - best_obj) / self.speed)
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
def _update_speed(self):
|
|
|
|
# calculate speed; use 0 for invalid speed temporarily
|
2021-04-08 09:29:55 -07:00
|
|
|
if self.obj_best2 > self.obj_best1:
|
2021-06-25 14:24:46 -07:00
|
|
|
# discount the speed if there are unfinished trials
|
|
|
|
self.speed = (self.obj_best2 - self.obj_best1) / self.running / (
|
2021-06-07 19:49:45 -04:00
|
|
|
max(self.cost_total - self.cost_best2, SearchThread._eps))
|
2021-04-08 09:29:55 -07:00
|
|
|
else:
|
|
|
|
self.speed = 0
|
2021-02-05 21:41:14 -08:00
|
|
|
|
|
|
|
def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
|
|
|
|
error: bool = False):
|
|
|
|
''' update the statistics of the thread
|
|
|
|
'''
|
2021-04-08 09:29:55 -07:00
|
|
|
if not self._search_alg:
|
|
|
|
return
|
|
|
|
if not hasattr(self._search_alg, '_ot_trials') or (
|
|
|
|
not error and trial_id in self._search_alg._ot_trials):
|
2021-02-05 21:41:14 -08:00
|
|
|
# optuna doesn't handle error
|
2021-03-05 23:39:14 -08:00
|
|
|
if self._is_ls or not self._init_config:
|
2021-06-25 14:24:46 -07:00
|
|
|
try:
|
|
|
|
self._search_alg.on_trial_complete(trial_id, result, error)
|
|
|
|
except RuntimeError as e:
|
|
|
|
# rs is used in place of optuna sometimes
|
|
|
|
if not str(e).endswith(
|
|
|
|
"has already finished and can not be updated."):
|
|
|
|
raise e
|
2021-04-08 09:29:55 -07:00
|
|
|
else:
|
2021-03-05 23:39:14 -08:00
|
|
|
# init config is not proposed by self._search_alg
|
|
|
|
# under this thread
|
|
|
|
self._init_config = False
|
2021-02-05 21:41:14 -08:00
|
|
|
if result:
|
|
|
|
if self.cost_attr in result:
|
|
|
|
self.cost_last = result[self.cost_attr]
|
|
|
|
self.cost_total += self.cost_last
|
|
|
|
if self._search_alg.metric in result:
|
|
|
|
obj = result[self._search_alg.metric] * self._metric_op
|
|
|
|
if obj < self.obj_best1:
|
|
|
|
self.cost_best2 = self.cost_best1
|
|
|
|
self.cost_best1 = self.cost_total
|
|
|
|
self.obj_best2 = obj if np.isinf(
|
|
|
|
self.obj_best1) else self.obj_best1
|
|
|
|
self.obj_best1 = obj
|
|
|
|
self.cost_best = self.cost_last
|
|
|
|
self._update_speed()
|
2021-06-25 14:24:46 -07:00
|
|
|
self.running -= 1
|
|
|
|
assert self.running >= 0
|
2021-04-08 09:29:55 -07:00
|
|
|
|
2021-02-05 21:41:14 -08:00
|
|
|
def on_trial_result(self, trial_id: str, result: Dict):
|
|
|
|
''' TODO update the statistics of the thread with partial result?
|
|
|
|
'''
|
2021-04-08 09:29:55 -07:00
|
|
|
if not self._search_alg:
|
|
|
|
return
|
2021-02-05 21:41:14 -08:00
|
|
|
if not hasattr(self._search_alg, '_ot_trials') or (
|
2021-06-25 14:24:46 -07:00
|
|
|
trial_id in self._search_alg._ot_trials):
|
|
|
|
try:
|
|
|
|
self._search_alg.on_trial_result(trial_id, result)
|
|
|
|
except RuntimeError as e:
|
|
|
|
# rs is used in place of optuna sometimes
|
|
|
|
if not str(e).endswith(
|
|
|
|
"has already finished and can not be updated."):
|
|
|
|
raise e
|
2021-02-05 21:41:14 -08:00
|
|
|
if self.cost_attr in result and self.cost_last < result[self.cost_attr]:
|
|
|
|
self.cost_last = result[self.cost_attr]
|
|
|
|
# self._update_speed()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def converged(self) -> bool:
|
|
|
|
return self._search_alg.converged
|
|
|
|
|
|
|
|
@property
|
|
|
|
def resource(self) -> float:
|
|
|
|
return self._search_alg.resource
|
|
|
|
|
|
|
|
def reach(self, thread) -> bool:
|
|
|
|
''' whether the incumbent can reach the incumbent of thread
|
|
|
|
'''
|
|
|
|
return self._search_alg.reach(thread._search_alg)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def can_suggest(self) -> bool:
|
|
|
|
''' whether the thread can suggest new configs
|
|
|
|
'''
|
|
|
|
return self._search_alg.can_suggest
|