mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-12 01:36:37 +00:00

* Refactor into automl subpackage Moved some of the packages into an automl subpackage to tidy before the task-based refactor. This is in response to discussions with the group and a comment on the first task-based PR. Only changes here are moving subpackages and modules into the new automl, fixing imports to work with this structure and fixing some dependencies in setup.py. * Fix doc building post automl subpackage refactor * Fix broken links in website post automl subpackage refactor * Fix broken links in website post automl subpackage refactor * Remove vw from test deps as this is breaking the build * Move default back to the top-level I'd moved this to automl as that's where it's used internally, but had missed that this is actually part of the public interface so makes sense to live where it was. * Re-add top level modules with deprecation warnings flaml.data, flaml.ml and flaml.model are re-added to the top level, being re-exported from flaml.automl for backwards compatability. Adding a deprecation warning so that we can have a planned removal later. * Fix model.py line-endings * Pin pytorch-lightning to less than 1.8.0 We're seeing strange lightning related bugs from pytorch-forecasting since the release of lightning 1.8.0. Going to try constraining this to see if we have a fix. * Fix the lightning version pin Was optimistic with setting it in the 1.7.x range, but that isn't compatible with python 3.6 * Remove lightning version pin * Revert dependency version changes * Minor change to retrigger the build * Fix line endings in ml.py and model.py Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> Co-authored-by: EgorKraevTransferwise <egor.kraev@transferwise.com>
182 lines
5.0 KiB
Python
182 lines
5.0 KiB
Python
"""!
|
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
|
* Licensed under the MIT License.
|
|
"""
|
|
|
|
import json
|
|
from typing import IO
|
|
from contextlib import contextmanager
|
|
import logging
|
|
|
|
logger = logging.getLogger("flaml.automl")
|
|
|
|
|
|
class TrainingLogRecord(object):
|
|
def __init__(
|
|
self,
|
|
record_id: int,
|
|
iter_per_learner: int,
|
|
logged_metric: float,
|
|
trial_time: float,
|
|
wall_clock_time: float,
|
|
validation_loss: float,
|
|
config: dict,
|
|
learner: str,
|
|
sample_size: int,
|
|
):
|
|
self.record_id = record_id
|
|
self.iter_per_learner = iter_per_learner
|
|
self.logged_metric = logged_metric
|
|
self.trial_time = trial_time
|
|
self.wall_clock_time = wall_clock_time
|
|
self.validation_loss = validation_loss
|
|
self.config = config
|
|
self.learner = learner
|
|
self.sample_size = sample_size
|
|
|
|
def dump(self, fp: IO[str]):
|
|
d = vars(self)
|
|
return json.dump(d, fp)
|
|
|
|
@classmethod
|
|
def load(cls, json_str: str):
|
|
d = json.loads(json_str)
|
|
return cls(**d)
|
|
|
|
def __str__(self):
|
|
return json.dumps(vars(self))
|
|
|
|
|
|
class TrainingLogCheckPoint(TrainingLogRecord):
|
|
def __init__(self, curr_best_record_id: int):
|
|
self.curr_best_record_id = curr_best_record_id
|
|
|
|
|
|
class TrainingLogWriter(object):
|
|
def __init__(self, output_filename: str):
|
|
self.output_filename = output_filename
|
|
self.file = None
|
|
self.current_best_loss_record_id = None
|
|
self.current_best_loss = float("+inf")
|
|
self.current_sample_size = None
|
|
self.current_record_id = 0
|
|
|
|
def open(self):
|
|
self.file = open(self.output_filename, "w")
|
|
|
|
def append_open(self):
|
|
self.file = open(self.output_filename, "a")
|
|
|
|
def append(
|
|
self,
|
|
it_counter: int,
|
|
train_loss: float,
|
|
trial_time: float,
|
|
wall_clock_time: float,
|
|
validation_loss,
|
|
config,
|
|
learner,
|
|
sample_size,
|
|
):
|
|
if self.file is None:
|
|
raise IOError("Call open() to open the output file first.")
|
|
if validation_loss is None:
|
|
raise ValueError("TEST LOSS NONE ERROR!!!")
|
|
record = TrainingLogRecord(
|
|
self.current_record_id,
|
|
it_counter,
|
|
train_loss,
|
|
trial_time,
|
|
wall_clock_time,
|
|
validation_loss,
|
|
config,
|
|
learner,
|
|
sample_size,
|
|
)
|
|
if (
|
|
validation_loss < self.current_best_loss
|
|
or validation_loss == self.current_best_loss
|
|
and self.current_sample_size is not None
|
|
and sample_size > self.current_sample_size
|
|
):
|
|
self.current_best_loss = validation_loss
|
|
self.current_sample_size = sample_size
|
|
self.current_best_loss_record_id = self.current_record_id
|
|
self.current_record_id += 1
|
|
record.dump(self.file)
|
|
self.file.write("\n")
|
|
self.file.flush()
|
|
|
|
def checkpoint(self):
|
|
if self.file is None:
|
|
raise IOError("Call open() to open the output file first.")
|
|
if self.current_best_loss_record_id is None:
|
|
logger.warning(
|
|
"flaml.training_log: checkpoint() called before any record is written, skipped."
|
|
)
|
|
return
|
|
record = TrainingLogCheckPoint(self.current_best_loss_record_id)
|
|
record.dump(self.file)
|
|
self.file.write("\n")
|
|
self.file.flush()
|
|
|
|
def close(self):
|
|
if self.file is not None:
|
|
self.file.close()
|
|
self.file = None # for pickle
|
|
|
|
|
|
class TrainingLogReader(object):
|
|
def __init__(self, filename: str):
|
|
self.filename = filename
|
|
self.file = None
|
|
|
|
def open(self):
|
|
self.file = open(self.filename)
|
|
|
|
def records(self):
|
|
if self.file is None:
|
|
raise IOError("Call open() before reading log file.")
|
|
for line in self.file:
|
|
data = json.loads(line)
|
|
if len(data) == 1:
|
|
# Skip checkpoints.
|
|
continue
|
|
yield TrainingLogRecord(**data)
|
|
|
|
def close(self):
|
|
if self.file is not None:
|
|
self.file.close()
|
|
self.file = None # for pickle
|
|
|
|
def get_record(self, record_id) -> TrainingLogRecord:
|
|
if self.file is None:
|
|
raise IOError("Call open() before reading log file.")
|
|
for rec in self.records():
|
|
if rec.record_id == record_id:
|
|
return rec
|
|
raise ValueError(f"Cannot find record with id {record_id}.")
|
|
|
|
|
|
@contextmanager
|
|
def training_log_writer(filename: str, append: bool = False):
|
|
try:
|
|
w = TrainingLogWriter(filename)
|
|
if not append:
|
|
w.open()
|
|
else:
|
|
w.append_open()
|
|
yield w
|
|
finally:
|
|
w.close()
|
|
|
|
|
|
@contextmanager
|
|
def training_log_reader(filename: str):
|
|
try:
|
|
r = TrainingLogReader(filename)
|
|
r.open()
|
|
yield r
|
|
finally:
|
|
r.close()
|