mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-09 16:26:20 +00:00
moving intermediate_results logging from model.py to huggingface/trainer.py (#403)
* replacing val_loss with automl_metric
This commit is contained in:
parent
569908fbe6
commit
dda4ac90a1
@ -591,6 +591,8 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
num_labels=self._num_labels,
|
num_labels=self._num_labels,
|
||||||
per_model_config=self._per_model_config,
|
per_model_config=self._per_model_config,
|
||||||
)
|
)
|
||||||
|
if hasattr(self._trainer, "intermediate_results"):
|
||||||
|
self._intermediate_results = self._trainer.intermediate_results
|
||||||
self._trainer = None
|
self._trainer = None
|
||||||
|
|
||||||
def _delete_one_ckpt(self, ckpt_location):
|
def _delete_one_ckpt(self, ckpt_location):
|
||||||
@ -656,7 +658,7 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
else np.argmax(predictions, axis=1)
|
else np.argmax(predictions, axis=1)
|
||||||
)
|
)
|
||||||
metric_dict = {
|
metric_dict = {
|
||||||
"val_loss": metric_loss_score(
|
"automl_metric": metric_loss_score(
|
||||||
metric_name=self._metric, y_predict=predictions, y_true=labels
|
metric_name=self._metric, y_predict=predictions, y_true=labels
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -669,10 +671,7 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
X_train=self._X_train,
|
X_train=self._X_train,
|
||||||
y_train=self._y_train,
|
y_train=self._y_train,
|
||||||
)
|
)
|
||||||
metric_dict["val_loss"] = loss
|
metric_dict["automl_metric"] = loss
|
||||||
if not hasattr(self, "intermediate_results"):
|
|
||||||
self.intermediate_results = []
|
|
||||||
self.intermediate_results.append(metric_dict)
|
|
||||||
return metric_dict
|
return metric_dict
|
||||||
|
|
||||||
def _init_model_for_predict(self, X_test):
|
def _init_model_for_predict(self, X_test):
|
||||||
|
@ -74,6 +74,9 @@ class TrainerForAuto(Seq2SeqTrainer):
|
|||||||
ignore_keys,
|
ignore_keys,
|
||||||
metric_key_prefix,
|
metric_key_prefix,
|
||||||
)
|
)
|
||||||
|
if not hasattr(self, "intermediate_results"):
|
||||||
|
self.intermediate_results = []
|
||||||
|
self.intermediate_results.append(metrics)
|
||||||
# if metrics:
|
# if metrics:
|
||||||
# for key in list(metrics.keys()):
|
# for key in list(metrics.keys()):
|
||||||
# if key.startswith("eval_"):
|
# if key.startswith("eval_"):
|
||||||
|
@ -36,7 +36,7 @@ def custom_metric(
|
|||||||
metrics = trainer.evaluate(eval_dataset)
|
metrics = trainer.evaluate(eval_dataset)
|
||||||
estimator._metric = estimator_metric_backup
|
estimator._metric = estimator_metric_backup
|
||||||
|
|
||||||
return metrics.pop("eval_val_loss"), metrics
|
return metrics.pop("eval_automl_metric"), metrics
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
|
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user