{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# AutoML with FLAML Library for synapseML models and spark dataframes\n", "\n", "\n", "## 1. Introduction\n", "\n", "FLAML is a Python library (https://github.com/microsoft/FLAML) designed to automatically produce accurate machine learning models \n", "with low computational cost. It is fast and economical. The simple and lightweight design makes it easy \n", "to use and extend, such as adding new learners. FLAML can \n", "- serve as an economical AutoML engine,\n", "- be used as a fast hyperparameter tuning tool, or \n", "- be embedded in self-tuning software that requires low latency & resource in repetitive\n", " tuning tasks.\n", "\n", "In this notebook, we demonstrate how to use FLAML library to do AutoML for synapseML models and spark dataframes. We also compare the results between FLAML AutoML and default SynapseML. \n", "In this example, we use LightGBM to build a classification model in order to predict bankruptcy.\n", "\n", "Since the dataset is unbalanced, `AUC` is a better metric than `Accuracy`. FLAML (1 min of training) achieved AUC **0.79**, the default SynapseML model only got AUC **0.64**. \n", "\n", "FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `synapse` option:\n", "```bash\n", "pip install flaml[synapse]>=1.1.3; \n", "```\n", " " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# %pip install \"flaml[synapse]>=1.1.3\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Load data and preprocess" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ":: loading settings :: url = jar:file:/datadrive/spark/spark33/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Ivy Default Cache set to: /home/lijiang1/.ivy2/cache\n", "The jars for the packages stored in: /home/lijiang1/.ivy2/jars\n", "com.microsoft.azure#synapseml_2.12 added as a dependency\n", "org.apache.hadoop#hadoop-azure added as a dependency\n", "com.microsoft.azure#azure-storage added as a dependency\n", ":: resolving dependencies :: org.apache.spark#spark-submit-parent-bfb2447b-61c5-4941-bf9b-0548472077eb;1.0\n", "\tconfs: [default]\n", "\tfound com.microsoft.azure#synapseml_2.12;0.10.2 in central\n", "\tfound com.microsoft.azure#synapseml-core_2.12;0.10.2 in central\n", "\tfound org.scalactic#scalactic_2.12;3.2.14 in local-m2-cache\n", "\tfound org.scala-lang#scala-reflect;2.12.15 in central\n", "\tfound io.spray#spray-json_2.12;1.3.5 in central\n", "\tfound com.jcraft#jsch;0.1.54 in central\n", "\tfound org.apache.httpcomponents.client5#httpclient5;5.1.3 in central\n", "\tfound org.apache.httpcomponents.core5#httpcore5;5.1.3 in central\n", "\tfound org.apache.httpcomponents.core5#httpcore5-h2;5.1.3 in central\n", "\tfound org.slf4j#slf4j-api;1.7.25 in local-m2-cache\n", "\tfound commons-codec#commons-codec;1.15 in local-m2-cache\n", "\tfound org.apache.httpcomponents#httpmime;4.5.13 in local-m2-cache\n", "\tfound org.apache.httpcomponents#httpclient;4.5.13 in local-m2-cache\n", "\tfound org.apache.httpcomponents#httpcore;4.4.13 in central\n", "\tfound commons-logging#commons-logging;1.2 in central\n", "\tfound com.linkedin.isolation-forest#isolation-forest_3.2.0_2.12;2.0.8 in central\n", "\tfound com.chuusai#shapeless_2.12;2.3.2 in central\n", "\tfound org.typelevel#macro-compat_2.12;1.1.1 in central\n", "\tfound org.apache.spark#spark-avro_2.12;3.2.0 in central\n", "\tfound org.tukaani#xz;1.8 in central\n", "\tfound org.spark-project.spark#unused;1.0.0 in central\n", "\tfound org.testng#testng;6.8.8 in central\n", "\tfound org.beanshell#bsh;2.0b4 in central\n", "\tfound com.beust#jcommander;1.27 in central\n", "\tfound com.microsoft.azure#synapseml-deep-learning_2.12;0.10.2 in central\n", "\tfound com.microsoft.azure#synapseml-opencv_2.12;0.10.2 in central\n", "\tfound org.openpnp#opencv;3.2.0-1 in central\n", "\tfound com.microsoft.azure#onnx-protobuf_2.12;0.9.1 in central\n", "\tfound com.microsoft.cntk#cntk;2.4 in central\n", "\tfound com.microsoft.onnxruntime#onnxruntime_gpu;1.8.1 in central\n", "\tfound com.microsoft.azure#synapseml-cognitive_2.12;0.10.2 in central\n", "\tfound com.microsoft.cognitiveservices.speech#client-jar-sdk;1.14.0 in central\n", "\tfound com.microsoft.azure#synapseml-vw_2.12;0.10.2 in central\n", "\tfound com.github.vowpalwabbit#vw-jni;8.9.1 in central\n", "\tfound com.microsoft.azure#synapseml-lightgbm_2.12;0.10.2 in central\n", "\tfound com.microsoft.ml.lightgbm#lightgbmlib;3.2.110 in central\n", "\tfound org.apache.hadoop#hadoop-azure;3.3.1 in central\n", "\tfound org.apache.hadoop.thirdparty#hadoop-shaded-guava;1.1.1 in local-m2-cache\n", "\tfound org.eclipse.jetty#jetty-util-ajax;9.4.40.v20210413 in central\n", "\tfound org.eclipse.jetty#jetty-util;9.4.40.v20210413 in central\n", "\tfound org.codehaus.jackson#jackson-mapper-asl;1.9.13 in local-m2-cache\n", "\tfound org.codehaus.jackson#jackson-core-asl;1.9.13 in local-m2-cache\n", "\tfound org.wildfly.openssl#wildfly-openssl;1.0.7.Final in local-m2-cache\n", "\tfound com.microsoft.azure#azure-storage;8.6.6 in central\n", "\tfound com.fasterxml.jackson.core#jackson-core;2.9.4 in central\n", "\tfound org.apache.commons#commons-lang3;3.4 in local-m2-cache\n", "\tfound com.microsoft.azure#azure-keyvault-core;1.2.4 in central\n", "\tfound com.google.guava#guava;24.1.1-jre in central\n", "\tfound com.google.code.findbugs#jsr305;1.3.9 in central\n", "\tfound org.checkerframework#checker-compat-qual;2.0.0 in central\n", "\tfound com.google.errorprone#error_prone_annotations;2.1.3 in central\n", "\tfound com.google.j2objc#j2objc-annotations;1.1 in central\n", "\tfound org.codehaus.mojo#animal-sniffer-annotations;1.14 in central\n", ":: resolution report :: resolve 992ms :: artifacts dl 77ms\n", "\t:: modules in use:\n", "\tcom.beust#jcommander;1.27 from central in [default]\n", "\tcom.chuusai#shapeless_2.12;2.3.2 from central in [default]\n", "\tcom.fasterxml.jackson.core#jackson-core;2.9.4 from central in [default]\n", "\tcom.github.vowpalwabbit#vw-jni;8.9.1 from central in [default]\n", "\tcom.google.code.findbugs#jsr305;1.3.9 from central in [default]\n", "\tcom.google.errorprone#error_prone_annotations;2.1.3 from central in [default]\n", "\tcom.google.guava#guava;24.1.1-jre from central in [default]\n", "\tcom.google.j2objc#j2objc-annotations;1.1 from central in [default]\n", "\tcom.jcraft#jsch;0.1.54 from central in [default]\n", "\tcom.linkedin.isolation-forest#isolation-forest_3.2.0_2.12;2.0.8 from central in [default]\n", "\tcom.microsoft.azure#azure-keyvault-core;1.2.4 from central in [default]\n", "\tcom.microsoft.azure#azure-storage;8.6.6 from central in [default]\n", "\tcom.microsoft.azure#onnx-protobuf_2.12;0.9.1 from central in [default]\n", "\tcom.microsoft.azure#synapseml-cognitive_2.12;0.10.2 from central in [default]\n", "\tcom.microsoft.azure#synapseml-core_2.12;0.10.2 from central in [default]\n", "\tcom.microsoft.azure#synapseml-deep-learning_2.12;0.10.2 from central in [default]\n", "\tcom.microsoft.azure#synapseml-lightgbm_2.12;0.10.2 from central in [default]\n", "\tcom.microsoft.azure#synapseml-opencv_2.12;0.10.2 from central in [default]\n", "\tcom.microsoft.azure#synapseml-vw_2.12;0.10.2 from central in [default]\n", "\tcom.microsoft.azure#synapseml_2.12;0.10.2 from central in [default]\n", "\tcom.microsoft.cntk#cntk;2.4 from central in [default]\n", "\tcom.microsoft.cognitiveservices.speech#client-jar-sdk;1.14.0 from central in [default]\n", "\tcom.microsoft.ml.lightgbm#lightgbmlib;3.2.110 from central in [default]\n", "\tcom.microsoft.onnxruntime#onnxruntime_gpu;1.8.1 from central in [default]\n", "\tcommons-codec#commons-codec;1.15 from local-m2-cache in [default]\n", "\tcommons-logging#commons-logging;1.2 from central in [default]\n", "\tio.spray#spray-json_2.12;1.3.5 from central in [default]\n", "\torg.apache.commons#commons-lang3;3.4 from local-m2-cache in [default]\n", "\torg.apache.hadoop#hadoop-azure;3.3.1 from central in [default]\n", "\torg.apache.hadoop.thirdparty#hadoop-shaded-guava;1.1.1 from local-m2-cache in [default]\n", "\torg.apache.httpcomponents#httpclient;4.5.13 from local-m2-cache in [default]\n", "\torg.apache.httpcomponents#httpcore;4.4.13 from central in [default]\n", "\torg.apache.httpcomponents#httpmime;4.5.13 from local-m2-cache in [default]\n", "\torg.apache.httpcomponents.client5#httpclient5;5.1.3 from central in [default]\n", "\torg.apache.httpcomponents.core5#httpcore5;5.1.3 from central in [default]\n", "\torg.apache.httpcomponents.core5#httpcore5-h2;5.1.3 from central in [default]\n", "\torg.apache.spark#spark-avro_2.12;3.2.0 from central in [default]\n", "\torg.beanshell#bsh;2.0b4 from central in [default]\n", "\torg.checkerframework#checker-compat-qual;2.0.0 from central in [default]\n", "\torg.codehaus.jackson#jackson-core-asl;1.9.13 from local-m2-cache in [default]\n", "\torg.codehaus.jackson#jackson-mapper-asl;1.9.13 from local-m2-cache in [default]\n", "\torg.codehaus.mojo#animal-sniffer-annotations;1.14 from central in [default]\n", "\torg.eclipse.jetty#jetty-util;9.4.40.v20210413 from central in [default]\n", "\torg.eclipse.jetty#jetty-util-ajax;9.4.40.v20210413 from central in [default]\n", "\torg.openpnp#opencv;3.2.0-1 from central in [default]\n", "\torg.scala-lang#scala-reflect;2.12.15 from central in [default]\n", "\torg.scalactic#scalactic_2.12;3.2.14 from local-m2-cache in [default]\n", "\torg.slf4j#slf4j-api;1.7.25 from local-m2-cache in [default]\n", "\torg.spark-project.spark#unused;1.0.0 from central in [default]\n", "\torg.testng#testng;6.8.8 from central in [default]\n", "\torg.tukaani#xz;1.8 from central in [default]\n", "\torg.typelevel#macro-compat_2.12;1.1.1 from central in [default]\n", "\torg.wildfly.openssl#wildfly-openssl;1.0.7.Final from local-m2-cache in [default]\n", "\t:: evicted modules:\n", "\tcommons-codec#commons-codec;1.11 by [commons-codec#commons-codec;1.15] in [default]\n", "\tcom.microsoft.azure#azure-storage;7.0.1 by [com.microsoft.azure#azure-storage;8.6.6] in [default]\n", "\torg.slf4j#slf4j-api;1.7.12 by [org.slf4j#slf4j-api;1.7.25] in [default]\n", "\torg.apache.commons#commons-lang3;3.8.1 by [org.apache.commons#commons-lang3;3.4] in [default]\n", "\t---------------------------------------------------------------------\n", "\t| | modules || artifacts |\n", "\t| conf | number| search|dwnlded|evicted|| number|dwnlded|\n", "\t---------------------------------------------------------------------\n", "\t| default | 57 | 0 | 0 | 4 || 53 | 0 |\n", "\t---------------------------------------------------------------------\n", ":: retrieving :: org.apache.spark#spark-submit-parent-bfb2447b-61c5-4941-bf9b-0548472077eb\n", "\tconfs: [default]\n", "\t0 artifacts copied, 53 already retrieved (0kB/20ms)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "23/02/28 02:12:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n" ] } ], "source": [ "import pyspark\n", "\n", "spark = (\n", " pyspark.sql.SparkSession.builder.appName(\"MyApp\")\n", " .config(\n", " \"spark.jars.packages\",\n", " f\"com.microsoft.azure:synapseml_2.12:0.10.2,org.apache.hadoop:hadoop-azure:{pyspark.__version__},com.microsoft.azure:azure-storage:8.6.6\",\n", " )\n", " .config(\"spark.sql.debug.maxToStringFields\", \"100\")\n", " .getOrCreate()\n", ")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "23/02/28 02:12:32 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-azure-file-system.properties,hadoop-metrics2.properties\n", "records read: 6819\n", "Schema: \n", "root\n", " |-- Bankrupt?: integer (nullable = true)\n", " |-- ROA(C) before interest and depreciation before interest: double (nullable = true)\n", " |-- ROA(A) before interest and % after tax: double (nullable = true)\n", " |-- ROA(B) before interest and depreciation after tax: double (nullable = true)\n", " |-- Operating Gross Margin: double (nullable = true)\n", " |-- Realized Sales Gross Margin: double (nullable = true)\n", " |-- Operating Profit Rate: double (nullable = true)\n", " |-- Pre-tax net Interest Rate: double (nullable = true)\n", " |-- After-tax net Interest Rate: double (nullable = true)\n", " |-- Non-industry income and expenditure/revenue: double (nullable = true)\n", " |-- Continuous interest rate (after tax): double (nullable = true)\n", " |-- Operating Expense Rate: double (nullable = true)\n", " |-- Research and development expense rate: double (nullable = true)\n", " |-- Cash flow rate: double (nullable = true)\n", " |-- Interest-bearing debt interest rate: double (nullable = true)\n", " |-- Tax rate (A): double (nullable = true)\n", " |-- Net Value Per Share (B): double (nullable = true)\n", " |-- Net Value Per Share (A): double (nullable = true)\n", " |-- Net Value Per Share (C): double (nullable = true)\n", " |-- Persistent EPS in the Last Four Seasons: double (nullable = true)\n", " |-- Cash Flow Per Share: double (nullable = true)\n", " |-- Revenue Per Share (Yuan ??): double (nullable = true)\n", " |-- Operating Profit Per Share (Yuan ??): double (nullable = true)\n", " |-- Per Share Net profit before tax (Yuan ??): double (nullable = true)\n", " |-- Realized Sales Gross Profit Growth Rate: double (nullable = true)\n", " |-- Operating Profit Growth Rate: double (nullable = true)\n", " |-- After-tax Net Profit Growth Rate: double (nullable = true)\n", " |-- Regular Net Profit Growth Rate: double (nullable = true)\n", " |-- Continuous Net Profit Growth Rate: double (nullable = true)\n", " |-- Total Asset Growth Rate: double (nullable = true)\n", " |-- Net Value Growth Rate: double (nullable = true)\n", " |-- Total Asset Return Growth Rate Ratio: double (nullable = true)\n", " |-- Cash Reinvestment %: double (nullable = true)\n", " |-- Current Ratio: double (nullable = true)\n", " |-- Quick Ratio: double (nullable = true)\n", " |-- Interest Expense Ratio: double (nullable = true)\n", " |-- Total debt/Total net worth: double (nullable = true)\n", " |-- Debt ratio %: double (nullable = true)\n", " |-- Net worth/Assets: double (nullable = true)\n", " |-- Long-term fund suitability ratio (A): double (nullable = true)\n", " |-- Borrowing dependency: double (nullable = true)\n", " |-- Contingent liabilities/Net worth: double (nullable = true)\n", " |-- Operating profit/Paid-in capital: double (nullable = true)\n", " |-- Net profit before tax/Paid-in capital: double (nullable = true)\n", " |-- Inventory and accounts receivable/Net value: double (nullable = true)\n", " |-- Total Asset Turnover: double (nullable = true)\n", " |-- Accounts Receivable Turnover: double (nullable = true)\n", " |-- Average Collection Days: double (nullable = true)\n", " |-- Inventory Turnover Rate (times): double (nullable = true)\n", " |-- Fixed Assets Turnover Frequency: double (nullable = true)\n", " |-- Net Worth Turnover Rate (times): double (nullable = true)\n", " |-- Revenue per person: double (nullable = true)\n", " |-- Operating profit per person: double (nullable = true)\n", " |-- Allocation rate per person: double (nullable = true)\n", " |-- Working Capital to Total Assets: double (nullable = true)\n", " |-- Quick Assets/Total Assets: double (nullable = true)\n", " |-- Current Assets/Total Assets: double (nullable = true)\n", " |-- Cash/Total Assets: double (nullable = true)\n", " |-- Quick Assets/Current Liability: double (nullable = true)\n", " |-- Cash/Current Liability: double (nullable = true)\n", " |-- Current Liability to Assets: double (nullable = true)\n", " |-- Operating Funds to Liability: double (nullable = true)\n", " |-- Inventory/Working Capital: double (nullable = true)\n", " |-- Inventory/Current Liability: double (nullable = true)\n", " |-- Current Liabilities/Liability: double (nullable = true)\n", " |-- Working Capital/Equity: double (nullable = true)\n", " |-- Current Liabilities/Equity: double (nullable = true)\n", " |-- Long-term Liability to Current Assets: double (nullable = true)\n", " |-- Retained Earnings to Total Assets: double (nullable = true)\n", " |-- Total income/Total expense: double (nullable = true)\n", " |-- Total expense/Assets: double (nullable = true)\n", " |-- Current Asset Turnover Rate: double (nullable = true)\n", " |-- Quick Asset Turnover Rate: double (nullable = true)\n", " |-- Working capitcal Turnover Rate: double (nullable = true)\n", " |-- Cash Turnover Rate: double (nullable = true)\n", " |-- Cash Flow to Sales: double (nullable = true)\n", " |-- Fixed Assets to Assets: double (nullable = true)\n", " |-- Current Liability to Liability: double (nullable = true)\n", " |-- Current Liability to Equity: double (nullable = true)\n", " |-- Equity to Long-term Liability: double (nullable = true)\n", " |-- Cash Flow to Total Assets: double (nullable = true)\n", " |-- Cash Flow to Liability: double (nullable = true)\n", " |-- CFO to Assets: double (nullable = true)\n", " |-- Cash Flow to Equity: double (nullable = true)\n", " |-- Current Liability to Current Assets: double (nullable = true)\n", " |-- Liability-Assets Flag: double (nullable = true)\n", " |-- Net Income to Total Assets: double (nullable = true)\n", " |-- Total assets to GNP price: double (nullable = true)\n", " |-- No-credit Interval: double (nullable = true)\n", " |-- Gross Profit to Sales: double (nullable = true)\n", " |-- Net Income to Stockholder's Equity: double (nullable = true)\n", " |-- Liability to Equity: double (nullable = true)\n", " |-- Degree of Financial Leverage (DFL): double (nullable = true)\n", " |-- Interest Coverage Ratio (Interest expense to EBIT): double (nullable = true)\n", " |-- Net Income Flag: double (nullable = true)\n", " |-- Equity to Liability: double (nullable = true)\n", "\n" ] } ], "source": [ "df = (\n", " spark.read.format(\"csv\")\n", " .option(\"header\", True)\n", " .option(\"inferSchema\", True)\n", " .load(\n", " \"wasbs://publicwasb@mmlspark.blob.core.windows.net/company_bankruptcy_prediction_data.csv\"\n", " )\n", ")\n", "# print dataset size\n", "print(\"records read: \" + str(df.count()))\n", "print(\"Schema: \")\n", "df.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Split the dataset into train and test" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train, test = df.randomSplit([0.8, 0.2], seed=41)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add featurizer to convert features to vector" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import VectorAssembler\n", "\n", "feature_cols = df.columns[1:]\n", "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", "train_data = featurizer.transform(train)[\"Bankrupt?\", \"features\"]\n", "test_data = featurizer.transform(test)[\"Bankrupt?\", \"features\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Default SynapseML LightGBM" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "23/02/28 02:12:42 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n", "[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "from synapse.ml.lightgbm import LightGBMClassifier\n", "\n", "model = LightGBMClassifier(\n", " objective=\"binary\", featuresCol=\"features\", labelCol=\"Bankrupt?\", isUnbalance=True\n", ")\n", "\n", "model = model.fit(train_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Model Prediction" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DataFrame[evaluation_type: string, confusion_matrix: matrix, accuracy: double, precision: double, recall: double, AUC: double]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[Stage 27:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+---------------+--------------------+-----------------+------------------+-------------------+------------------+\n", "|evaluation_type| confusion_matrix| accuracy| precision| recall| AUC|\n", "+---------------+--------------------+-----------------+------------------+-------------------+------------------+\n", "| Classification|1250.0 23.0 \\n3...|0.958997722095672|0.3611111111111111|0.29545454545454547|0.6386934942512319|\n", "+---------------+--------------------+-----------------+------------------+-------------------+------------------+\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "def predict(model):\n", " from synapse.ml.train import ComputeModelStatistics\n", "\n", " predictions = model.transform(test_data)\n", " # predictions.limit(10).show()\n", " \n", " metrics = ComputeModelStatistics(\n", " evaluationMetric=\"classification\",\n", " labelCol=\"Bankrupt?\",\n", " scoredLabelsCol=\"prediction\",\n", " ).transform(predictions)\n", " display(metrics)\n", " return metrics\n", "\n", "default_metrics = predict(model)\n", "default_metrics.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run FLAML\n", "In the FLAML automl run configuration, users can specify the task type, time budget, error metric, learner list, whether to subsample, resampling strategy type, and so on. All these arguments have default values which will be used if users do not provide them. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "''' import AutoML class from flaml package '''\n", "from flaml import AutoML\n", "from flaml.automl.spark.utils import to_pandas_on_spark\n", "\n", "automl = AutoML()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import os\n", "settings = {\n", " \"time_budget\": 30, # total running time in seconds\n", " \"metric\": 'roc_auc',\n", " \"estimator_list\": ['lgbm_spark'], # list of ML learners; we tune lightgbm in this example\n", " \"task\": 'classification', # task type\n", " \"log_file_name\": 'flaml_experiment.log', # flaml log file\n", " \"seed\": 41, # random seed\n", " \"force_cancel\": True, # force stop training once time_budget is used up\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Disable Arrow optimization to omit below warning:\n", "```\n", "/opt/spark/python/lib/pyspark.zip/pyspark/sql/pandas/conversion.py:87: UserWarning: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by the reason below:\n", " Unsupported type in conversion to Arrow: VectorUDT\n", "Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.\n", " warnings.warn(msg)\n", "```" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"false\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | index | \n", "Bankrupt? | \n", "features | \n", "
---|---|---|---|
0 | \n", "0 | \n", "0 | \n", "[0.0828, 0.0693, 0.0884, 0.6468, 0.6468, 0.997... | \n", "
1 | \n", "1 | \n", "0 | \n", "[0.1606, 0.1788, 0.1832, 0.5897, 0.5897, 0.998... | \n", "
2 | \n", "2 | \n", "0 | \n", "[0.204, 0.2638, 0.2598, 0.4483, 0.4483, 0.9959... | \n", "
3 | \n", "3 | \n", "0 | \n", "[0.217, 0.1881, 0.2451, 0.5992, 0.5992, 0.9962... | \n", "
4 | \n", "4 | \n", "0 | \n", "[0.2314, 0.1628, 0.2068, 0.6001, 0.6001, 0.998... | \n", "