autogen/flaml/tune/trial_runner.py
Chi Wang 9128c8811a
handle failing trials (#505)
* handle failing trials

* clarify when to return {}

* skip ensemble in accuracy check
2022-03-28 16:57:52 -07:00

138 lines
4.2 KiB
Python

# !
# * Copyright (c) Microsoft Corporation. All rights reserved.
# * Licensed under the MIT License. See LICENSE file in the
# * project root for license information.
from typing import Optional
# try:
# from ray import __version__ as ray_version
# assert ray_version >= '1.0.0'
# from ray.tune.trial import Trial
# except (ImportError, AssertionError):
from .trial import Trial
import logging
logger = logging.getLogger(__name__)
class Nologger:
"""Logger without logging."""
def on_result(self, result):
pass
class SimpleTrial(Trial):
"""A simple trial class."""
def __init__(self, config, trial_id=None):
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
self.config = config or {}
self.status = Trial.PENDING
self.start_time = None
self.last_result = {}
self.last_update_time = -float("inf")
self.custom_trial_name = None
self.trainable_name = "trainable"
self.experiment_tag = "exp"
self.verbose = False
self.result_logger = Nologger()
self.metric_analysis = {}
self.n_steps = [5, 10]
self.metric_n_steps = {}
class BaseTrialRunner:
"""Implementation of a simple trial runner.
Note that the caller usually should not mutate trial state directly.
"""
def __init__(
self,
search_alg=None,
scheduler=None,
metric: Optional[str] = None,
mode: Optional[str] = "min",
):
self._search_alg = search_alg
self._scheduler_alg = scheduler
self._trials = []
self._metric = metric
self._mode = mode
def get_trials(self):
"""Returns the list of trials managed by this TrialRunner.
Note that the caller usually should not mutate trial state directly.
"""
return self._trials
def add_trial(self, trial):
"""Adds a new trial to this TrialRunner.
Trials may be added at any time.
Args:
trial (Trial): Trial to queue.
"""
self._trials.append(trial)
if self._scheduler_alg:
self._scheduler_alg.on_trial_add(self, trial)
def process_trial_result(self, trial, result):
trial.update_last_result(result)
if "time_total_s" not in result.keys():
result["time_total_s"] = trial.last_update_time - trial.start_time
self._search_alg.on_trial_result(trial.trial_id, result)
if self._scheduler_alg:
decision = self._scheduler_alg.on_trial_result(self, trial, result)
if decision == "STOP":
trial.set_status(Trial.TERMINATED)
elif decision == "PAUSE":
trial.set_status(Trial.PAUSED)
def stop_trial(self, trial):
"""Stops trial."""
if trial.status not in [Trial.ERROR, Trial.TERMINATED]:
if self._scheduler_alg:
self._scheduler_alg.on_trial_complete(
self, trial.trial_id, trial.last_result
)
self._search_alg.on_trial_complete(trial.trial_id, trial.last_result)
trial.set_status(Trial.TERMINATED)
elif self._scheduler_alg:
self._scheduler_alg.on_trial_remove(self, trial)
if trial.status == Trial.ERROR:
self._search_alg.on_trial_complete(
trial.trial_id, trial.last_result, error=True
)
class SequentialTrialRunner(BaseTrialRunner):
"""Implementation of the sequential trial runner."""
def step(self) -> Trial:
"""Runs one step of the trial event loop.
Callers should typically run this method repeatedly in a loop. They
may inspect or modify the runner's state in between calls to step().
Returns:
a trial to run.
"""
trial_id = Trial.generate_id()
config = self._search_alg.suggest(trial_id)
if config is not None:
trial = SimpleTrial(config, trial_id)
self.add_trial(trial)
trial.set_status(Trial.RUNNING)
else:
trial = None
self.running_trial = trial
return trial
def stop_trial(self, trial):
super().stop_trial(trial)
self.running_trial = None