diff --git a/flaml/automl.py b/flaml/automl.py index d6f23ce1a..fddf43540 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -1974,10 +1974,6 @@ class AutoML(BaseEstimator): self._min_sample_size = min_sample_size self._prepare_data(eval_method, split_ratio, n_splits) - if _is_nlp_task(self._state.task): - self._state.fit_kwargs["metric"] = metric - self._state.fit_kwargs["use_ray"] = self._use_ray - self._sample = ( sample and task != "rank" @@ -1996,24 +1992,50 @@ class AutoML(BaseEstimator): metric = "mape" elif self._state.task == "rank": metric = "ndcg" + elif _is_nlp_task(self._state.task): + from .nlp.utils import load_default_huggingface_metric_for_task + + metric = load_default_huggingface_metric_for_task(self._state.task) else: metric = "r2" + + if _is_nlp_task(self._state.task): + self._state.fit_kwargs["metric"] = metric + self._state.fit_kwargs["use_ray"] = self._use_ray + self._state.metric = metric - if metric in [ - "r2", - "accuracy", - "roc_auc", - "roc_auc_ovr", - "roc_auc_ovo", - "f1", - "ap", - "micro_f1", - "macro_f1", - "ndcg", - ]: - error_metric = f"1-{metric}" - elif isinstance(metric, str): - error_metric = metric + + def is_to_reverse_metric(metric, task): + if metric.startswith("ndcg"): + return True, f"1-{metric}" + if metric in [ + "r2", + "accuracy", + "roc_auc", + "roc_auc_ovr", + "roc_auc_ovo", + "f1", + "ap", + "micro_f1", + "macro_f1", + ]: + return True, f"1-{metric}" + if _is_nlp_task(task): + from .ml import huggingface_metric_to_mode + + if ( + metric in huggingface_metric_to_mode + and huggingface_metric_to_mode[metric] == "max" + ): + return True, f"-{metric}" + return False, None + + if isinstance(metric, str): + is_reverse, reverse_metric = is_to_reverse_metric(metric, task) + if is_reverse: + error_metric = reverse_metric + else: + error_metric = metric else: error_metric = "customized metric" logger.info(f"Minimizing error metric: {error_metric}") diff --git a/flaml/ml.py b/flaml/ml.py index 8293455a1..145b045d0 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -38,6 +38,52 @@ import logging logger = logging.getLogger(__name__) +sklearn_metric_name_set = { + "r2", + "rmse", + "mae", + "mse", + "accuracy", + "roc_auc", + "roc_auc_ovr", + "roc_auc_ovo", + "log_loss", + "mape", + "f1", + "ap", + "ndcg", + "micro_f1", + "macro_f1", +} +huggingface_metric_to_mode = { + "accuracy": "max", + "bertscore": "max", + "bleu": "max", + "bleurt": "max", + "cer": "min", + "chrf": "min", + "code_eval": "max", + "comet": "max", + "competition_math": "max", + "coval": "max", + "cuad": "max", + "f1": "max", + "gleu": "max", + "google_bleu": "max", + "matthews_correlation": "max", + "meteor": "max", + "pearsonr": "max", + "precision": "max", + "recall": "max", + "rouge": "max", + "sacrebleu": "max", + "sari": "max", + "seqeval": "max", + "spearmanr": "max", + "ter": "min", + "wer": "min", +} + def get_estimator_class(task, estimator_name): # when adding a new learner, need to add an elif branch @@ -75,6 +121,74 @@ def get_estimator_class(task, estimator_name): return estimator_class +def metric_loss_score( + metric_name, + y_predict, + y_true, + labels=None, + sample_weight=None, + groups=None, +): + if is_in_sklearn_metric_name_set(metric_name): + return sklearn_metric_loss_score( + metric_name, y_predict, y_true, labels, sample_weight, groups + ) + else: + """ + hf's datasets.load_metric("pearsonr") returns nan (hf's bug), overwriting it here + """ + if metric_name == "spearmanr": + from scipy.stats import spearmanr + + y_true = y_true.to_list() if type(y_true) == pd.Series else list(y_true) + score = spearmanr(list(y_predict), y_true)[0] + metric_mode = "max" + elif metric_name == "pearsonr": + from scipy.stats import pearsonr + + y_true = y_true.to_list() if type(y_true) == pd.Series else list(y_true) + score = pearsonr(list(y_predict), y_true)[0] + metric_mode = "max" + else: + try: + import datasets + + metric = datasets.load_metric(metric_name) + metric_mode = huggingface_metric_to_mode[metric_name] + score = metric.compute(predictions=y_predict, references=y_true)[ + metric_name + ] + except ImportError: + raise Exception( + metric_name + + " is not an built-in sklearn metric and nlp is not installed. " + "Currently built-in sklearn metrics are: " + "r2, rmse, mae, mse, accuracy, roc_auc, roc_auc_ovr, roc_auc_ovo," + "log_loss, mape, f1, micro_f1, macro_f1, ap. " + "If the metric is an nlp metric, please pip install flaml[nlp] ", + "or pass a customized metric function to AutoML.fit(metric=func)", + ) + # If the metric is not found from huggingface dataset metric list (i.e., FileNotFoundError) + # ask the user to provide a custom metric + except FileNotFoundError: + raise Exception( + metric_name + + " is neither an sklearn metric nor a huggingface metric. " + "Currently built-in sklearn metrics are: " + "r2, rmse, mae, mse, accuracy, roc_auc, roc_auc_ovr, roc_auc_ovo," + "log_loss, mape, f1, micro_f1, macro_f1, ap. " + "Currently built-in huggingface metrics are: " + + ", ".join(huggingface_metric_to_mode.keys()) + + ". Please pass a customized metric function to AutoML.fit(metric=func)" + ) + multiplier = -1 if metric_mode == "max" else 1 + return score * multiplier + + +def is_in_sklearn_metric_name_set(metric_name): + return metric_name.startswith("ndcg") or metric_name in sklearn_metric_name_set + + def sklearn_metric_loss_score( metric_name, y_predict, @@ -102,6 +216,7 @@ def sklearn_metric_loss_score( score: A float number of the loss, the lower the better. """ metric_name = metric_name.lower() + if "r2" == metric_name: score = 1.0 - r2_score(y_true, y_predict, sample_weight=sample_weight) elif metric_name == "rmse": @@ -162,14 +277,6 @@ def sklearn_metric_loss_score( score += 1 else: score = 1 - ndcg_score([y_true], [y_predict]) - else: - raise ValueError( - metric_name + " is not a built-in metric, " - "currently built-in metrics are: " - "r2, rmse, mae, mse, accuracy, roc_auc, roc_auc_ovr, roc_auc_ovo," - "log_loss, mape, f1, micro_f1, macro_f1, ap. " - "please pass a customized metric function to AutoML.fit(metric=func)" - ) return score @@ -203,13 +310,13 @@ def _eval_estimator( pred_start = time.time() val_pred_y = get_y_pred(estimator, X_val, eval_metric, obj) pred_time = (time.time() - pred_start) / X_val.shape[0] - val_loss = sklearn_metric_loss_score( + val_loss = metric_loss_score( eval_metric, val_pred_y, y_val, labels, weight_val, groups_val ) metric_for_logging = {"pred_time": pred_time} if log_training_metric: train_pred_y = get_y_pred(estimator, X_train, eval_metric, obj) - metric_for_logging["train_loss"] = sklearn_metric_loss_score( + metric_for_logging["train_loss"] = metric_loss_score( eval_metric, train_pred_y, y_train, diff --git a/flaml/model.py b/flaml/model.py index 79404675e..fe427b284 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -590,9 +590,7 @@ class TransformersEstimator(BaseEstimator): return best_ckpt def _compute_metrics_by_dataset_name(self, eval_pred): - from .ml import sklearn_metric_loss_score - import datasets - from .nlp.utils import load_default_huggingface_metric_for_task + from .ml import metric_loss_score predictions, labels = eval_pred predictions = ( @@ -601,25 +599,11 @@ class TransformersEstimator(BaseEstimator): else np.argmax(predictions, axis=1) ) - if isinstance(self._metric_name, str): - return { - "val_loss": sklearn_metric_loss_score( - metric_name=self._metric_name, y_predict=predictions, y_true=labels + return { + "val_loss": metric_loss_score( + metric_name=self._metric_name, y_predict=predictions, y_true=labels ) } - else: - ( - default_metric_name, - default_metric_mode, - ) = load_default_huggingface_metric_for_task(self._task) - metric = datasets.load_metric(default_metric_name) - multiplier = -1 if default_metric_mode == "max" else 1 - return { - "val_loss": metric.compute(predictions=predictions, references=labels)[ - default_metric_name - ] - * multiplier - } def predict_proba(self, X_test): from datasets import Dataset @@ -673,7 +657,7 @@ class TransformersEstimator(BaseEstimator): if self._task == SEQCLASSIFICATION: return np.argmax(predictions.predictions, axis=1) elif self._task == SEQREGRESSION: - return predictions.predictions + return predictions.predictions.reshape((len(predictions.predictions),)) # TODO: elif self._task == your task, return the corresponding prediction # e.g., if your task == QUESTIONANSWERING, you need to return the answer instead # of the index diff --git a/flaml/nlp/utils.py b/flaml/nlp/utils.py index 97526081a..27edd6555 100644 --- a/flaml/nlp/utils.py +++ b/flaml/nlp/utils.py @@ -2,14 +2,23 @@ import argparse from dataclasses import dataclass, field from typing import Dict, Any +from ..data import ( + SUMMARIZATION, + SEQREGRESSION, + SEQCLASSIFICATION, + NLG_TASKS +) + def load_default_huggingface_metric_for_task(task): from ..data import SEQCLASSIFICATION, SEQREGRESSION if task == SEQCLASSIFICATION: - return "accuracy", "max" + return "accuracy" elif task == SEQREGRESSION: - return "rmse", "max" + return "rmse" + elif task == SUMMARIZATION: + return "rouge" # TODO: elif task == your task, return the default metric name for your task, # e.g., if task == MULTIPLECHOICE, return "accuracy" # notice this metric name has to be in ['accuracy', 'bertscore', 'bleu', 'bleurt',