"""! * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. """ import json from typing import IO from contextlib import contextmanager import logging logger = logging.getLogger("flaml.automl") class TrainingLogRecord(object): def __init__( self, record_id: int, iter_per_learner: int, logged_metric: float, trial_time: float, wall_clock_time: float, validation_loss: float, config: dict, learner: str, sample_size: int, ): self.record_id = record_id self.iter_per_learner = iter_per_learner self.logged_metric = logged_metric self.trial_time = trial_time self.wall_clock_time = wall_clock_time self.validation_loss = validation_loss self.config = config self.learner = learner self.sample_size = sample_size def dump(self, fp: IO[str]): d = vars(self) return json.dump(d, fp) @classmethod def load(cls, json_str: str): d = json.loads(json_str) return cls(**d) def __str__(self): return json.dumps(vars(self)) class TrainingLogCheckPoint(TrainingLogRecord): def __init__(self, curr_best_record_id: int): self.curr_best_record_id = curr_best_record_id class TrainingLogWriter(object): def __init__(self, output_filename: str): self.output_filename = output_filename self.file = None self.current_best_loss_record_id = None self.current_best_loss = float("+inf") self.current_sample_size = None self.current_record_id = 0 def open(self): self.file = open(self.output_filename, "w") def append_open(self): self.file = open(self.output_filename, "a") def append( self, it_counter: int, train_loss: float, trial_time: float, wall_clock_time: float, validation_loss, config, learner, sample_size, ): if self.file is None: raise IOError("Call open() to open the output file first.") if validation_loss is None: raise ValueError("TEST LOSS NONE ERROR!!!") record = TrainingLogRecord( self.current_record_id, it_counter, train_loss, trial_time, wall_clock_time, validation_loss, config, learner, sample_size, ) if ( validation_loss < self.current_best_loss or validation_loss == self.current_best_loss and self.current_sample_size is not None and sample_size > self.current_sample_size ): self.current_best_loss = validation_loss self.current_sample_size = sample_size self.current_best_loss_record_id = self.current_record_id self.current_record_id += 1 record.dump(self.file) self.file.write("\n") self.file.flush() def checkpoint(self): if self.file is None: raise IOError("Call open() to open the output file first.") if self.current_best_loss_record_id is None: logger.warning("flaml.training_log: checkpoint() called before any record is written, skipped.") return record = TrainingLogCheckPoint(self.current_best_loss_record_id) record.dump(self.file) self.file.write("\n") self.file.flush() def close(self): if self.file is not None: self.file.close() self.file = None # for pickle class TrainingLogReader(object): def __init__(self, filename: str): self.filename = filename self.file = None def open(self): self.file = open(self.filename) def records(self): if self.file is None: raise IOError("Call open() before reading log file.") for line in self.file: data = json.loads(line) if len(data) == 1: # Skip checkpoints. continue yield TrainingLogRecord(**data) def close(self): if self.file is not None: self.file.close() self.file = None # for pickle def get_record(self, record_id) -> TrainingLogRecord: if self.file is None: raise IOError("Call open() before reading log file.") for rec in self.records(): if rec.record_id == record_id: return rec raise ValueError(f"Cannot find record with id {record_id}.") @contextmanager def training_log_writer(filename: str, append: bool = False): try: w = TrainingLogWriter(filename) if not append: w.open() else: w.append_open() yield w finally: w.close() @contextmanager def training_log_reader(filename: str): try: r = TrainingLogReader(filename) r.open() yield r finally: r.close()