autogen/flaml/tune/trial.py
2021-11-06 09:37:33 -07:00

148 lines
6.0 KiB
Python

# Copyright 2020 The Ray Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This source file is adapted here because ray does not fully support Windows.
# Copyright (c) Microsoft Corporation.
import uuid
import time
from numbers import Number
from collections import deque
def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
dt = dt.copy()
if prevent_delimiter and any(delimiter in key for key in dt):
# Raise if delimiter is any of the keys
raise ValueError(
"Found delimiter `{}` in key when trying to flatten array."
"Please avoid using the delimiter in your specification."
)
while any(isinstance(v, dict) for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if isinstance(value, dict):
for subkey, v in value.items():
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
raise ValueError(
"Found delimiter `{}` in key when trying to "
"flatten array. Please avoid using the delimiter "
"in your specification."
)
add[delimiter.join([key, str(subkey)])] = v
remove.append(key)
dt.update(add)
for k in remove:
del dt[k]
return dt
def unflatten_dict(dt, delimiter="/"):
"""Unflatten dict. Does not support unflattening lists."""
dict_type = type(dt)
out = dict_type()
for key, val in dt.items():
path = key.split(delimiter)
item = out
for k in path[:-1]:
item = item.setdefault(k, dict_type())
item[path[-1]] = val
return out
class Trial:
"""A trial object holds the state for one model training run.
Trials are themselves managed by the TrialRunner class, which implements
the event loop for submitting trial runs to a Ray cluster.
Trials start in the PENDING state, and transition to RUNNING once started.
On error it transitions to ERROR, otherwise TERMINATED on success.
Attributes:
trainable_name (str): Name of the trainable object to be executed.
config (dict): Provided configuration dictionary with evaluated params.
trial_id (str): Unique identifier for the trial.
local_dir (str): Local_dir as passed to tune.run.
logdir (str): Directory where the trial logs are saved.
evaluated_params (dict): Evaluated parameters by search algorithm,
experiment_tag (str): Identifying trial name to show in the console.
resources (Resources): Amount of resources that this trial will use.
status (str): One of PENDING, RUNNING, PAUSED, TERMINATED, ERROR/
error_file (str): Path to the errors that this trial has raised.
"""
PENDING = "PENDING"
RUNNING = "RUNNING"
PAUSED = "PAUSED"
TERMINATED = "TERMINATED"
ERROR = "ERROR"
@classmethod
def generate_id(cls):
return str(uuid.uuid1().hex)[:8]
def update_last_result(self, result):
if self.experiment_tag:
result.update(experiment_tag=self.experiment_tag)
self.last_result = result
self.last_update_time = time.time()
for metric, value in flatten_dict(result).items():
if isinstance(value, Number):
if metric not in self.metric_analysis:
self.metric_analysis[metric] = {
"max": value,
"min": value,
"avg": value,
"last": value,
}
self.metric_n_steps[metric] = {}
for n in self.n_steps:
key = "last-{:d}-avg".format(n)
self.metric_analysis[metric][key] = value
# Store n as string for correct restore.
self.metric_n_steps[metric][str(n)] = deque([value], maxlen=n)
else:
step = result["training_iteration"] or 1
self.metric_analysis[metric]["max"] = max(
value, self.metric_analysis[metric]["max"]
)
self.metric_analysis[metric]["min"] = min(
value, self.metric_analysis[metric]["min"]
)
self.metric_analysis[metric]["avg"] = (
1
/ step
* (value + (step - 1) * self.metric_analysis[metric]["avg"])
)
self.metric_analysis[metric]["last"] = value
for n in self.n_steps:
key = "last-{:d}-avg".format(n)
self.metric_n_steps[metric][str(n)].append(value)
self.metric_analysis[metric][key] = sum(
self.metric_n_steps[metric][str(n)]
) / len(self.metric_n_steps[metric][str(n)])
def set_status(self, status):
"""Sets the status of the trial."""
self.status = status
if status == Trial.RUNNING:
if self.start_time is None:
self.start_time = time.time()
def is_finished(self):
return self.status in [Trial.ERROR, Trial.TERMINATED]