# ! # * Copyright (c) Microsoft Corporation. All rights reserved. # * Licensed under the MIT License. See LICENSE file in the # * project root for license information. from typing import Dict, Optional import numpy as np try: from ray import __version__ as ray_version assert ray_version >= "1.10.0" if ray_version.startswith("1."): from ray.tune.suggest import Searcher else: from ray.tune.search import Searcher except (ImportError, AssertionError): from .suggestion import Searcher from .flow2 import FLOW2 from ..space import add_cost_to_space, unflatten_hierarchical import logging logger = logging.getLogger(__name__) class SearchThread: """Class of global or local search thread.""" def __init__( self, mode: str = "min", search_alg: Optional[Searcher] = None, cost_attr: Optional[str] = "time_total_s", eps: Optional[float] = 1.0, ): """When search_alg is omitted, use local search FLOW2.""" self._search_alg = search_alg self._is_ls = isinstance(search_alg, FLOW2) self._mode = mode self._metric_op = 1 if mode == "min" else -1 self.cost_best = self.cost_last = self.cost_total = self.cost_best1 = getattr( search_alg, "cost_incumbent", 0 ) self._eps = eps self.cost_best2 = 0 self.obj_best1 = self.obj_best2 = getattr( search_alg, "best_obj", np.inf ) # inherently minimize self.best_result = None # eci: estimated cost for improvement self.eci = self.cost_best self.priority = self.speed = 0 self._init_config = True self.running = 0 # the number of running trials from the thread self.cost_attr = cost_attr if search_alg: self.space = self._space = search_alg.space # unflattened space if ( self.space and not isinstance(search_alg, FLOW2) and isinstance(search_alg._space, dict) ): # remember const config self._const = add_cost_to_space(self.space, {}, {}) def suggest(self, trial_id: str) -> Optional[Dict]: """Use the suggest() of the underlying search algorithm.""" if isinstance(self._search_alg, FLOW2): config = self._search_alg.suggest(trial_id) else: try: config = self._search_alg.suggest(trial_id) if isinstance(self._search_alg._space, dict): config.update(self._const) else: # define by run config, self.space = unflatten_hierarchical(config, self._space) except FloatingPointError: logger.warning( "The global search method raises FloatingPointError. " "Ignoring for this iteration." ) config = None if config is not None: self.running += 1 return config def update_priority(self, eci: Optional[float] = 0): # optimistic projection self.priority = eci * self.speed - self.obj_best1 def update_eci(self, metric_target: float, max_speed: Optional[float] = np.inf): # calculate eci: estimated cost for improvement over metric_target best_obj = metric_target * self._metric_op if not self.speed: self.speed = max_speed self.eci = max( self.cost_total - self.cost_best1, self.cost_best1 - self.cost_best2 ) if self.obj_best1 > best_obj and self.speed > 0: self.eci = max(self.eci, 2 * (self.obj_best1 - best_obj) / self.speed) def _update_speed(self): # calculate speed; use 0 for invalid speed temporarily if self.obj_best2 > self.obj_best1: # discount the speed if there are unfinished trials self.speed = ( (self.obj_best2 - self.obj_best1) / self.running / (max(self.cost_total - self.cost_best2, self._eps)) ) else: self.speed = 0 def on_trial_complete( self, trial_id: str, result: Optional[Dict] = None, error: bool = False ): """Update the statistics of the thread.""" if not self._search_alg: return if not hasattr(self._search_alg, "_ot_trials") or ( not error and trial_id in self._search_alg._ot_trials ): # optuna doesn't handle error if self._is_ls or not self._init_config: try: self._search_alg.on_trial_complete(trial_id, result, error) except RuntimeError as e: # rs is used in place of optuna sometimes if not str(e).endswith( "has already finished and can not be updated." ): raise e else: # init config is not proposed by self._search_alg # under this thread self._init_config = False if result: self.cost_last = result.get(self.cost_attr, 1) self.cost_total += self.cost_last if self._search_alg.metric in result and ( not hasattr(self._search_alg, "lexico_objectives") or self._search_alg.lexico_objectives is None ): # TODO: Improve this behavior. When lexico_objectives is provided to CFO, # related variables are not callable. obj = result[self._search_alg.metric] * self._metric_op if obj < self.obj_best1 or self.best_result is None: self.cost_best2 = self.cost_best1 self.cost_best1 = self.cost_total self.obj_best2 = obj if np.isinf(self.obj_best1) else self.obj_best1 self.obj_best1 = obj self.cost_best = self.cost_last self.best_result = result if ( not hasattr(self._search_alg, "lexico_objectives") or self._search_alg.lexico_objectives is None ): # TODO: Improve this behavior. When lexico_objectives is provided to CFO, # related variables are not callable. self._update_speed() self.running -= 1 assert self.running >= 0 def on_trial_result(self, trial_id: str, result: Dict): # TODO update the statistics of the thread with partial result? if not self._search_alg: return if not hasattr(self._search_alg, "_ot_trials") or ( trial_id in self._search_alg._ot_trials ): try: self._search_alg.on_trial_result(trial_id, result) except RuntimeError as e: # rs is used in place of optuna sometimes if not str(e).endswith("has already finished and can not be updated."): raise e new_cost = result.get(self.cost_attr, 1) if self.cost_last < new_cost: self.cost_last = new_cost # self._update_speed() @property def converged(self) -> bool: return self._search_alg.converged @property def resource(self) -> float: return self._search_alg.resource def reach(self, thread) -> bool: """Whether the incumbent can reach the incumbent of thread.""" return self._search_alg.reach(thread._search_alg) @property def can_suggest(self) -> bool: """Whether the thread can suggest new configs.""" return self._search_alg.can_suggest