autogen/notebook/flaml_demo.ipynb

636 lines
86 KiB
Plaintext
Raw Normal View History

2020-12-04 09:40:27 -08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Copyright (c) 2020 Microsoft Corporation. All rights reserved. \n",
"\n",
"Licensed under the MIT License.\n",
"\n",
"# Demo of AutoML with FLAML Library\n",
"\n",
"\n",
"## 1. Introduction\n",
"\n",
"FLAML is a Python library (https://github.com/microsoft/FLAML) designed to automatically produce accurate machine learning models \n",
"with low computational cost. It is fast and cheap. The simple and lightweight design makes it easy \n",
"to use and extend, such as adding new learners. FLAML can \n",
"- serve as an economical AutoML engine,\n",
"- be used as a fast hyperparameter tuning tool, or \n",
"- be embedded in self-tuning software that requires low latency & resource in repetitive\n",
" tuning tasks.\n",
"\n",
"In this notebook, we use one real data example (binary classification) to showcase how to ues FLAML library.\n",
"\n",
"FLAML requires `Python>=3.6`. To run this notebook example, please install flaml with the [notebook] option:\n",
"```bash\n",
"pip install flaml[notebook]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2. Real Data Example\n",
"### Load data and preprocess\n",
"\n",
"Download [Airlines dataset](https://www.openml.org/d/1169) from OpenML. The task is to predict whether a given flight will be delayed, given the information of the scheduled departure."
]
},
{
"cell_type": "code",
"execution_count": 1,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"load dataset from ./openml_ds1169.pkl\n",
"Dataset name: airlines\n",
"X_train.shape: (404537, 7), y_train.shape: (404537,);\n",
"X_test.shape: (134846, 7), y_test.shape: (134846,)\n"
]
}
],
"source": [
"from flaml.data import load_openml_dataset\n",
"X_train, X_test, y_train, y_test = load_openml_dataset(dataset_id = 1169, data_dir = './')"
2020-12-04 09:40:27 -08:00
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Run FLAML\n",
"In the FLAML automl run configuration, users can specify the task type, time budget, error metric, learner list, whether to subsample, resampling strategy type, and so on. All these arguments have default values which will be used if users do not provide them. For example, the default ML learners of FLAML are `['lgbm', 'xgboost', 'catboost', 'rf', 'extra_tree', 'lrl1']`. "
]
},
{
"cell_type": "code",
"execution_count": 2,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [],
"source": [
"''' import AutoML class from flaml package '''\n",
"from flaml import AutoML\n",
"automl = AutoML()"
]
},
{
"cell_type": "code",
"execution_count": 3,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [],
"source": [
"settings = {\n",
" \"time_budget\": 60, # total running time in seconds\n",
" \"metric\": 'accuracy', # primary metrics can be chosen from: ['accuracy','roc_auc','f1','log_loss','mae','mse','r2']\n",
" \"estimator_list\": ['lgbm', 'rf', 'xgboost'], # list of ML learners\n",
" \"task\": 'classification', # task type \n",
" \"sample\": False, # whether to subsample training data\n",
" \"log_file_name\": 'airlines_experiment.log', # cache directory of flaml log files \n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 4,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"[flaml.automl: 12-15 07:41:38] {660} INFO - List of ML learners in AutoML Run: ['lgbm', 'rf', 'xgboost']\n",
"[flaml.automl: 12-15 07:41:38] {665} INFO - Evaluation method: holdout\n",
"[flaml.automl: 12-15 07:41:38] {683} INFO - Minimizing error metric: 1-accuracy\n",
"[flaml.automl: 12-15 07:41:39] {327} INFO - Using StratifiedKFold\n",
"[flaml.automl: 12-15 07:41:39] {728} INFO - iteration 0 current learner lgbm\n",
"[flaml.automl: 12-15 07:41:41] {793} INFO - at 3.6s,\tbest lgbm's error=0.3748,\tbest lgbm's error=0.3748\n",
"[flaml.automl: 12-15 07:41:41] {728} INFO - iteration 1 current learner lgbm\n",
"[flaml.automl: 12-15 07:41:45] {793} INFO - at 7.5s,\tbest lgbm's error=0.3735,\tbest lgbm's error=0.3735\n",
"[flaml.automl: 12-15 07:41:45] {728} INFO - iteration 2 current learner lgbm\n",
"[flaml.automl: 12-15 07:41:47] {793} INFO - at 9.2s,\tbest lgbm's error=0.3668,\tbest lgbm's error=0.3668\n",
"[flaml.automl: 12-15 07:41:47] {728} INFO - iteration 3 current learner lgbm\n",
"[flaml.automl: 12-15 07:41:49] {793} INFO - at 11.4s,\tbest lgbm's error=0.3613,\tbest lgbm's error=0.3613\n",
"[flaml.automl: 12-15 07:41:49] {728} INFO - iteration 4 current learner lgbm\n",
"[flaml.automl: 12-15 07:41:53] {793} INFO - at 15.0s,\tbest lgbm's error=0.3613,\tbest lgbm's error=0.3613\n",
"[flaml.automl: 12-15 07:41:53] {728} INFO - iteration 5 current learner xgboost\n",
"[flaml.automl: 12-15 07:41:56] {793} INFO - at 18.1s,\tbest xgboost's error=0.3740,\tbest lgbm's error=0.3613\n",
"[flaml.automl: 12-15 07:41:56] {728} INFO - iteration 6 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:00] {793} INFO - at 22.7s,\tbest lgbm's error=0.3613,\tbest lgbm's error=0.3613\n",
"[flaml.automl: 12-15 07:42:00] {728} INFO - iteration 7 current learner xgboost\n",
"[flaml.automl: 12-15 07:42:02] {793} INFO - at 24.8s,\tbest xgboost's error=0.3659,\tbest lgbm's error=0.3613\n",
"[flaml.automl: 12-15 07:42:02] {728} INFO - iteration 8 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:11] {793} INFO - at 33.0s,\tbest lgbm's error=0.3544,\tbest lgbm's error=0.3544\n",
"[flaml.automl: 12-15 07:42:11] {728} INFO - iteration 9 current learner rf\n",
"[flaml.automl: 12-15 07:42:20] {793} INFO - at 41.9s,\tbest rf's error=0.3895,\tbest lgbm's error=0.3544\n",
"[flaml.automl: 12-15 07:42:20] {728} INFO - iteration 10 current learner xgboost\n",
"[flaml.automl: 12-15 07:42:24] {793} INFO - at 45.8s,\tbest xgboost's error=0.3659,\tbest lgbm's error=0.3544\n",
"[flaml.automl: 12-15 07:42:24] {728} INFO - iteration 11 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:29] {793} INFO - at 51.5s,\tbest lgbm's error=0.3410,\tbest lgbm's error=0.3410\n",
"[flaml.automl: 12-15 07:42:29] {728} INFO - iteration 12 current learner rf\n",
"[flaml.automl: 12-15 07:42:29] {793} INFO - at 51.5s,\tbest rf's error=0.3895,\tbest lgbm's error=0.3410\n",
"[flaml.automl: 12-15 07:42:29] {728} INFO - iteration 13 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:35] {793} INFO - at 57.1s,\tbest lgbm's error=0.3383,\tbest lgbm's error=0.3383\n",
"[flaml.automl: 12-15 07:42:35] {728} INFO - iteration 14 current learner xgboost\n",
"[flaml.automl: 12-15 07:42:38] {793} INFO - at 60.4s,\tbest xgboost's error=0.3659,\tbest lgbm's error=0.3383\n",
"[flaml.automl: 12-15 07:42:38] {814} INFO - LGBMClassifier(learning_rate=0.5482637744255212, max_bin=1023,\n",
" min_child_weight=1.1930700595990091, n_estimators=76,\n",
" num_leaves=67, objective='binary',\n",
" reg_alpha=3.668052110134859e-10, reg_lambda=0.49371485228257217,\n",
" subsample=0.6)\n",
"[flaml.automl: 12-15 07:42:38] {702} INFO - fit succeeded\n"
2020-12-04 09:40:27 -08:00
]
}
],
"source": [
"'''The main flaml automl API'''\n",
"automl.fit(X_train = X_train, y_train = y_train, **settings)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Best model and metric"
]
},
{
"cell_type": "code",
"execution_count": 5,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Best ML leaner: lgbm\nBest hyperparmeter config: {'n_estimators': 76.23660313632638, 'max_leaves': 66.93360726547702, 'min_child_weight': 1.1930700595990091, 'learning_rate': 0.5482637744255212, 'subsample': 0.6, 'log_max_bin': 10.0, 'reg_alpha': 3.668052110134859e-10, 'reg_lambda': 0.49371485228257217, 'colsample_bytree': 1.0}\nBest accuracy on validation data: 0.6617\nTraining duration of best run: 5.522 s\n"
2020-12-04 09:40:27 -08:00
]
}
],
"source": [
"''' retrieve best config and best learner'''\n",
"print('Best ML leaner:', automl.best_estimator)\n",
"print('Best hyperparmeter config:', automl.best_config)\n",
"print('Best accuracy on validation data: {0:.4g}'.format(1-automl.best_loss))\n",
"print('Training duration of best run: {0:.4g} s'.format(automl.best_config_train_time))"
]
},
{
"cell_type": "code",
"execution_count": 6,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LGBMClassifier(learning_rate=0.5482637744255212, max_bin=1023,\n",
" min_child_weight=1.1930700595990091, n_estimators=76,\n",
" num_leaves=67, objective='binary',\n",
" reg_alpha=3.668052110134859e-10, reg_lambda=0.49371485228257217,\n",
" subsample=0.6)"
2020-12-04 09:40:27 -08:00
]
},
"metadata": {},
"execution_count": 6
2020-12-04 09:40:27 -08:00
}
],
"source": [
"automl.model"
]
},
{
"cell_type": "code",
"execution_count": 7,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [],
"source": [
"''' pickle and save the best model '''\n",
"import pickle\n",
"with open('best_model.pkl', 'wb') as f:\n",
" pickle.dump(automl.model, f, pickle.HIGHEST_PROTOCOL)"
]
},
{
"cell_type": "code",
"execution_count": 8,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Predicted labels [1 0 1 ... 1 0 0]\nTrue labels [0 0 0 ... 0 1 0]\n"
]
}
],
"source": [
"''' compute predictions of testing dataset ''' \n",
"y_pred = automl.predict(X_test)\n",
"print('Predicted labels', y_pred)\n",
"print('True labels', y_test)\n",
"y_pred_proba = automl.predict_proba(X_test)[:,1]"
]
},
{
"cell_type": "code",
"execution_count": 9,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"accuracy = 0.6666493629770256\n",
"roc_auc = 0.7173397375696496\n",
"log_loss = 0.6095801351363471\n",
"f1 = 0.580528363863719\n"
2020-12-04 09:40:27 -08:00
]
}
],
"source": [
"''' compute different metric values on testing dataset'''\n",
"from flaml.ml import sklearn_metric_loss_score\n",
"print('accuracy', '=', 1 - sklearn_metric_loss_score('accuracy', y_pred, y_test))\n",
"print('roc_auc', '=', 1 - sklearn_metric_loss_score('roc_auc', y_pred_proba, y_test))\n",
"print('log_loss', '=', sklearn_metric_loss_score('log_loss', y_pred_proba, y_test))\n",
"print('f1', '=', 1 - sklearn_metric_loss_score('f1', y_pred, y_test))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Log history"
]
},
{
"cell_type": "code",
"execution_count": 10,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'Current Learner': 'lgbm', 'Current Sample': 364083, 'Current Hyper-parameters': {'n_estimators': 4, 'max_leaves': 4, 'min_child_weight': 20, 'learning_rate': 0.1, 'subsample': 1.0, 'log_max_bin': 8, 'reg_alpha': 1e-10, 'reg_lambda': 1.0, 'colsample_bytree': 1.0}, 'Best Learner': 'lgbm', 'Best Hyper-parameters': {'n_estimators': 4, 'max_leaves': 4, 'min_child_weight': 20, 'learning_rate': 0.1, 'subsample': 1.0, 'log_max_bin': 8, 'reg_alpha': 1e-10, 'reg_lambda': 1.0, 'colsample_bytree': 1.0}}\n{'Current Learner': 'lgbm', 'Current Sample': 364083, 'Current Hyper-parameters': {'n_estimators': 4.345841756255061, 'max_leaves': 10.353390566270846, 'min_child_weight': 20.0, 'learning_rate': 0.04742496726415123, 'subsample': 0.9045133325444861, 'log_max_bin': 10.0, 'reg_alpha': 1e-10, 'reg_lambda': 1.0, 'colsample_bytree': 0.9407474408255333}, 'Best Learner': 'lgbm', 'Best Hyper-parameters': {'n_estimators': 4, 'max_leaves': 4, 'min_child_weight': 20, 'learning_rate': 0.1, 'subsample': 1.0, 'log_max_bin': 8, 'reg_alpha': 1e-10, 'reg_lambda': 1.0, 'colsample_bytree': 1.0}}\n{'Current Learner': 'lgbm', 'Current Sample': 364083, 'Current Hyper-parameters': {'n_estimators': 4.0, 'max_leaves': 4.0, 'min_child_weight': 9.874086709908818, 'learning_rate': 0.21085939699865755, 'subsample': 1.0, 'log_max_bin': 3.0, 'reg_alpha': 2.6875093824678297e-10, 'reg_lambda': 0.7230542131309051, 'colsample_bytree': 1.0}, 'Best Learner': 'lgbm', 'Best Hyper-parameters': {'n_estimators': 4.0, 'max_leaves': 4.0, 'min_child_weight': 9.874086709908818, 'learning_rate': 0.21085939699865755, 'subsample': 1.0, 'log_max_bin': 3.0, 'reg_alpha': 2.6875093824678297e-10, 'reg_lambda': 0.7230542131309051, 'colsample_bytree': 1.0}}\n{'Current Learner': 'lgbm', 'Current Sample': 364083, 'Current Hyper-parameters': {'n_estimators': 6.30703808576676, 'max_leaves': 4.615126183980338, 'min_child_weight': 5.419442970309873, 'learning_rate': 0.45611181052279925, 'subsample': 1.0, 'log_max_bin': 3.0, 'reg_alpha': 1e-10, 'reg_lambda': 0.5948168429421155, 'colsample_bytree': 1.0}, 'Best Learner': 'lgbm', 'Best Hyper-parameters': {'n_estimators': 6.30703808576676, 'max_leaves': 4.615126183980338, 'min_child_weight': 5.419442970309873, 'learning_rate': 0.45611181052279925, 'subsample': 1.0, 'log_max_bin': 3.0, 'reg_alpha': 1e-10, 'reg_lambda': 0.5948168429421155, 'colsample_bytree': 1.0}}\n{'Current Learner': 'lgbm', 'Current Sample': 364083, 'Current Hyper-parameters': {'n_estimators': 13.346655408225933, 'max_leaves': 7.128882408907543, 'min_child_weight': 3.5378687932000563, 'learning_rate': 0.27022645132691947, 'subsample': 1.0, 'log_max_bin': 3.9062497595361734, 'reg_alpha': 4.798429666191569e-10, 'reg_lambda': 0.31076883570242425, 'colsample_bytree': 1.0}, 'Best Learner': 'lgbm', 'Best Hyper-parameters': {'n_estimators': 13.346655408225933, 'max_leaves': 7.128882408907543, 'min_child_weight': 3.5378687932000563, 'learning_rate': 0.27022645132691947, 'subsample': 1.0, 'log_max_bin': 3.9062497595361734, 'reg_alpha': 4.798429666191569e-10, 'reg_lambda': 0.31076883570242425, 'colsample_bytree': 1.0}}\n{'Current Learner': 'lgbm', 'Current Sample': 364083, 'Current Hyper-parameters': {'n_estimators': 9.168255249166949, 'max_leaves': 16.406314436487644, 'min_child_weight': 1.2440119163470513, 'learning_rate': 0.34085789038743874, 'subsample': 0.8622669492242545, 'log_max_bin': 3.9088586623653176, 'reg_alpha': 6.716698258358434e-10, 'reg_lambda': 0.08971222222676836, 'colsample_bytree': 0.7}, 'Best Learner': 'lgbm', 'Best Hyper-parameters': {'n_estimators': 13.346655408225933, 'max_leaves': 7.128882408907543, 'min_child_weight': 3.5378687932000563, 'learning_rate': 0.27022645132691947, 'subsample': 1.0, 'log_max_bin': 3.9062497595361734, 'reg_alpha': 4.798429666191569e-10, 'reg_lambda': 0.31076883570242425, 'colsample_bytree': 1.0}}\n{'Current Learner': 'lgbm', 'Current Sample': 364083, 'Current Hyper-parameters': {'n_estimators': 19.429346778070144, 'max_leaves': 4.0, 'min_child_weight': 10.061411336518901, 'learning_rate': 0.21423102429501803, 'subsample': 1.0, 'lo
2020-12-04 09:40:27 -08:00
]
}
],
"source": [
"from flaml.data import get_output_from_log\n",
2020-12-04 09:40:27 -08:00
"time_history, best_valid_loss_history, valid_loss_history, config_history, train_loss_history = \\\n",
" get_output_from_log(filename = settings['log_file_name'], time_budget = 60)\n",
"\n",
"for config in config_history:\n",
" print(config)"
]
},
{
"cell_type": "code",
"execution_count": 11,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"277.314375pt\" version=\"1.1\" viewBox=\"0 0 392.14375 277.314375\" width=\"392.14375pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 277.314375 \r\nL 392.14375 277.314375 \r\nL 392.14375 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 50.14375 239.758125 \r\nL 384.94375 239.758125 \r\nL 384.94375 22.318125 \r\nL 50.14375 22.318125 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"PathCollection_1\">\r\n <defs>\r\n <path d=\"M 0 3 \r\nC 0.795609 3 1.55874 2.683901 2.12132 2.12132 \r\nC 2.683901 1.55874 3 0.795609 3 0 \r\nC 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132 \r\nC 1.55874 -2.683901 0.795609 -3 0 -3 \r\nC -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132 \r\nC -2.683901 -1.55874 -3 -0.795609 -3 0 \r\nC -3 0.795609 -2.683901 1.55874 -2.12132 2.12132 \r\nC -1.55874 2.683901 -0.795609 3 0 3 \r\nz\r\n\" id=\"m3c030bdee4\" style=\"stroke:#1f77b4;\"/>\r\n </defs>\r\n <g clip-path=\"url(#p1603db58c1)\">\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"65.361932\" xlink:href=\"#m3c030bdee4\" y=\"173.301165\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"79.674828\" xlink:href=\"#m3c030bdee4\" y=\"222.528543\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"87.738087\" xlink:href=\"#m3c030bdee4\" y=\"168.05406\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"97.300075\" xlink:href=\"#m3c030bdee4\" y=\"142.295549\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"109.900498\" xlink:href=\"#m3c030bdee4\" y=\"120.925524\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"120.541856\" xlink:href=\"#m3c030bdee4\" y=\"127.412853\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"130.415606\" xlink:href=\"#m3c030bdee4\" y=\"130.370312\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"147.983308\" xlink:href=\"#m3c030bdee4\" y=\"170.152902\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"160.53905\" xlink:href=\"#m3c030bdee4\" y=\"135.80822\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"174.517433\" xlink:href=\"#m3c030bdee4\" y=\"146.588634\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"186.019966\" xlink:href=\"#m3c030bdee4\" y=\"138.765679\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"233.019224\" xlink:href=\"#m3c030bdee4\" y=\"94.499199\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"283.373055\" xlink:href=\"#m3c030bdee4\" y=\"229.874489\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"296.670361\" xlink:href=\"#m3c030bdee4\" y=\"170.343706\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"305.914886\" xlink:href=\"#m3c030bdee4\" y=\"140.482913\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"318.831045\" xlink:href=\"#m3c030bdee4\" y=\"106.13823\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"338.241899\" xlink:href=\"#m3c030bdee4\" y=\"42.791372\"/>\r\n <use style=\"fill:#1f77b4;stroke:#1f77b4;\" x=\"369.725568\" xlink:href=\"#m3c030bdee4\" y=\"32.201761\"/>\r\n </g>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"madf86e0e70\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n </defs>\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"101.977632\" xlink:href=\"#madf86e0e70\" y=\"239.758125\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- 10 -->\r\n <defs>\r\n <path d=
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3deXxW9Zn38c+XECQoGBdAFhWogOIGJVVb64ZVcLdWuzjddKaOM1PHeWydaqeddqbjjDNM7fLU1sc6HWtb2yog0kpBW7eOopKIsog4iIgJIKCCiAFCcj1/nBN7E+4kN5CTO7nzfb9evJJzzu/c5zptzJXfrojAzMyspV7FDsDMzLomJwgzM8vLCcLMzPJygjAzs7ycIMzMLC8nCDMzy8sJwmwPSDpF0rJix2GWJScI63YkrZT0kWLGEBF/jIixWX2+pMmSHpe0WdJ6SY9JujCr55nl4wRhloeksiI++1LgXuAuYDgwGPhH4II9+CxJ8n/ntkf8g2MlQ1IvSTdIelnSG5LukXRgzvV7Ja2VtCn96/zonGt3SvqRpNmStgBnpDWVL0tamN7za0l90/KnS6rNub/Vsun1v5e0RtJqSX8hKSQdkecdBNwCfCsi7oiITRHRFBGPRcQX0jLflPTznHtGpJ/XOz1+VNJNkp4A3gW+Kqm6xXP+j6RZ6ff7SPpPSaskvS7pNkkVe/l/h5UAJwgrJX8LXAycBgwF3gJuzbn+O2A0MAh4FvhFi/svB24C+gP/k577ODAFGAkcB3y+jefnLStpCnAd8BHgiDS+1owFDgWmtVGmEJ8BriJ5l/8LjJU0Ouf65cDd6ff/DowBxqfxDSOpsVgP5wRhpeQvgX+IiNqI2AZ8E7i0+S/riPhJRGzOuXa8pP1z7r8/Ip5I/2Lfmp77fkSsjog3gd+Q/BJtTWtlPw78d0QsiYh3gX9q4zMOSr+uKfit87szfd6OiNgE3A98CiBNFEcCs9IayxeA/xMRb0bEZuBfgU/u5fOtBDhBWCk5HLhP0kZJG4GlQCMwWFKZpJvT5qe3gZXpPQfn3P9ans9cm/P9u8B+bTy/tbJDW3x2vuc0eyP9OqSNMoVo+Yy7SRMESe1hZpqsBgL9gJqc/93mpOeth3OCsFLyGnBORFTm/OsbEXUkvxQvImnm2R8Ykd6jnPuzWtp4DUlnc7ND2yi7jOQ9PtZGmS0kv9SbHZKnTMt3eRA4WNJ4kkTR3Ly0AagHjs7532z/iGgrEVoP4QRh3VW5pL45/3oDtwE3STocQNJASRel5fsD20j+Qu9H0ozSWe4BrpB0lKR+tNG+H8n6+9cBX5d0haQBaef7hyXdnhZ7DjhV0mFpE9mN7QUQETtI+jWmAgcCD6Xnm4AfA9+RNAhA0jBJk/f4ba1kOEFYdzWb5C/f5n/fBL4HzAIelLQZeAo4MS1/F/AqUAe8kF7rFBHxO+D7wCPAcmBeemlbK+WnAZ8ArgRWA68D/0LSj0BEPAT8GlgI1AC/LTCUu0lqUPemCaPZV9K4nkqb335P0lluPZy8YZBZ55J0FLAY2KfFL2qzLsU1CLNOIOmjkvpIOoBkWOlvnBysq3OCMOscfwmsB14mGVn1V8UNx6x9bmIyM7O8XIMwM7O8ehc7gI508MEHx4gRI4odhplZt1FTU7MhIvJOjCypBDFixAiqq6vbL2hmZgBIerW1a25iMjOzvJwgzMwsLycIMzPLywnCzMzycoIwM7O8SmoUk5lZTzJzQR1T5y5j9cZ6hlZWcP3ksVw8YViHfb4ThJlZNzRzQR03zlhEfUMjAHUb67lxxiKADksSbmIyM+uGps5d9l5yaFbf0MjUucs67BmuQZiZdSNbGxp5dNk66jbW572+upXze8IJwsysi9va0MhjL63ngYVr+P3S13l3eyO9BE151lodWlnRYc91gjAz64K27Wjk8Zc28MDC1fx+6Tre2baDA/qVc9H4YZx/3BDWbtrK12Yu3qmZqaK8jOsnd9xmgE4QZmZdxPYdTfzP8vX8duEaHlryOpu37aCyXznnHzeE844bwkmjDqK87E9dx2W91H1HMUmaQrJPcBlwR0TcnKfM6cB3gXJgQ0Sclp6vBO4AjgECuDIi5rW838ysO9u+o4knXt7AAwvX8OCStby9dQcD+vZmyjGHcN5xQzj5iIN3Sgq5Lp4wrEMTQkuZJQhJZcCtwFlALTBf0qyIeCGnTCXwQ2BKRKySNCjnI74HzImISyX1AfplFauZWWdqaGziyZff4IGFq5m75HU21TfQv29vzh53COenSaFP7+IPMs2yBnECsDwiVgBI+hVwEfBCTpnLgRkRsQogItalZQcApwKfT89vB7ZnGKuZWaZ2NDYxb8UbPLBwDXOWrGXjuw3st09vzh43mPOOG8KHRx/MPr3Lih3mTrJMEMOA13KOa4ETW5QZA5RLehToD3wvIu4CRpHs3/vfko4HaoBrI2JLy4dIugq4CuCwww7r6HcwM9tjOxqbeOaVN/ntojXMWbyWN7dsZ98+ZXxk3GDOO3YIp44ZSN/yrpUUcmWZIJTnXMtBWb2BicCZQAUwT9JT6fn3A9dExNOSvgfcAHx9lw+MuB24HaCqqsobbJtZp2q53MWXzhrDkMoKHli0mjmL17Lhne3061PGmUclSeH0sV07KeTKMkHUAofmHA8HVucpsyGtGWyR9DhwPPBHoDYink7LTSNJEGZmXUa+5S6uu/d5IBlyOumoQZx/7BBOHzuIij7dIynkyjJBzAdGSxoJ1AGfJOlzyHU/8ANJvYE+JE1Q34mItZJekzQ2IpaR1DBewMysC8m33AXAAf3KeeKGSfTr071nEmQWfUTskPRFYC7JMNefRMQSSVen12+LiKWS5gALgSaSobCL04+4BvhFOoJpBXBFVrGame2J1pa12PhuQ7dPDpDxPIiImA3MbnHuthbHU4Gpee59DqjKMj4zs70xtLIi75pIHbncRTEVf6CtmVk3dcXJI3Y519HLXRSTE4SZ2R56e+sOJDhkQF8EDKus4N8uOTbT2c2dqfs3kpmZFUFTUzC9ppZTRg/kritPKHY4mXANwsxsDzz1yhvUbazn0onDix1KZpwgzMz2wLTq2nT9pMHFDiUzThBmZrtp89YGZi9ewwXHD+02s6L3hBOEmdlu+t2itWxtaCrp5iVwgjAz223TamoZNXBfJhxaWexQMuUEYWa2G1Zu2MIzK9/k0onDkfKtSVo6nCDMzHbDjGdr6SW4ZEJpNy+BE4SZWcGamoLpz9bx4dEDOWT/vsUOJ3NOEGZmBXpqRenPfcjlBGFmVqBpNaU/9yGXE4SZWQGa5z5cWOJzH3I5QZiZFWD2ojU9Yu5DLicIM7MCTKup5X0D92V8ic99yOUEYWbWjpUbtjB/5VtcOvHQkp/7kMsJwsysHdPTuQ8fLZF9HgrlBGFm1obcfR96wtyHXE4QZmZtmLfiDVZv2tqjOqebOUGYmbWhee7DWT1k7kMuJwgzs1Zs3trA73rY3IdcThBmZq14YGEy9+GyqkOLHUpROEGYmbViWk0tRwzaj+OH71/sUIrCCcLMLI9XNmyh+tW3esS+D61xgjAzy2N6Tc+c+5DLCcLMrIXGpmD6s7WcOmYggwf0rLkPuZwgzMxamPfyG6zpoXMfcjlBmJm1MK3mNQb07c1Hjup5cx9yZZogJE2RtEzSckk3tFLmdEnPSVoi6bGc8yslLUqvVWcZp5lZs7e3NjBnyVouHN8z5z7k6p3VB0sqA24FzgJqgfmSZkXECzllKoEfAlMiYpWkQS0+5oyI2JBVjGZmLc1e2LzvQ8+c+5AryxrECcDyiFgREduBXwEXtShzOTAjIlYBRMS6DOMxM2vXvT187kOuLBPEMOC1nOPa9FyuMcABkh6VVCPpsznXAngwPX9VhnGamQGwYv071Lz6Fpf14LkPuTJrYgLy/a8beZ4/ETgTqADmSXoqIl4CTo6I1Wmz00OSXoyIx3d5SJI8rgI47LDDOvQFzKxn6an7PrQmyxpELZDbiDccWJ2nzJyI2JL2NTwOHA8QEavTr+uA+0iarHY
2020-12-04 09:40:27 -08:00
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"plt.title('Learning Curve')\n",
"plt.xlabel('Wall Clock Time (s)')\n",
"plt.ylabel('Validation Accuracy')\n",
"plt.scatter(time_history, 1-np.array(valid_loss_history))\n",
"plt.plot(time_history, 1-np.array(best_valid_loss_history))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3. Customized Learner"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Some experienced automl users may have a preferred model to tune or may already have a reasonably by-hand-tuned model before launching the automl experiment. They need to select optimal configurations for the customized model mixed with standard built-in learners. \n",
"\n",
"FLAML can easily incorporate customized/new learners (preferably with sklearn API) provided by users in a real-time manner, as demonstrated below."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Example of Regularized Greedy Forest\n",
"\n",
"[Regularized Greedy Forest](https://arxiv.org/abs/1109.0887) (RGF) is a machine learning method currently not included in FLAML. The RGF has many tuning parameters, the most critical of which are: `[max_leaf, n_iter, n_tree_search, opt_interval, min_samples_leaf]`. To run a customized/new learner, the user needs to provide the following information:\n",
"* an implementation of the customized/new learner\n",
"* a list of hyperparameter names and types\n",
"* rough ranges of hyperparameters (i.e., upper/lower bounds)\n",
"* choose initial value corresponding to low cost for cost-related hyperparameters (e.g., initial value for max_leaf and n_iter should be small)\n",
"\n",
"In this example, the above information for RGF is wrapped in a python class called *MyRegularizedGreedyForest* that exposes the hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": 12,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [],
"source": [
"''' BaseEstimator is the parent class for a customized learner '''\n",
2020-12-04 09:40:27 -08:00
"from flaml.model import BaseEstimator\n",
"from flaml.space import ConfigSearchInfo\n",
"''' import the RGF implementation from rgf.sklearn module'''\n",
2020-12-04 09:40:27 -08:00
"from rgf.sklearn import RGFClassifier, RGFRegressor\n",
"\n",
"\n",
"class MyRegularizedGreedyForest(BaseEstimator):\n",
"\n",
" # search space\n",
" params_configsearch_info = {\n",
" 'max_leaf': ConfigSearchInfo(name = 'max_leaf', type = int, lower = 4, init = 4, upper = 10000),\n",
" 'n_iter': ConfigSearchInfo(name = 'n_iter', type = int, lower = 1, init = 1, upper = 32768),\n",
" 'n_tree_search': ConfigSearchInfo(name = 'n_tree_search', type = int, lower = 1, init = 1, upper = 32768),\n",
" 'opt_interval': ConfigSearchInfo(name = 'opt_interval', type = int, lower = 1, init = 100, upper = 10000),\n",
" 'learning_rate': ConfigSearchInfo(name = 'learning_rate', type = float, lower = 0.01, init = 1.0, upper = 20.0),\n",
" 'min_samples_leaf': ConfigSearchInfo(name = 'min_samples_leaf', type = int, lower = 1, init = 20, upper = 20)\n",
" }\n",
" \n",
" def __init__(self, objective_name = 'binary:logistic', n_jobs = 1, max_leaf = 1000, \n",
" n_iter = 1, n_tree_search = 1, opt_interval = 1, learning_rate = 1.0, min_samples_leaf = 1):\n",
"\n",
" '''regression for RGFRegressor; binary:logistic and multiclass for RGFClassifier'''\n",
" self.objective_name = objective_name\n",
"\n",
" if 'regression' in objective_name:\n",
" self.estimator_class = RGFRegressor\n",
" else:\n",
" self.estimator_class = RGFClassifier\n",
"\n",
" # round integer hyperparameters\n",
" self.params = {\n",
" \"n_jobs\": n_jobs,\n",
2020-12-04 09:40:27 -08:00
" 'max_leaf': int(round(max_leaf)),\n",
" 'n_iter': int(round(n_iter)),\n",
" 'n_tree_search': int(round(n_tree_search)),\n",
" 'opt_interval': int(round(opt_interval)),\n",
" 'learning_rate': learning_rate,\n",
" 'min_samples_leaf':int(round(min_samples_leaf))\n",
" } \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Add Customized Learner and Run FLAML AutoML\n",
"\n",
"After adding RGF into the list of learners, we run automl by tuning hyperpameters of RGF as well as the default learners. "
]
},
{
"cell_type": "code",
"execution_count": 13,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [],
"source": [
"''' add a new learner RGF'''\n",
"automl = AutoML()\n",
"automl.add_learner(learner_name = 'RGF', learner_class = MyRegularizedGreedyForest)"
]
},
{
"cell_type": "code",
"execution_count": 14,
2020-12-04 09:40:27 -08:00
"metadata": {
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"[flaml.automl: 12-15 07:42:43] {660} INFO - List of ML learners in AutoML Run: ['RGF', 'lgbm', 'rf', 'xgboost']\n",
"[flaml.automl: 12-15 07:42:43] {665} INFO - Evaluation method: holdout\n",
"[flaml.automl: 12-15 07:42:43] {683} INFO - Minimizing error metric: 1-accuracy\n",
"[flaml.automl: 12-15 07:42:45] {327} INFO - Using StratifiedKFold\n",
"[flaml.automl: 12-15 07:42:45] {728} INFO - iteration 0 current learner RGF\n",
"[flaml.automl: 12-15 07:42:47] {793} INFO - at 4.0s,\tbest RGF's error=0.3764,\tbest RGF's error=0.3764\n",
"[flaml.automl: 12-15 07:42:47] {728} INFO - iteration 1 current learner RGF\n",
"[flaml.automl: 12-15 07:42:52] {793} INFO - at 8.7s,\tbest RGF's error=0.3764,\tbest RGF's error=0.3764\n",
"[flaml.automl: 12-15 07:42:52] {728} INFO - iteration 2 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:52] {793} INFO - at 8.9s,\tbest lgbm's error=0.3790,\tbest RGF's error=0.3764\n",
"[flaml.automl: 12-15 07:42:52] {728} INFO - iteration 3 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:53] {793} INFO - at 9.3s,\tbest lgbm's error=0.3790,\tbest RGF's error=0.3764\n",
"[flaml.automl: 12-15 07:42:53] {728} INFO - iteration 4 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:53] {793} INFO - at 9.8s,\tbest lgbm's error=0.3718,\tbest lgbm's error=0.3718\n",
"[flaml.automl: 12-15 07:42:53] {728} INFO - iteration 5 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:53] {793} INFO - at 10.0s,\tbest lgbm's error=0.3652,\tbest lgbm's error=0.3652\n",
"[flaml.automl: 12-15 07:42:53] {728} INFO - iteration 6 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:54] {793} INFO - at 10.5s,\tbest lgbm's error=0.3652,\tbest lgbm's error=0.3652\n",
"[flaml.automl: 12-15 07:42:54] {728} INFO - iteration 7 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:55] {793} INFO - at 11.8s,\tbest lgbm's error=0.3652,\tbest lgbm's error=0.3652\n",
"[flaml.automl: 12-15 07:42:55] {728} INFO - iteration 8 current learner lgbm\n",
"[flaml.automl: 12-15 07:42:57] {793} INFO - at 14.0s,\tbest lgbm's error=0.3568,\tbest lgbm's error=0.3568\n",
"[flaml.automl: 12-15 07:42:57] {728} INFO - iteration 9 current learner lgbm\n",
"[flaml.automl: 12-15 07:43:02] {793} INFO - at 18.1s,\tbest lgbm's error=0.3547,\tbest lgbm's error=0.3547\n",
"[flaml.automl: 12-15 07:43:02] {728} INFO - iteration 10 current learner lgbm\n",
"[flaml.automl: 12-15 07:43:07] {793} INFO - at 23.2s,\tbest lgbm's error=0.3522,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:07] {728} INFO - iteration 11 current learner xgboost\n",
"[flaml.automl: 12-15 07:43:07] {793} INFO - at 23.9s,\tbest xgboost's error=0.3764,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:07] {728} INFO - iteration 12 current learner xgboost\n",
"[flaml.automl: 12-15 07:43:08] {793} INFO - at 24.7s,\tbest xgboost's error=0.3671,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:08] {728} INFO - iteration 13 current learner xgboost\n",
"[flaml.automl: 12-15 07:43:09] {793} INFO - at 26.0s,\tbest xgboost's error=0.3671,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:09] {728} INFO - iteration 14 current learner lgbm\n",
"[flaml.automl: 12-15 07:43:18] {793} INFO - at 34.7s,\tbest lgbm's error=0.3522,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:18] {728} INFO - iteration 15 current learner rf\n",
"[flaml.automl: 12-15 07:43:19] {793} INFO - at 35.3s,\tbest rf's error=0.4323,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:19] {728} INFO - iteration 16 current learner rf\n",
"[flaml.automl: 12-15 07:43:19] {793} INFO - at 36.0s,\tbest rf's error=0.4033,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:19] {728} INFO - iteration 17 current learner RGF\n",
"[flaml.automl: 12-15 07:43:28] {793} INFO - at 44.7s,\tbest RGF's error=0.3764,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:28] {728} INFO - iteration 18 current learner xgboost\n",
"[flaml.automl: 12-15 07:43:29] {793} INFO - at 45.4s,\tbest xgboost's error=0.3602,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:29] {728} INFO - iteration 19 current learner xgboost\n",
"[flaml.automl: 12-15 07:43:31] {793} INFO - at 47.3s,\tbest xgboost's error=0.3544,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:31] {728} INFO - iteration 20 current learner xgboost\n",
"[flaml.automl: 12-15 07:43:32] {793} INFO - at 48.9s,\tbest xgboost's error=0.3525,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:32] {728} INFO - iteration 21 current learner xgboost\n",
"[flaml.automl: 12-15 07:43:37] {793} INFO - at 53.5s,\tbest xgboost's error=0.3525,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:37] {728} INFO - iteration 22 current learner lgbm\n",
"[flaml.automl: 12-15 07:43:42] {793} INFO - at 59.0s,\tbest lgbm's error=0.3522,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:42] {728} INFO - iteration 23 current learner xgboost\n",
"[flaml.automl: 12-15 07:43:43] {793} INFO - at 59.9s,\tbest xgboost's error=0.3525,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:43] {728} INFO - iteration 24 current learner rf\n",
"[flaml.automl: 12-15 07:43:43] {793} INFO - at 59.9s,\tbest rf's error=0.4033,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:43] {728} INFO - iteration 25 current learner RGF\n",
"[flaml.automl: 12-15 07:43:47] {793} INFO - at 63.9s,\tbest RGF's error=0.3764,\tbest lgbm's error=0.3522\n",
"[flaml.automl: 12-15 07:43:47] {814} INFO - LGBMClassifier(colsample_bytree=0.7, learning_rate=0.06177098582210786,\n",
" max_bin=127, min_child_weight=5.058775453728698, n_estimators=80,\n",
" num_leaves=17, objective='binary',\n",
" reg_alpha=3.690867311882246e-10, reg_lambda=1.0,\n",
" subsample=0.7382230019481447)\n",
"[flaml.automl: 12-15 07:43:47] {702} INFO - fit succeeded\n"
2020-12-04 09:40:27 -08:00
]
}
],
"source": [
"settings = {\n",
" \"time_budget\": 60, # total running time in seconds\n",
" \"metric\": 'accuracy', \n",
" \"estimator_list\": ['RGF', 'lgbm', 'rf', 'xgboost'], # list of ML learners\n",
" \"task\": 'classification', # task type \n",
" \"sample\": True, # whether to subsample training data\n",
" \"log_file_name\": 'airlines_experiment.log', # cache directory of flaml log files \n",
" \"log_training_metric\": True, # whether to log training metric\n",
"}\n",
"\n",
"'''The main flaml automl API'''\n",
"automl.fit(X_train = X_train, y_train = y_train, **settings)"
]
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3.7.9 64-bit ('test': conda)",
2020-12-04 09:40:27 -08:00
"metadata": {
"interpreter": {
"hash": "d432c3c2bcf16c697a4c55907b7ae9cb502fbbf6a7955e813637a3b18956f9d0"
2020-12-04 09:40:27 -08:00
}
}
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9-final"
2020-12-04 09:40:27 -08:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}