mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-03 03:10:04 +00:00
148 lines
6.0 KiB
Python
148 lines
6.0 KiB
Python
# Copyright 2020 The Ray Authors.
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# This source file is adapted here because ray does not fully support Windows.
|
|
|
|
# Copyright (c) Microsoft Corporation.
|
|
import uuid
|
|
import time
|
|
from numbers import Number
|
|
from collections import deque
|
|
|
|
|
|
def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
|
|
dt = dt.copy()
|
|
if prevent_delimiter and any(delimiter in key for key in dt):
|
|
# Raise if delimiter is any of the keys
|
|
raise ValueError(
|
|
"Found delimiter `{}` in key when trying to flatten array."
|
|
"Please avoid using the delimiter in your specification."
|
|
)
|
|
while any(isinstance(v, dict) for v in dt.values()):
|
|
remove = []
|
|
add = {}
|
|
for key, value in dt.items():
|
|
if isinstance(value, dict):
|
|
for subkey, v in value.items():
|
|
if prevent_delimiter and delimiter in subkey:
|
|
# Raise if delimiter is in any of the subkeys
|
|
raise ValueError(
|
|
"Found delimiter `{}` in key when trying to "
|
|
"flatten array. Please avoid using the delimiter "
|
|
"in your specification."
|
|
)
|
|
add[delimiter.join([key, str(subkey)])] = v
|
|
remove.append(key)
|
|
dt.update(add)
|
|
for k in remove:
|
|
del dt[k]
|
|
return dt
|
|
|
|
|
|
def unflatten_dict(dt, delimiter="/"):
|
|
"""Unflatten dict. Does not support unflattening lists."""
|
|
dict_type = type(dt)
|
|
out = dict_type()
|
|
for key, val in dt.items():
|
|
path = key.split(delimiter)
|
|
item = out
|
|
for k in path[:-1]:
|
|
item = item.setdefault(k, dict_type())
|
|
item[path[-1]] = val
|
|
return out
|
|
|
|
|
|
class Trial:
|
|
"""A trial object holds the state for one model training run.
|
|
Trials are themselves managed by the TrialRunner class, which implements
|
|
the event loop for submitting trial runs to a Ray cluster.
|
|
Trials start in the PENDING state, and transition to RUNNING once started.
|
|
On error it transitions to ERROR, otherwise TERMINATED on success.
|
|
Attributes:
|
|
trainable_name (str): Name of the trainable object to be executed.
|
|
config (dict): Provided configuration dictionary with evaluated params.
|
|
trial_id (str): Unique identifier for the trial.
|
|
local_dir (str): Local_dir as passed to tune.run.
|
|
logdir (str): Directory where the trial logs are saved.
|
|
evaluated_params (dict): Evaluated parameters by search algorithm,
|
|
experiment_tag (str): Identifying trial name to show in the console.
|
|
resources (Resources): Amount of resources that this trial will use.
|
|
status (str): One of PENDING, RUNNING, PAUSED, TERMINATED, ERROR/
|
|
error_file (str): Path to the errors that this trial has raised.
|
|
"""
|
|
|
|
PENDING = "PENDING"
|
|
RUNNING = "RUNNING"
|
|
PAUSED = "PAUSED"
|
|
TERMINATED = "TERMINATED"
|
|
ERROR = "ERROR"
|
|
|
|
@classmethod
|
|
def generate_id(cls):
|
|
return str(uuid.uuid1().hex)[:8]
|
|
|
|
def update_last_result(self, result):
|
|
if self.experiment_tag:
|
|
result.update(experiment_tag=self.experiment_tag)
|
|
|
|
self.last_result = result
|
|
self.last_update_time = time.time()
|
|
|
|
for metric, value in flatten_dict(result).items():
|
|
if isinstance(value, Number):
|
|
if metric not in self.metric_analysis:
|
|
self.metric_analysis[metric] = {
|
|
"max": value,
|
|
"min": value,
|
|
"avg": value,
|
|
"last": value,
|
|
}
|
|
self.metric_n_steps[metric] = {}
|
|
for n in self.n_steps:
|
|
key = "last-{:d}-avg".format(n)
|
|
self.metric_analysis[metric][key] = value
|
|
# Store n as string for correct restore.
|
|
self.metric_n_steps[metric][str(n)] = deque([value], maxlen=n)
|
|
else:
|
|
step = result["training_iteration"] or 1
|
|
self.metric_analysis[metric]["max"] = max(
|
|
value, self.metric_analysis[metric]["max"]
|
|
)
|
|
self.metric_analysis[metric]["min"] = min(
|
|
value, self.metric_analysis[metric]["min"]
|
|
)
|
|
self.metric_analysis[metric]["avg"] = (
|
|
1
|
|
/ step
|
|
* (value + (step - 1) * self.metric_analysis[metric]["avg"])
|
|
)
|
|
self.metric_analysis[metric]["last"] = value
|
|
|
|
for n in self.n_steps:
|
|
key = "last-{:d}-avg".format(n)
|
|
self.metric_n_steps[metric][str(n)].append(value)
|
|
self.metric_analysis[metric][key] = sum(
|
|
self.metric_n_steps[metric][str(n)]
|
|
) / len(self.metric_n_steps[metric][str(n)])
|
|
|
|
def set_status(self, status):
|
|
"""Sets the status of the trial."""
|
|
self.status = status
|
|
if status == Trial.RUNNING:
|
|
if self.start_time is None:
|
|
self.start_time = time.time()
|
|
|
|
def is_finished(self):
|
|
return self.status in [Trial.ERROR, Trial.TERMINATED]
|