diff --git a/flaml/automl.py b/flaml/automl.py index 29efebbae..3d48575df 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -416,14 +416,38 @@ class AutoML(BaseEstimator): .. code-block:: python def custom_metric( - X_test, y_test, estimator, labels, - X_train, y_train, weight_test=None, weight_train=None, - config=None, groups_test=None, groups_train=None, + X_val, y_val, estimator, labels, + X_train, y_train, weight_val=None, weight_train=None, + config=None, groups_val=None, groups_train=None, ): return metric_to_minimize, metrics_to_log which returns a float number as the minimization objective, - and a dictionary as the metrics to log. + and a dictionary as the metrics to log. E.g., + + .. code-block:: python + + def custom_metric( + X_val, y_val, estimator, labels, + X_train, y_train, weight_val=None, weight_train=None, + **args, + ): + from sklearn.metrics import log_loss + import time + + start = time.time() + y_pred = estimator.predict_proba(X_val) + pred_time = (time.time() - start) / len(X_val) + val_loss = log_loss(y_val, y_pred, labels=labels, sample_weight=weight_val) + y_pred = estimator.predict_proba(X_train) + train_loss = log_loss(y_train, y_pred, labels=labels, sample_weight=weight_train) + alpha = 0.5 + return val_loss * (1 + alpha) - alpha * train_loss, { + "val_loss": val_loss, + "train_loss": train_loss, + "pred_time": pred_time, + } + task: A string of the task type, e.g., 'classification', 'regression', 'ts_forecast', 'rank', 'seq-classification', 'seq-regression'. @@ -1641,14 +1665,38 @@ class AutoML(BaseEstimator): .. code-block:: python def custom_metric( - X_test, y_test, estimator, labels, - X_train, y_train, weight_test=None, weight_train=None, - config=None, groups_test=None, groups_train=None, + X_val, y_val, estimator, labels, + X_train, y_train, weight_val=None, weight_train=None, + config=None, groups_val=None, groups_train=None, ): return metric_to_minimize, metrics_to_log which returns a float number as the minimization objective, - and a dictionary as the metrics to log. + and a dictionary as the metrics to log. E.g., + + .. code-block:: python + + def custom_metric( + X_val, y_val, estimator, labels, + X_train, y_train, weight_val=None, weight_train=None, + **args, + ): + from sklearn.metrics import log_loss + import time + + start = time.time() + y_pred = estimator.predict_proba(X_val) + pred_time = (time.time() - start) / len(X_val) + val_loss = log_loss(y_val, y_pred, labels=labels, sample_weight=weight_val) + y_pred = estimator.predict_proba(X_train) + train_loss = log_loss(y_train, y_pred, labels=labels, sample_weight=weight_train) + alpha = 0.5 + return val_loss * (1 + alpha) - alpha * train_loss, { + "val_loss": val_loss, + "train_loss": train_loss, + "pred_time": pred_time, + } + task: A string of the task type, e.g., 'classification', 'regression', 'ts_forecast', 'rank', 'seq-classification', 'seq-regression'. diff --git a/flaml/ml.py b/flaml/ml.py index 9ad75695e..0441691b4 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -189,10 +189,10 @@ def _eval_estimator( estimator, X_train, y_train, - X_test, - y_test, - weight_test, - groups_test, + X_val, + y_val, + weight_val, + groups_val, eval_metric, obj, labels=None, @@ -201,10 +201,10 @@ def _eval_estimator( ): if isinstance(eval_metric, str): pred_start = time.time() - test_pred_y = get_y_pred(estimator, X_test, eval_metric, obj) - pred_time = (time.time() - pred_start) / X_test.shape[0] - test_loss = sklearn_metric_loss_score( - eval_metric, test_pred_y, y_test, labels, weight_test, groups_test + val_pred_y = get_y_pred(estimator, X_val, eval_metric, obj) + pred_time = (time.time() - pred_start) / X_val.shape[0] + val_loss = sklearn_metric_loss_score( + eval_metric, val_pred_y, y_val, labels, weight_val, groups_val ) metric_for_logging = {} if log_training_metric: @@ -218,34 +218,34 @@ def _eval_estimator( fit_kwargs.get("groups"), ) else: # customized metric function - test_loss, metric_for_logging = eval_metric( - X_test, - y_test, + val_loss, metric_for_logging = eval_metric( + X_val, + y_val, estimator, labels, X_train, y_train, - weight_test, + weight_val, fit_kwargs.get("sample_weight"), config, - groups_test, + groups_val, fit_kwargs.get("groups"), ) pred_time = metric_for_logging.get("pred_time", 0) - test_pred_y = None - # eval_metric may return test_pred_y but not necessarily. Setting None for now. - return test_loss, metric_for_logging, pred_time, test_pred_y + val_pred_y = None + # eval_metric may return val_pred_y but not necessarily. Setting None for now. + return val_loss, metric_for_logging, pred_time, val_pred_y -def get_test_loss( +def get_val_loss( config, estimator, X_train, y_train, - X_test, - y_test, - weight_test, - groups_test, + X_val, + y_val, + weight_val, + groups_val, eval_metric, obj, labels=None, @@ -255,20 +255,20 @@ def get_test_loss( ): start = time.time() - # if groups_test is not None: - # fit_kwargs['groups_val'] = groups_test - # fit_kwargs['X_val'] = X_test - # fit_kwargs['y_val'] = y_test + # if groups_val is not None: + # fit_kwargs['groups_val'] = groups_val + # fit_kwargs['X_val'] = X_val + # fit_kwargs['y_val'] = y_val estimator.fit(X_train, y_train, budget, **fit_kwargs) - test_loss, metric_for_logging, pred_time, _ = _eval_estimator( + val_loss, metric_for_logging, pred_time, _ = _eval_estimator( config, estimator, X_train, y_train, - X_test, - y_test, - weight_test, - groups_test, + X_val, + y_val, + weight_val, + groups_val, eval_metric, obj, labels, @@ -276,7 +276,7 @@ def get_test_loss( fit_kwargs, ) train_time = time.time() - start - return test_loss, metric_for_logging, train_time, pred_time + return val_loss, metric_for_logging, train_time, pred_time def evaluate_model_CV( @@ -349,7 +349,7 @@ def evaluate_model_CV( groups_val = groups[val_index] else: groups_val = None - val_loss_i, metric_i, train_time_i, pred_time_i = get_test_loss( + val_loss_i, metric_i, train_time_i, pred_time_i = get_val_loss( config, estimator, X_train, @@ -427,7 +427,7 @@ def compute_estimator( n_jobs=n_jobs, ) if "holdout" == eval_method: - val_loss, metric_for_logging, train_time, pred_time = get_test_loss( + val_loss, metric_for_logging, train_time, pred_time = get_val_loss( config_dic, estimator, X_train, diff --git a/notebook/flaml_automl.ipynb b/notebook/flaml_automl.ipynb index 52a678f12..9b29f95c6 100644 --- a/notebook/flaml_automl.ipynb +++ b/notebook/flaml_automl.ipynb @@ -34,9 +34,106 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: flaml[notebook] in /usr/local/lib/python3.9/site-packages (0.7.1)\n", + "Requirement already satisfied: scipy>=1.4.1 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.7.2)\n", + "Requirement already satisfied: lightgbm>=2.3.1 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (3.3.1)\n", + "Requirement already satisfied: pandas>=1.1.4 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.3.4)\n", + "Requirement already satisfied: NumPy>=1.16.2 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.21.4)\n", + "Requirement already satisfied: xgboost<=1.3.3,>=0.90 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.3.3)\n", + "Requirement already satisfied: scikit-learn>=0.24 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.0.1)\n", + "Requirement already satisfied: catboost>=0.26 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.0.3)\n", + "Requirement already satisfied: jupyter in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.0.0)\n", + "Requirement already satisfied: rgf-python in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (3.11.0)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (3.5.0)\n", + "Requirement already satisfied: openml==0.10.2 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (0.10.2)\n", + "Requirement already satisfied: xmltodict in /usr/local/lib/python3.9/site-packages (from openml==0.10.2->flaml[notebook]) (0.12.0)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.9/site-packages (from openml==0.10.2->flaml[notebook]) (2.26.0)\n", + "Requirement already satisfied: liac-arff>=2.4.0 in /usr/local/lib/python3.9/site-packages (from openml==0.10.2->flaml[notebook]) (2.5.0)\n", + "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.9/site-packages (from openml==0.10.2->flaml[notebook]) (2.8.2)\n", + "Requirement already satisfied: plotly in /usr/local/lib/python3.9/site-packages (from catboost>=0.26->flaml[notebook]) (5.4.0)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.9/site-packages (from catboost>=0.26->flaml[notebook]) (1.16.0)\n", + "Requirement already satisfied: graphviz in /usr/local/lib/python3.9/site-packages (from catboost>=0.26->flaml[notebook]) (0.18.2)\n", + "Requirement already satisfied: wheel in /usr/local/lib/python3.9/site-packages (from lightgbm>=2.3.1->flaml[notebook]) (0.37.0)\n", + "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.9/site-packages (from pandas>=1.1.4->flaml[notebook]) (2021.3)\n", + "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.9/site-packages (from scikit-learn>=0.24->flaml[notebook]) (1.1.0)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/site-packages (from scikit-learn>=0.24->flaml[notebook]) (3.0.0)\n", + "Requirement already satisfied: ipykernel in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (6.5.1)\n", + "Requirement already satisfied: qtconsole in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (5.2.0)\n", + "Requirement already satisfied: notebook in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (6.4.6)\n", + "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (7.6.5)\n", + "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (6.4.0)\n", + "Requirement already satisfied: nbconvert in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (6.3.0)\n", + "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (3.0.6)\n", + "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (8.4.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (1.3.2)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (4.28.1)\n", + "Requirement already satisfied: setuptools-scm>=4 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (6.3.2)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (0.11.0)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (21.3)\n", + "Requirement already satisfied: tomli>=1.0.0 in /usr/local/lib/python3.9/site-packages (from setuptools-scm>=4->matplotlib->flaml[notebook]) (1.2.2)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/site-packages (from setuptools-scm>=4->matplotlib->flaml[notebook]) (57.5.0)\n", + "Requirement already satisfied: traitlets<6.0,>=5.1.0 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (5.1.1)\n", + "Requirement already satisfied: tornado<7.0,>=4.2 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (6.1)\n", + "Requirement already satisfied: jupyter-client<8.0 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (7.0.6)\n", + "Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (7.29.0)\n", + "Requirement already satisfied: debugpy<2.0,>=1.0.0 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (1.5.1)\n", + "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.0 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (0.1.3)\n", + "Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.9/site-packages (from ipywidgets->jupyter->flaml[notebook]) (3.5.2)\n", + "Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.9/site-packages (from ipywidgets->jupyter->flaml[notebook]) (1.0.2)\n", + "Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.9/site-packages (from ipywidgets->jupyter->flaml[notebook]) (5.1.3)\n", + "Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.9/site-packages (from ipywidgets->jupyter->flaml[notebook]) (0.2.0)\n", + "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.9/site-packages (from jupyter-console->jupyter->flaml[notebook]) (3.0.22)\n", + "Requirement already satisfied: pygments in /usr/local/lib/python3.9/site-packages (from jupyter-console->jupyter->flaml[notebook]) (2.10.0)\n", + "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (4.9.1)\n", + "Requirement already satisfied: bleach in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (4.1.0)\n", + "Requirement already satisfied: jinja2>=2.4 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (3.0.3)\n", + "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.3)\n", + "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.5.9)\n", + "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.1.2)\n", + "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (1.5.0)\n", + "Requirement already satisfied: defusedxml in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.7.1)\n", + "Requirement already satisfied: testpath in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.5.0)\n", + "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.8.4)\n", + "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (1.8.0)\n", + "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (0.12.0)\n", + "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (22.3.0)\n", + "Requirement already satisfied: nest-asyncio>=1.5 in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (1.5.1)\n", + "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (21.1.0)\n", + "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (0.12.1)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.9/site-packages (from plotly->catboost>=0.26->flaml[notebook]) (8.0.1)\n", + "Requirement already satisfied: qtpy in /usr/local/lib/python3.9/site-packages (from qtconsole->jupyter->flaml[notebook]) (1.11.2)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/site-packages (from requests->openml==0.10.2->flaml[notebook]) (1.26.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/site-packages (from requests->openml==0.10.2->flaml[notebook]) (2021.10.8)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/site-packages (from requests->openml==0.10.2->flaml[notebook]) (3.3)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/site-packages (from requests->openml==0.10.2->flaml[notebook]) (2.0.7)\n", + "Requirement already satisfied: decorator in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (5.1.0)\n", + "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (4.8.0)\n", + "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (0.18.1)\n", + "Requirement already satisfied: pickleshare in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (0.7.5)\n", + "Requirement already satisfied: backcall in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (0.2.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/site-packages (from jinja2>=2.4->nbconvert->jupyter->flaml[notebook]) (2.0.1)\n", + "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.9/site-packages (from nbformat>=4.2.0->ipywidgets->jupyter->flaml[notebook]) (4.2.1)\n", + "Requirement already satisfied: wcwidth in /usr/local/lib/python3.9/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter->flaml[notebook]) (0.2.5)\n", + "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.9/site-packages (from terminado>=0.8.3->notebook->jupyter->flaml[notebook]) (0.7.0)\n", + "Requirement already satisfied: cffi>=1.0.0 in /usr/local/lib/python3.9/site-packages (from argon2-cffi->notebook->jupyter->flaml[notebook]) (1.15.0)\n", + "Requirement already satisfied: webencodings in /usr/local/lib/python3.9/site-packages (from bleach->nbconvert->jupyter->flaml[notebook]) (0.5.1)\n", + "Requirement already satisfied: pycparser in /usr/local/lib/python3.9/site-packages (from cffi>=1.0.0->argon2-cffi->notebook->jupyter->flaml[notebook]) (2.21)\n", + "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /usr/local/lib/python3.9/site-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (0.8.2)\n", + "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.9/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter->flaml[notebook]) (21.2.0)\n", + "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.9/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter->flaml[notebook]) (0.18.0)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n", + "\u001b[33mWARNING: You are using pip version 21.3; however, version 21.3.1 is available.\n", + "You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n" + ] + } + ], "source": [ "!pip install flaml[notebook];\n", "# from v0.6.6, catboost is made an optional dependency to build conda package.\n", @@ -977,7 +1074,7 @@ "source": [ "## 5. Customized Metric\n", "\n", - "It's also easy to customize the optimization metric. As an example, we demonstrate with a custom metric function which combines training loss and test loss as the final loss to minimize." + "It's also easy to customize the optimization metric. As an example, we demonstrate with a custom metric function which combines training loss and validation loss as the final loss to minimize." ] }, { @@ -986,22 +1083,22 @@ "metadata": {}, "outputs": [], "source": [ - "def custom_metric(X_test, y_test, estimator, labels, X_train, y_train,\n", - " weight_test=None, weight_train=None, config=None,\n", - " groups_test=None, groups_train=None):\n", + "def custom_metric(X_val, y_val, estimator, labels, X_train, y_train,\n", + " weight_val=None, weight_train=None, config=None,\n", + " groups_val=None, groups_train=None):\n", " from sklearn.metrics import log_loss\n", " import time\n", " start = time.time()\n", - " y_pred = estimator.predict_proba(X_test)\n", - " pred_time = (time.time() - start) / len(X_test)\n", - " test_loss = log_loss(y_test, y_pred, labels=labels,\n", - " sample_weight=weight_test)\n", + " y_pred = estimator.predict_proba(X_val)\n", + " pred_time = (time.time() - start) / len(X_val)\n", + " val_loss = log_loss(y_val, y_pred, labels=labels,\n", + " sample_weight=weight_val)\n", " y_pred = estimator.predict_proba(X_train)\n", " train_loss = log_loss(y_train, y_pred, labels=labels,\n", " sample_weight=weight_train)\n", " alpha = 0.5\n", - " return test_loss * (1 + alpha) - alpha * train_loss, {\n", - " \"test_loss\": test_loss, \"train_loss\": train_loss, \"pred_time\": pred_time\n", + " return val_loss * (1 + alpha) - alpha * train_loss, {\n", + " \"val_loss\": val_loss, \"train_loss\": train_loss, \"pred_time\": pred_time\n", " }\n", " # two elements are returned:\n", " # the first element is the metric to minimize as a float number,\n", diff --git a/test/automl/test_multiclass.py b/test/automl/test_multiclass.py index 9c80e66a5..28a6b148d 100644 --- a/test/automl/test_multiclass.py +++ b/test/automl/test_multiclass.py @@ -98,30 +98,30 @@ class MyLargeLGBM(LGBMEstimator): def custom_metric( - X_test, - y_test, + X_val, + y_val, estimator, labels, X_train, y_train, - weight_test=None, + weight_val=None, weight_train=None, config=None, - groups_test=None, + groups_val=None, groups_train=None, ): from sklearn.metrics import log_loss import time start = time.time() - y_pred = estimator.predict_proba(X_test) - pred_time = (time.time() - start) / len(X_test) - test_loss = log_loss(y_test, y_pred, labels=labels, sample_weight=weight_test) + y_pred = estimator.predict_proba(X_val) + pred_time = (time.time() - start) / len(X_val) + val_loss = log_loss(y_val, y_pred, labels=labels, sample_weight=weight_val) y_pred = estimator.predict_proba(X_train) train_loss = log_loss(y_train, y_pred, labels=labels, sample_weight=weight_train) alpha = 0.5 - return test_loss * (1 + alpha) - alpha * train_loss, { - "test_loss": test_loss, + return val_loss * (1 + alpha) - alpha * train_loss, { + "val_loss": val_loss, "train_loss": train_loss, "pred_time": pred_time, }