autogen/flaml/searcher/suggestion.py

705 lines
29 KiB
Python
Raw Normal View History

'''
Copyright 2020 The Ray Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
This source file is adapted here because ray does not fully support Windows.
Copyright (c) Microsoft Corporation.
'''
import copy
import logging
from typing import Any, Dict, Optional, Union, List, Tuple
import pickle
from .variant_generator import parse_spec_vars
from ..tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
Quantized, Uniform
from ..tune.trial import flatten_dict, unflatten_dict
logger = logging.getLogger(__name__)
UNRESOLVED_SEARCH_SPACE = str(
"You passed a `{par}` parameter to {cls} that contained unresolved search "
"space definitions. {cls} should however be instantiated with fully "
"configured search spaces only. To use Ray Tune's automatic search space "
"conversion, pass the space definition as part of the `config` argument "
"to `tune.run()` instead.")
UNDEFINED_SEARCH_SPACE = str(
"Trying to sample a configuration from {cls}, but no search "
"space has been defined. Either pass the `{space}` argument when "
"instantiating the search algorithm, or pass a `config` to "
"`tune.run()`.")
UNDEFINED_METRIC_MODE = str(
"Trying to sample a configuration from {cls}, but the `metric` "
"({metric}) or `mode` ({mode}) parameters have not been set. "
"Either pass these arguments when instantiating the search algorithm, "
"or pass them to `tune.run()`.")
class Searcher:
"""Abstract class for wrapping suggesting algorithms.
Custom algorithms can extend this class easily by overriding the
`suggest` method provide generated parameters for the trials.
Any subclass that implements ``__init__`` must also call the
constructor of this class: ``super(Subclass, self).__init__(...)``.
To track suggestions and their corresponding evaluations, the method
`suggest` will be passed a trial_id, which will be used in
subsequent notifications.
Not all implementations support multi objectives.
Args:
metric (str or list): The training result objective value attribute. If
list then list of training result objective value attributes
mode (str or list): If string One of {min, max}. If list then
list of max and min, determines whether objective is minimizing
or maximizing the metric attribute. Must match type of metric.
.. code-block:: python
class ExampleSearch(Searcher):
def __init__(self, metric="mean_loss", mode="min", **kwargs):
super(ExampleSearch, self).__init__(
metric=metric, mode=mode, **kwargs)
self.optimizer = Optimizer()
self.configurations = {}
def suggest(self, trial_id):
configuration = self.optimizer.query()
self.configurations[trial_id] = configuration
def on_trial_complete(self, trial_id, result, **kwargs):
configuration = self.configurations[trial_id]
if result and self.metric in result:
self.optimizer.update(configuration, result[self.metric])
tune.run(trainable_function, search_alg=ExampleSearch())
"""
FINISHED = "FINISHED"
CKPT_FILE_TMPL = "searcher-state-{}.pkl"
def __init__(self,
metric: Optional[str] = None,
mode: Optional[str] = None,
max_concurrent: Optional[int] = None,
use_early_stopped_trials: Optional[bool] = None):
if use_early_stopped_trials is False:
raise DeprecationWarning(
"Early stopped trials are now always used. If this is a "
"problem, file an issue: https://github.com/ray-project/ray.")
if max_concurrent is not None:
logger.warning(
"DeprecationWarning: `max_concurrent` is deprecated for this "
"search algorithm. Use tune.suggest.ConcurrencyLimiter() "
"instead. This will raise an error in future versions of Ray.")
self._metric = metric
self._mode = mode
if not mode or not metric:
# Early return to avoid assertions
return
assert isinstance(
metric, type(mode)), "metric and mode must be of the same type"
if isinstance(mode, str):
assert mode in ["min", "max"
], "if `mode` is a str must be 'min' or 'max'!"
elif isinstance(mode, list):
assert len(mode) == len(
metric), "Metric and mode must be the same length"
assert all(mod in ["min", "max", "obs"] for mod in
mode), "All of mode must be 'min' or 'max' or 'obs'!"
else:
raise ValueError("Mode most either be a list or string")
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
"""Pass search properties to searcher.
This method acts as an alternative to instantiating search algorithms
with their own specific search spaces. Instead they can accept a
Tune config through this method. A searcher should return ``True``
if setting the config was successful, or ``False`` if it was
unsuccessful, e.g. when the search space has already been set.
Args:
metric (str): Metric to optimize
mode (str): One of ["min", "max"]. Direction to optimize.
config (dict): Tune config dict.
"""
return False
def on_trial_result(self, trial_id: str, result: Dict):
"""Optional notification for result during training.
Note that by default, the result dict may include NaNs or
may not include the optimization metric. It is up to the
subclass implementation to preprocess the result to
avoid breaking the optimization process.
Args:
trial_id (str): A unique string ID for the trial.
result (dict): Dictionary of metrics for current training progress.
Note that the result dict may include NaNs or
may not include the optimization metric. It is up to the
subclass implementation to preprocess the result to
avoid breaking the optimization process.
"""
pass
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Notification for the completion of trial.
Typically, this method is used for notifying the underlying
optimizer of the result.
Args:
trial_id (str): A unique string ID for the trial.
result (dict): Dictionary of metrics for current training progress.
Note that the result dict may include NaNs or
may not include the optimization metric. It is up to the
subclass implementation to preprocess the result to
avoid breaking the optimization process. Upon errors, this
may also be None.
error (bool): True if the training process raised an error.
"""
raise NotImplementedError
def suggest(self, trial_id: str) -> Optional[Dict]:
"""Queries the algorithm to retrieve the next set of parameters.
Arguments:
trial_id (str): Trial ID used for subsequent notifications.
Returns:
dict | FINISHED | None: Configuration for a trial, if possible.
If FINISHED is returned, Tune will be notified that
no more suggestions/configurations will be provided.
If None is returned, Tune will skip the querying of the
searcher for this step.
"""
raise NotImplementedError
def save(self, checkpoint_path: str):
"""Save state to path for this search algorithm.
Args:
checkpoint_path (str): File where the search algorithm
state is saved. This path should be used later when
restoring from file.
Example:
.. code-block:: python
search_alg = Searcher(...)
analysis = tune.run(
cost,
num_samples=5,
search_alg=search_alg,
name=self.experiment_name,
local_dir=self.tmpdir)
search_alg.save("./my_favorite_path.pkl")
.. versionchanged:: 0.8.7
Save is automatically called by `tune.run`. You can use
`restore_from_dir` to restore from an experiment directory
such as `~/ray_results/trainable`.
"""
raise NotImplementedError
def restore(self, checkpoint_path: str):
"""Restore state for this search algorithm
Args:
checkpoint_path (str): File where the search algorithm
state is saved. This path should be the same
as the one provided to "save".
Example:
.. code-block:: python
search_alg.save("./my_favorite_path.pkl")
search_alg2 = Searcher(...)
search_alg2 = ConcurrencyLimiter(search_alg2, 1)
search_alg2.restore(checkpoint_path)
tune.run(cost, num_samples=5, search_alg=search_alg2)
"""
raise NotImplementedError
def get_state(self) -> Dict:
raise NotImplementedError
def set_state(self, state: Dict):
raise NotImplementedError
@property
def metric(self) -> str:
"""The training result objective value attribute."""
return self._metric
@property
def mode(self) -> str:
"""Specifies if minimizing or maximizing the metric."""
return self._mode
class ConcurrencyLimiter(Searcher):
"""A wrapper algorithm for limiting the number of concurrent trials.
Args:
searcher (Searcher): Searcher object that the
ConcurrencyLimiter will manage.
max_concurrent (int): Maximum concurrent samples from the underlying
searcher.
batch (bool): Whether to wait for all concurrent samples
to finish before updating the underlying searcher.
Example:
.. code-block:: python
from ray.tune.suggest import ConcurrencyLimiter
search_alg = HyperOptSearch(metric="accuracy")
search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
tune.run(trainable, search_alg=search_alg)
"""
def __init__(self,
searcher: Searcher,
max_concurrent: int,
batch: bool = False):
assert type(max_concurrent) is int and max_concurrent > 0
self.searcher = searcher
self.max_concurrent = max_concurrent
self.batch = batch
self.live_trials = set()
self.cached_results = {}
super(ConcurrencyLimiter, self).__init__(
metric=self.searcher.metric, mode=self.searcher.mode)
def suggest(self, trial_id: str) -> Optional[Dict]:
assert trial_id not in self.live_trials, (
f"Trial ID {trial_id} must be unique: already found in set.")
if len(self.live_trials) >= self.max_concurrent:
logger.debug(
f"Not providing a suggestion for {trial_id} due to "
"concurrency limit: %s/%s.", len(self.live_trials),
self.max_concurrent)
return
suggestion = self.searcher.suggest(trial_id)
if suggestion not in (None, Searcher.FINISHED):
self.live_trials.add(trial_id)
return suggestion
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
if trial_id not in self.live_trials:
return
elif self.batch:
self.cached_results[trial_id] = (result, error)
if len(self.cached_results) == self.max_concurrent:
# Update the underlying searcher once the
# full batch is completed.
for trial_id, (result, error) in self.cached_results.items():
self.searcher.on_trial_complete(
trial_id, result=result, error=error)
self.live_trials.remove(trial_id)
self.cached_results = {}
else:
return
else:
self.searcher.on_trial_complete(
trial_id, result=result, error=error)
self.live_trials.remove(trial_id)
def get_state(self) -> Dict:
state = self.__dict__.copy()
del state["searcher"]
return copy.deepcopy(state)
def set_state(self, state: Dict):
self.__dict__.update(state)
def save(self, checkpoint_path: str):
self.searcher.save(checkpoint_path)
def restore(self, checkpoint_path: str):
self.searcher.restore(checkpoint_path)
def on_pause(self, trial_id: str):
self.searcher.on_pause(trial_id)
def on_unpause(self, trial_id: str):
self.searcher.on_unpause(trial_id)
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
return self.searcher.set_search_properties(metric, mode, config)
try:
import optuna as ot
from optuna.trial import TrialState as OptunaTrialState
from optuna.samplers import BaseSampler
except ImportError:
ot = None
OptunaTrialState = None
BaseSampler = None
# (Optional) Default (anonymous) metric when using tune.report(x)
DEFAULT_METRIC = "_metric"
# (Auto-filled) The index of this training iteration.
TRAINING_ITERATION = "training_iteration"
class OptunaSearch(Searcher):
"""A wrapper around Optuna to provide trial suggestions.
`Optuna <https://optuna.org/>`_ is a hyperparameter optimization library.
In contrast to other libraries, it employs define-by-run style
hyperparameter definitions.
This Searcher is a thin wrapper around Optuna's search algorithms.
You can pass any Optuna sampler, which will be used to generate
hyperparameter suggestions.
Please note that this wrapper does not support define-by-run, so the
search space will be configured before running the optimization. You will
also need to use a Tune trainable (e.g. using the function API) with
this wrapper.
For defining the search space, use ``ray.tune.suggest.optuna.param``
(see example).
Args:
space (list): Hyperparameter search space definition for Optuna's
sampler. This is a list, and samples for the parameters will
be obtained in order.
metric (str): The training result objective value attribute. If None
but a mode was passed, the anonymous metric `_metric` will be used
per default.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
points_to_evaluate (list): Initial parameter suggestions to be run
first. This is for when you already have some good parameters
you want to run first to help the algorithm make better suggestions
for future parameters. Needs to be a list of dicts containing the
configurations.
sampler (optuna.samplers.BaseSampler): Optuna sampler used to
draw hyperparameter configurations. Defaults to ``TPESampler``.
seed (int): Seed to initialize sampler with. This parameter is only
used when ``sampler=None``. In all other cases, the sampler
you pass should be initialized with the seed already.
evaluated_rewards (list): If you have previously evaluated the
parameters passed in as points_to_evaluate you can avoid
re-running those trials by passing in the reward attributes
as a list so the optimiser can be told the results without
needing to re-compute the trial. Must be the same length as
points_to_evaluate.
Tune automatically converts search spaces to Optuna's format:
.. code-block:: python
from ray.tune.suggest.optuna import OptunaSearch
config = {
"a": tune.uniform(6, 8)
"b": tune.loguniform(1e-4, 1e-2)
}
optuna_search = OptunaSearch(
metric="loss",
mode="min")
tune.run(trainable, config=config, search_alg=optuna_search)
If you would like to pass the search space manually, the code would
look like this:
.. code-block:: python
from ray.tune.suggest.optuna import OptunaSearch
import optuna
config = {
"a": optuna.distributions.UniformDistribution(6, 8),
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
}
optuna_search = OptunaSearch(
space,
metric="loss",
mode="min")
tune.run(trainable, search_alg=optuna_search)
.. versionadded:: 0.8.8
"""
def __init__(self,
space: Optional[Union[Dict, List[Tuple]]] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
points_to_evaluate: Optional[List[Dict]] = None,
sampler: Optional[BaseSampler] = None,
seed: Optional[int] = None,
evaluated_rewards: Optional[List] = None):
assert ot is not None, (
"Optuna must be installed! Run `pip install optuna`.")
super(OptunaSearch, self).__init__(
metric=metric,
mode=mode,
max_concurrent=None,
use_early_stopped_trials=None)
if isinstance(space, dict) and space:
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
if domain_vars or grid_vars:
logger.warning(
UNRESOLVED_SEARCH_SPACE.format(
par="space", cls=type(self).__name__))
space = self.convert_search_space(space)
else:
# Flatten to support nested dicts
space = flatten_dict(space, "/")
# Deprecate: 1.5
if isinstance(space, list):
logger.warning(
"Passing lists of `param.suggest_*()` calls to OptunaSearch "
"as a search space is deprecated and will be removed in "
"a future release of Ray. Please pass a dict mapping "
"to `optuna.distributions` objects instead.")
self._space = space
self._points_to_evaluate = points_to_evaluate or []
self._evaluated_rewards = evaluated_rewards
self._study_name = "optuna" # Fixed study name for in-memory storage
if sampler and seed:
logger.warning(
"You passed an initialized sampler to `OptunaSearch`. The "
"`seed` parameter has to be passed to the sampler directly "
"and will be ignored.")
self._sampler = sampler or ot.samplers.TPESampler(seed=seed)
assert isinstance(self._sampler, BaseSampler), \
"You can only pass an instance of `optuna.samplers.BaseSampler` " \
"as a sampler to `OptunaSearcher`."
self._ot_trials = {}
self._ot_study = None
if self._space:
self._setup_study(mode)
def _setup_study(self, mode: str):
if self._metric is None and self._mode:
# If only a mode was passed, use anonymous metric
self._metric = DEFAULT_METRIC
pruner = ot.pruners.NopPruner()
storage = ot.storages.InMemoryStorage()
self._ot_study = ot.study.create_study(
storage=storage,
sampler=self._sampler,
pruner=pruner,
study_name=self._study_name,
direction="minimize" if mode == "min" else "maximize",
load_if_exists=True)
if self._points_to_evaluate:
if self._evaluated_rewards:
for point, reward in zip(self._points_to_evaluate,
self._evaluated_rewards):
self.add_evaluated_point(point, reward)
else:
for point in self._points_to_evaluate:
self._ot_study.enqueue_trial(point)
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self._space:
return False
space = self.convert_search_space(config)
self._space = space
if metric:
self._metric = metric
if mode:
self._mode = mode
self._setup_study(mode)
return True
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self._space:
raise RuntimeError(
UNDEFINED_SEARCH_SPACE.format(
cls=self.__class__.__name__, space="space"))
if not self._metric or not self._mode:
raise RuntimeError(
UNDEFINED_METRIC_MODE.format(
cls=self.__class__.__name__,
metric=self._metric,
mode=self._mode))
if isinstance(self._space, list):
# Keep for backwards compatibility
# Deprecate: 1.5
if trial_id not in self._ot_trials:
self._ot_trials[trial_id] = self._ot_study.ask()
ot_trial = self._ot_trials[trial_id]
# getattr will fetch the trial.suggest_ function on Optuna trials
params = {
args[0] if len(args) > 0 else kwargs["name"]: getattr(
ot_trial, fn)(*args, **kwargs)
for (fn, args, kwargs) in self._space
}
else:
# Use Optuna ask interface (since version 2.6.0)
if trial_id not in self._ot_trials:
self._ot_trials[trial_id] = self._ot_study.ask(
fixed_distributions=self._space)
ot_trial = self._ot_trials[trial_id]
params = ot_trial.params
return unflatten_dict(params)
def on_trial_result(self, trial_id: str, result: Dict):
metric = result[self.metric]
step = result[TRAINING_ITERATION]
ot_trial = self._ot_trials[trial_id]
ot_trial.report(metric, step)
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
ot_trial = self._ot_trials[trial_id]
val = result.get(self.metric, None) if result else None
ot_trial_state = OptunaTrialState.COMPLETE
if val is None:
if error:
ot_trial_state = OptunaTrialState.FAIL
else:
ot_trial_state = OptunaTrialState.PRUNED
try:
self._ot_study.tell(ot_trial, val, state=ot_trial_state)
except ValueError as exc:
logger.warning(exc) # E.g. if NaN was reported
def add_evaluated_point(self,
parameters: Dict,
value: float,
error: bool = False,
pruned: bool = False,
intermediate_values: Optional[List[float]] = None):
if not self._space:
raise RuntimeError(
UNDEFINED_SEARCH_SPACE.format(
cls=self.__class__.__name__, space="space"))
if not self._metric or not self._mode:
raise RuntimeError(
UNDEFINED_METRIC_MODE.format(
cls=self.__class__.__name__,
metric=self._metric,
mode=self._mode))
ot_trial_state = OptunaTrialState.COMPLETE
if error:
ot_trial_state = OptunaTrialState.FAIL
elif pruned:
ot_trial_state = OptunaTrialState.PRUNED
if intermediate_values:
intermediate_values_dict = {
i: value
for i, value in enumerate(intermediate_values)
}
else:
intermediate_values_dict = None
trial = ot.trial.create_trial(
state=ot_trial_state,
value=value,
params=parameters,
distributions=self._space,
intermediate_values=intermediate_values_dict)
self._ot_study.add_trial(trial)
def save(self, checkpoint_path: str):
save_object = (self._sampler, self._ot_trials, self._ot_study,
self._points_to_evaluate, self._evaluated_rewards)
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(save_object, outputFile)
def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as inputFile:
save_object = pickle.load(inputFile)
if len(save_object) == 5:
self._sampler, self._ot_trials, self._ot_study, \
self._points_to_evaluate, self._evaluated_rewards = save_object
else:
# Backwards compatibility
self._sampler, self._ot_trials, self._ot_study, \
self._points_to_evaluate = save_object
@staticmethod
def convert_search_space(spec: Dict) -> Dict[str, Any]:
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
if not domain_vars and not grid_vars:
return {}
if grid_vars:
raise ValueError(
"Grid search parameters cannot be automatically converted "
"to an Optuna search space.")
# Flatten and resolve again after checking for grid search.
spec = flatten_dict(spec, prevent_delimiter=True)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
def resolve_value(domain: Domain) -> ot.distributions.BaseDistribution:
quantize = None
sampler = domain.get_sampler()
if isinstance(sampler, Quantized):
quantize = sampler.q
sampler = sampler.sampler
if isinstance(sampler, LogUniform):
logger.warning(
"Optuna does not handle quantization in loguniform "
"sampling. The parameter will be passed but it will "
"probably be ignored.")
if isinstance(domain, Float):
if isinstance(sampler, LogUniform):
if quantize:
logger.warning(
"Optuna does not support both quantization and "
"sampling from LogUniform. Dropped quantization.")
return ot.distributions.LogUniformDistribution(
domain.lower, domain.upper)
elif isinstance(sampler, Uniform):
if quantize:
return ot.distributions.DiscreteUniformDistribution(
domain.lower, domain.upper, quantize)
return ot.distributions.UniformDistribution(
domain.lower, domain.upper)
elif isinstance(domain, Integer):
if isinstance(sampler, LogUniform):
return ot.distributions.IntLogUniformDistribution(
domain.lower, domain.upper - 1, step=quantize or 1)
elif isinstance(sampler, Uniform):
# Upper bound should be inclusive for quantization and
# exclusive otherwise
return ot.distributions.IntUniformDistribution(
domain.lower,
domain.upper - int(bool(not quantize)),
step=quantize or 1)
elif isinstance(domain, Categorical):
if isinstance(sampler, Uniform):
return ot.distributions.CategoricalDistribution(
domain.categories)
raise ValueError(
"Optuna search does not support parameters of type "
"`{}` with samplers of type `{}`".format(
type(domain).__name__,
type(domain.sampler).__name__))
# Parameter name is e.g. "a/b/c" for nested dicts
values = {
"/".join(path): resolve_value(domain)
for path, domain in domain_vars
}
return values