diff --git a/flaml/model.py b/flaml/model.py index 61ea6b456..6dfdbc2b1 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -591,6 +591,8 @@ class TransformersEstimator(BaseEstimator): num_labels=self._num_labels, per_model_config=self._per_model_config, ) + if hasattr(self._trainer, "intermediate_results"): + self._intermediate_results = self._trainer.intermediate_results self._trainer = None def _delete_one_ckpt(self, ckpt_location): @@ -656,7 +658,7 @@ class TransformersEstimator(BaseEstimator): else np.argmax(predictions, axis=1) ) metric_dict = { - "val_loss": metric_loss_score( + "automl_metric": metric_loss_score( metric_name=self._metric, y_predict=predictions, y_true=labels ) } @@ -669,10 +671,7 @@ class TransformersEstimator(BaseEstimator): X_train=self._X_train, y_train=self._y_train, ) - metric_dict["val_loss"] = loss - if not hasattr(self, "intermediate_results"): - self.intermediate_results = [] - self.intermediate_results.append(metric_dict) + metric_dict["automl_metric"] = loss return metric_dict def _init_model_for_predict(self, X_test): diff --git a/flaml/nlp/huggingface/trainer.py b/flaml/nlp/huggingface/trainer.py index 2bd81bf22..b657fc535 100644 --- a/flaml/nlp/huggingface/trainer.py +++ b/flaml/nlp/huggingface/trainer.py @@ -74,6 +74,9 @@ class TrainerForAuto(Seq2SeqTrainer): ignore_keys, metric_key_prefix, ) + if not hasattr(self, "intermediate_results"): + self.intermediate_results = [] + self.intermediate_results.append(metrics) # if metrics: # for key in list(metrics.keys()): # if key.startswith("eval_"): diff --git a/test/nlp/test_autohf_custom_metric.py b/test/nlp/test_autohf_custom_metric.py index 861cdecd7..613857bd7 100644 --- a/test/nlp/test_autohf_custom_metric.py +++ b/test/nlp/test_autohf_custom_metric.py @@ -36,7 +36,7 @@ def custom_metric( metrics = trainer.evaluate(eval_dataset) estimator._metric = estimator_metric_backup - return metrics.pop("eval_val_loss"), metrics + return metrics.pop("eval_automl_metric"), metrics @pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")