autogen/flaml/tune/searcher/search_thread.py
Anonymous-submission-repo 2d18c49cdd update
2022-10-12 04:31:51 +00:00

198 lines
7.6 KiB
Python

# !
# * Copyright (c) Microsoft Corporation. All rights reserved.
# * Licensed under the MIT License. See LICENSE file in the
# * project root for license information.
from typing import Dict, Optional
import numpy as np
try:
from ray import __version__ as ray_version
assert ray_version >= "1.10.0"
if ray_version.startswith("1."):
from ray.tune.suggest import Searcher
else:
from ray.tune.search import Searcher
except (ImportError, AssertionError):
from .suggestion import Searcher
from .flow2 import FLOW2
from ..space import add_cost_to_space, unflatten_hierarchical
import logging
logger = logging.getLogger(__name__)
class SearchThread:
"""Class of global or local search thread."""
def __init__(
self,
mode: str = "min",
search_alg: Optional[Searcher] = None,
cost_attr: Optional[str] = "time_total_s",
eps: Optional[float] = 1.0,
):
"""When search_alg is omitted, use local search FLOW2."""
self._search_alg = search_alg
self._is_ls = isinstance(search_alg, FLOW2)
self._mode = mode
self._metric_op = 1 if mode == "min" else -1
self.cost_best = self.cost_last = self.cost_total = self.cost_best1 = getattr(
search_alg, "cost_incumbent", 0
)
self._eps = eps
self.cost_best2 = 0
self.obj_best1 = self.obj_best2 = getattr(
search_alg, "best_obj", np.inf
) # inherently minimize
self.best_result = None
# eci: estimated cost for improvement
self.eci = self.cost_best
self.priority = self.speed = 0
self._init_config = True
self.running = 0 # the number of running trials from the thread
self.cost_attr = cost_attr
if search_alg:
self.space = self._space = search_alg.space # unflattened space
if (
self.space
and not isinstance(search_alg, FLOW2)
and isinstance(search_alg._space, dict)
):
# remember const config
self._const = add_cost_to_space(self.space, {}, {})
def suggest(self, trial_id: str) -> Optional[Dict]:
"""Use the suggest() of the underlying search algorithm."""
if isinstance(self._search_alg, FLOW2):
config = self._search_alg.suggest(trial_id)
else:
try:
config = self._search_alg.suggest(trial_id)
if isinstance(self._search_alg._space, dict):
config.update(self._const)
else:
# define by run
config, self.space = unflatten_hierarchical(config, self._space)
except FloatingPointError:
logger.warning(
"The global search method raises FloatingPointError. "
"Ignoring for this iteration."
)
config = None
if config is not None:
self.running += 1
return config
def update_priority(self, eci: Optional[float] = 0):
# optimistic projection
self.priority = eci * self.speed - self.obj_best1
def update_eci(self, metric_target: float, max_speed: Optional[float] = np.inf):
# calculate eci: estimated cost for improvement over metric_target
best_obj = metric_target * self._metric_op
if not self.speed:
self.speed = max_speed
self.eci = max(
self.cost_total - self.cost_best1, self.cost_best1 - self.cost_best2
)
if self.obj_best1 > best_obj and self.speed > 0:
self.eci = max(self.eci, 2 * (self.obj_best1 - best_obj) / self.speed)
def _update_speed(self):
# calculate speed; use 0 for invalid speed temporarily
if self.obj_best2 > self.obj_best1:
# discount the speed if there are unfinished trials
self.speed = (
(self.obj_best2 - self.obj_best1)
/ self.running
/ (max(self.cost_total - self.cost_best2, self._eps))
)
else:
self.speed = 0
def on_trial_complete(
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
):
"""Update the statistics of the thread."""
if not self._search_alg:
return
if not hasattr(self._search_alg, "_ot_trials") or (
not error and trial_id in self._search_alg._ot_trials
):
# optuna doesn't handle error
if self._is_ls or not self._init_config:
try:
self._search_alg.on_trial_complete(trial_id, result, error)
except RuntimeError as e:
# rs is used in place of optuna sometimes
if not str(e).endswith(
"has already finished and can not be updated."
):
raise e
else:
# init config is not proposed by self._search_alg
# under this thread
self._init_config = False
if result:
self.cost_last = result.get(self.cost_attr, 1)
self.cost_total += self.cost_last
if self._search_alg.metric in result and (
not hasattr(self._search_alg, "lexico_objectives")
or self._search_alg.lexico_objectives is None
):
# TODO: Improve this behavior. When lexico_objectives is provided to CFO,
# related variables are not callable.
obj = result[self._search_alg.metric] * self._metric_op
if obj < self.obj_best1 or self.best_result is None:
self.cost_best2 = self.cost_best1
self.cost_best1 = self.cost_total
self.obj_best2 = obj if np.isinf(self.obj_best1) else self.obj_best1
self.obj_best1 = obj
self.cost_best = self.cost_last
self.best_result = result
if (
not hasattr(self._search_alg, "lexico_objectives")
or self._search_alg.lexico_objectives is None
):
# TODO: Improve this behavior. When lexico_objectives is provided to CFO,
# related variables are not callable.
self._update_speed()
self.running -= 1
assert self.running >= 0
def on_trial_result(self, trial_id: str, result: Dict):
# TODO update the statistics of the thread with partial result?
if not self._search_alg:
return
if not hasattr(self._search_alg, "_ot_trials") or (
trial_id in self._search_alg._ot_trials
):
try:
self._search_alg.on_trial_result(trial_id, result)
except RuntimeError as e:
# rs is used in place of optuna sometimes
if not str(e).endswith("has already finished and can not be updated."):
raise e
new_cost = result.get(self.cost_attr, 1)
if self.cost_last < new_cost:
self.cost_last = new_cost
# self._update_speed()
@property
def converged(self) -> bool:
return self._search_alg.converged
@property
def resource(self) -> float:
return self._search_alg.resource
def reach(self, thread) -> bool:
"""Whether the incumbent can reach the incumbent of thread."""
return self._search_alg.reach(thread._search_alg)
@property
def can_suggest(self) -> bool:
"""Whether the thread can suggest new configs."""
return self._search_alg.can_suggest