mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-16 13:30:55 +00:00

* add basic support to Spark dataframe add support to SynapseML LightGBM model update to pyspark>=3.2.0 to leverage pandas_on_Spark API * clean code, add TODOs * add sample_train_data for pyspark.pandas dataframe, fix bugs * improve some functions, fix bugs * fix dict change size during iteration * update model predict * update LightGBM model, update test * update SynapseML LightGBM params * update synapseML and tests * update TODOs * Added support to roc_auc for spark models * Added support to score of spark estimator * Added test for automl score of spark estimator * Added cv support to pyspark.pandas dataframe * Update test, fix bugs * Added tests * Updated docs, tests, added a notebook * Fix bugs in non-spark env * Fix bugs and improve tests * Fix uninstall pyspark * Fix tests error * Fix java.lang.OutOfMemoryError: Java heap space * Fix test_performance * Update test_sparkml to test_0sparkml to use the expected spark conf * Remove unnecessary widgets in notebook * Fix iloc java.lang.StackOverflowError * fix pre-commit * Added params check for spark dataframes * Refactor code for train_test_split to a function * Update train_test_split_pyspark * Refactor if-else, remove unnecessary code * Remove y from predict, remove mem control from n_iter compute * Update workflow * Improve _split_pyspark * Fix test failure of too short training time * Fix typos, improve docstrings * Fix index errors of pandas_on_spark, add spark loss metric * Fix typo of ndcgAtK * Update NDCG metrics and tests * Remove unuseful logger * Use cache and count to ensure consistent indexes * refactor for merge maain * fix errors of refactor * Updated SparkLightGBMEstimator and cache * Updated config2params * Remove unused import * Fix unknown parameters * Update default_estimator_list * Add unit tests for spark metrics
832 lines
44 KiB
Plaintext
832 lines
44 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# AutoML with FLAML Library for synapseML models and spark dataframes\n",
|
|
"\n",
|
|
"\n",
|
|
"## 1. Introduction\n",
|
|
"\n",
|
|
"FLAML is a Python library (https://github.com/microsoft/FLAML) designed to automatically produce accurate machine learning models \n",
|
|
"with low computational cost. It is fast and economical. The simple and lightweight design makes it easy \n",
|
|
"to use and extend, such as adding new learners. FLAML can \n",
|
|
"- serve as an economical AutoML engine,\n",
|
|
"- be used as a fast hyperparameter tuning tool, or \n",
|
|
"- be embedded in self-tuning software that requires low latency & resource in repetitive\n",
|
|
" tuning tasks.\n",
|
|
"\n",
|
|
"In this notebook, we demonstrate how to use FLAML library to do AutoML for synapseML models and spark dataframes. We also compare the results between FLAML AutoML and default SynapseML. \n",
|
|
"In this example, we use LightGBM to build a classification model in order to predict bankruptcy.\n",
|
|
"\n",
|
|
"Since the dataset is unbalanced, `AUC` is a better metric than `Accuracy`. FLAML (1 min of training) achieved AUC **0.79**, the default SynapseML model only got AUC **0.64**. \n",
|
|
"\n",
|
|
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `synapse` option:\n",
|
|
"```bash\n",
|
|
"pip install flaml[synapse]>=1.1.3; \n",
|
|
"```\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# %pip install \"flaml[synapse]>=1.1.3\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. Load data and preprocess"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
":: loading settings :: url = jar:file:/datadrive/spark/spark33/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Ivy Default Cache set to: /home/lijiang1/.ivy2/cache\n",
|
|
"The jars for the packages stored in: /home/lijiang1/.ivy2/jars\n",
|
|
"com.microsoft.azure#synapseml_2.12 added as a dependency\n",
|
|
"org.apache.hadoop#hadoop-azure added as a dependency\n",
|
|
"com.microsoft.azure#azure-storage added as a dependency\n",
|
|
":: resolving dependencies :: org.apache.spark#spark-submit-parent-bfb2447b-61c5-4941-bf9b-0548472077eb;1.0\n",
|
|
"\tconfs: [default]\n",
|
|
"\tfound com.microsoft.azure#synapseml_2.12;0.10.2 in central\n",
|
|
"\tfound com.microsoft.azure#synapseml-core_2.12;0.10.2 in central\n",
|
|
"\tfound org.scalactic#scalactic_2.12;3.2.14 in local-m2-cache\n",
|
|
"\tfound org.scala-lang#scala-reflect;2.12.15 in central\n",
|
|
"\tfound io.spray#spray-json_2.12;1.3.5 in central\n",
|
|
"\tfound com.jcraft#jsch;0.1.54 in central\n",
|
|
"\tfound org.apache.httpcomponents.client5#httpclient5;5.1.3 in central\n",
|
|
"\tfound org.apache.httpcomponents.core5#httpcore5;5.1.3 in central\n",
|
|
"\tfound org.apache.httpcomponents.core5#httpcore5-h2;5.1.3 in central\n",
|
|
"\tfound org.slf4j#slf4j-api;1.7.25 in local-m2-cache\n",
|
|
"\tfound commons-codec#commons-codec;1.15 in local-m2-cache\n",
|
|
"\tfound org.apache.httpcomponents#httpmime;4.5.13 in local-m2-cache\n",
|
|
"\tfound org.apache.httpcomponents#httpclient;4.5.13 in local-m2-cache\n",
|
|
"\tfound org.apache.httpcomponents#httpcore;4.4.13 in central\n",
|
|
"\tfound commons-logging#commons-logging;1.2 in central\n",
|
|
"\tfound com.linkedin.isolation-forest#isolation-forest_3.2.0_2.12;2.0.8 in central\n",
|
|
"\tfound com.chuusai#shapeless_2.12;2.3.2 in central\n",
|
|
"\tfound org.typelevel#macro-compat_2.12;1.1.1 in central\n",
|
|
"\tfound org.apache.spark#spark-avro_2.12;3.2.0 in central\n",
|
|
"\tfound org.tukaani#xz;1.8 in central\n",
|
|
"\tfound org.spark-project.spark#unused;1.0.0 in central\n",
|
|
"\tfound org.testng#testng;6.8.8 in central\n",
|
|
"\tfound org.beanshell#bsh;2.0b4 in central\n",
|
|
"\tfound com.beust#jcommander;1.27 in central\n",
|
|
"\tfound com.microsoft.azure#synapseml-deep-learning_2.12;0.10.2 in central\n",
|
|
"\tfound com.microsoft.azure#synapseml-opencv_2.12;0.10.2 in central\n",
|
|
"\tfound org.openpnp#opencv;3.2.0-1 in central\n",
|
|
"\tfound com.microsoft.azure#onnx-protobuf_2.12;0.9.1 in central\n",
|
|
"\tfound com.microsoft.cntk#cntk;2.4 in central\n",
|
|
"\tfound com.microsoft.onnxruntime#onnxruntime_gpu;1.8.1 in central\n",
|
|
"\tfound com.microsoft.azure#synapseml-cognitive_2.12;0.10.2 in central\n",
|
|
"\tfound com.microsoft.cognitiveservices.speech#client-jar-sdk;1.14.0 in central\n",
|
|
"\tfound com.microsoft.azure#synapseml-vw_2.12;0.10.2 in central\n",
|
|
"\tfound com.github.vowpalwabbit#vw-jni;8.9.1 in central\n",
|
|
"\tfound com.microsoft.azure#synapseml-lightgbm_2.12;0.10.2 in central\n",
|
|
"\tfound com.microsoft.ml.lightgbm#lightgbmlib;3.2.110 in central\n",
|
|
"\tfound org.apache.hadoop#hadoop-azure;3.3.1 in central\n",
|
|
"\tfound org.apache.hadoop.thirdparty#hadoop-shaded-guava;1.1.1 in local-m2-cache\n",
|
|
"\tfound org.eclipse.jetty#jetty-util-ajax;9.4.40.v20210413 in central\n",
|
|
"\tfound org.eclipse.jetty#jetty-util;9.4.40.v20210413 in central\n",
|
|
"\tfound org.codehaus.jackson#jackson-mapper-asl;1.9.13 in local-m2-cache\n",
|
|
"\tfound org.codehaus.jackson#jackson-core-asl;1.9.13 in local-m2-cache\n",
|
|
"\tfound org.wildfly.openssl#wildfly-openssl;1.0.7.Final in local-m2-cache\n",
|
|
"\tfound com.microsoft.azure#azure-storage;8.6.6 in central\n",
|
|
"\tfound com.fasterxml.jackson.core#jackson-core;2.9.4 in central\n",
|
|
"\tfound org.apache.commons#commons-lang3;3.4 in local-m2-cache\n",
|
|
"\tfound com.microsoft.azure#azure-keyvault-core;1.2.4 in central\n",
|
|
"\tfound com.google.guava#guava;24.1.1-jre in central\n",
|
|
"\tfound com.google.code.findbugs#jsr305;1.3.9 in central\n",
|
|
"\tfound org.checkerframework#checker-compat-qual;2.0.0 in central\n",
|
|
"\tfound com.google.errorprone#error_prone_annotations;2.1.3 in central\n",
|
|
"\tfound com.google.j2objc#j2objc-annotations;1.1 in central\n",
|
|
"\tfound org.codehaus.mojo#animal-sniffer-annotations;1.14 in central\n",
|
|
":: resolution report :: resolve 992ms :: artifacts dl 77ms\n",
|
|
"\t:: modules in use:\n",
|
|
"\tcom.beust#jcommander;1.27 from central in [default]\n",
|
|
"\tcom.chuusai#shapeless_2.12;2.3.2 from central in [default]\n",
|
|
"\tcom.fasterxml.jackson.core#jackson-core;2.9.4 from central in [default]\n",
|
|
"\tcom.github.vowpalwabbit#vw-jni;8.9.1 from central in [default]\n",
|
|
"\tcom.google.code.findbugs#jsr305;1.3.9 from central in [default]\n",
|
|
"\tcom.google.errorprone#error_prone_annotations;2.1.3 from central in [default]\n",
|
|
"\tcom.google.guava#guava;24.1.1-jre from central in [default]\n",
|
|
"\tcom.google.j2objc#j2objc-annotations;1.1 from central in [default]\n",
|
|
"\tcom.jcraft#jsch;0.1.54 from central in [default]\n",
|
|
"\tcom.linkedin.isolation-forest#isolation-forest_3.2.0_2.12;2.0.8 from central in [default]\n",
|
|
"\tcom.microsoft.azure#azure-keyvault-core;1.2.4 from central in [default]\n",
|
|
"\tcom.microsoft.azure#azure-storage;8.6.6 from central in [default]\n",
|
|
"\tcom.microsoft.azure#onnx-protobuf_2.12;0.9.1 from central in [default]\n",
|
|
"\tcom.microsoft.azure#synapseml-cognitive_2.12;0.10.2 from central in [default]\n",
|
|
"\tcom.microsoft.azure#synapseml-core_2.12;0.10.2 from central in [default]\n",
|
|
"\tcom.microsoft.azure#synapseml-deep-learning_2.12;0.10.2 from central in [default]\n",
|
|
"\tcom.microsoft.azure#synapseml-lightgbm_2.12;0.10.2 from central in [default]\n",
|
|
"\tcom.microsoft.azure#synapseml-opencv_2.12;0.10.2 from central in [default]\n",
|
|
"\tcom.microsoft.azure#synapseml-vw_2.12;0.10.2 from central in [default]\n",
|
|
"\tcom.microsoft.azure#synapseml_2.12;0.10.2 from central in [default]\n",
|
|
"\tcom.microsoft.cntk#cntk;2.4 from central in [default]\n",
|
|
"\tcom.microsoft.cognitiveservices.speech#client-jar-sdk;1.14.0 from central in [default]\n",
|
|
"\tcom.microsoft.ml.lightgbm#lightgbmlib;3.2.110 from central in [default]\n",
|
|
"\tcom.microsoft.onnxruntime#onnxruntime_gpu;1.8.1 from central in [default]\n",
|
|
"\tcommons-codec#commons-codec;1.15 from local-m2-cache in [default]\n",
|
|
"\tcommons-logging#commons-logging;1.2 from central in [default]\n",
|
|
"\tio.spray#spray-json_2.12;1.3.5 from central in [default]\n",
|
|
"\torg.apache.commons#commons-lang3;3.4 from local-m2-cache in [default]\n",
|
|
"\torg.apache.hadoop#hadoop-azure;3.3.1 from central in [default]\n",
|
|
"\torg.apache.hadoop.thirdparty#hadoop-shaded-guava;1.1.1 from local-m2-cache in [default]\n",
|
|
"\torg.apache.httpcomponents#httpclient;4.5.13 from local-m2-cache in [default]\n",
|
|
"\torg.apache.httpcomponents#httpcore;4.4.13 from central in [default]\n",
|
|
"\torg.apache.httpcomponents#httpmime;4.5.13 from local-m2-cache in [default]\n",
|
|
"\torg.apache.httpcomponents.client5#httpclient5;5.1.3 from central in [default]\n",
|
|
"\torg.apache.httpcomponents.core5#httpcore5;5.1.3 from central in [default]\n",
|
|
"\torg.apache.httpcomponents.core5#httpcore5-h2;5.1.3 from central in [default]\n",
|
|
"\torg.apache.spark#spark-avro_2.12;3.2.0 from central in [default]\n",
|
|
"\torg.beanshell#bsh;2.0b4 from central in [default]\n",
|
|
"\torg.checkerframework#checker-compat-qual;2.0.0 from central in [default]\n",
|
|
"\torg.codehaus.jackson#jackson-core-asl;1.9.13 from local-m2-cache in [default]\n",
|
|
"\torg.codehaus.jackson#jackson-mapper-asl;1.9.13 from local-m2-cache in [default]\n",
|
|
"\torg.codehaus.mojo#animal-sniffer-annotations;1.14 from central in [default]\n",
|
|
"\torg.eclipse.jetty#jetty-util;9.4.40.v20210413 from central in [default]\n",
|
|
"\torg.eclipse.jetty#jetty-util-ajax;9.4.40.v20210413 from central in [default]\n",
|
|
"\torg.openpnp#opencv;3.2.0-1 from central in [default]\n",
|
|
"\torg.scala-lang#scala-reflect;2.12.15 from central in [default]\n",
|
|
"\torg.scalactic#scalactic_2.12;3.2.14 from local-m2-cache in [default]\n",
|
|
"\torg.slf4j#slf4j-api;1.7.25 from local-m2-cache in [default]\n",
|
|
"\torg.spark-project.spark#unused;1.0.0 from central in [default]\n",
|
|
"\torg.testng#testng;6.8.8 from central in [default]\n",
|
|
"\torg.tukaani#xz;1.8 from central in [default]\n",
|
|
"\torg.typelevel#macro-compat_2.12;1.1.1 from central in [default]\n",
|
|
"\torg.wildfly.openssl#wildfly-openssl;1.0.7.Final from local-m2-cache in [default]\n",
|
|
"\t:: evicted modules:\n",
|
|
"\tcommons-codec#commons-codec;1.11 by [commons-codec#commons-codec;1.15] in [default]\n",
|
|
"\tcom.microsoft.azure#azure-storage;7.0.1 by [com.microsoft.azure#azure-storage;8.6.6] in [default]\n",
|
|
"\torg.slf4j#slf4j-api;1.7.12 by [org.slf4j#slf4j-api;1.7.25] in [default]\n",
|
|
"\torg.apache.commons#commons-lang3;3.8.1 by [org.apache.commons#commons-lang3;3.4] in [default]\n",
|
|
"\t---------------------------------------------------------------------\n",
|
|
"\t| | modules || artifacts |\n",
|
|
"\t| conf | number| search|dwnlded|evicted|| number|dwnlded|\n",
|
|
"\t---------------------------------------------------------------------\n",
|
|
"\t| default | 57 | 0 | 0 | 4 || 53 | 0 |\n",
|
|
"\t---------------------------------------------------------------------\n",
|
|
":: retrieving :: org.apache.spark#spark-submit-parent-bfb2447b-61c5-4941-bf9b-0548472077eb\n",
|
|
"\tconfs: [default]\n",
|
|
"\t0 artifacts copied, 53 already retrieved (0kB/20ms)\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"23/02/28 02:12:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Setting default log level to \"WARN\".\n",
|
|
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pyspark\n",
|
|
"\n",
|
|
"spark = (\n",
|
|
" pyspark.sql.SparkSession.builder.appName(\"MyApp\")\n",
|
|
" .config(\n",
|
|
" \"spark.jars.packages\",\n",
|
|
" f\"com.microsoft.azure:synapseml_2.12:0.10.2,org.apache.hadoop:hadoop-azure:{pyspark.__version__},com.microsoft.azure:azure-storage:8.6.6\",\n",
|
|
" )\n",
|
|
" .config(\"spark.sql.debug.maxToStringFields\", \"100\")\n",
|
|
" .getOrCreate()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"23/02/28 02:12:32 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-azure-file-system.properties,hadoop-metrics2.properties\n",
|
|
"records read: 6819\n",
|
|
"Schema: \n",
|
|
"root\n",
|
|
" |-- Bankrupt?: integer (nullable = true)\n",
|
|
" |-- ROA(C) before interest and depreciation before interest: double (nullable = true)\n",
|
|
" |-- ROA(A) before interest and % after tax: double (nullable = true)\n",
|
|
" |-- ROA(B) before interest and depreciation after tax: double (nullable = true)\n",
|
|
" |-- Operating Gross Margin: double (nullable = true)\n",
|
|
" |-- Realized Sales Gross Margin: double (nullable = true)\n",
|
|
" |-- Operating Profit Rate: double (nullable = true)\n",
|
|
" |-- Pre-tax net Interest Rate: double (nullable = true)\n",
|
|
" |-- After-tax net Interest Rate: double (nullable = true)\n",
|
|
" |-- Non-industry income and expenditure/revenue: double (nullable = true)\n",
|
|
" |-- Continuous interest rate (after tax): double (nullable = true)\n",
|
|
" |-- Operating Expense Rate: double (nullable = true)\n",
|
|
" |-- Research and development expense rate: double (nullable = true)\n",
|
|
" |-- Cash flow rate: double (nullable = true)\n",
|
|
" |-- Interest-bearing debt interest rate: double (nullable = true)\n",
|
|
" |-- Tax rate (A): double (nullable = true)\n",
|
|
" |-- Net Value Per Share (B): double (nullable = true)\n",
|
|
" |-- Net Value Per Share (A): double (nullable = true)\n",
|
|
" |-- Net Value Per Share (C): double (nullable = true)\n",
|
|
" |-- Persistent EPS in the Last Four Seasons: double (nullable = true)\n",
|
|
" |-- Cash Flow Per Share: double (nullable = true)\n",
|
|
" |-- Revenue Per Share (Yuan ??): double (nullable = true)\n",
|
|
" |-- Operating Profit Per Share (Yuan ??): double (nullable = true)\n",
|
|
" |-- Per Share Net profit before tax (Yuan ??): double (nullable = true)\n",
|
|
" |-- Realized Sales Gross Profit Growth Rate: double (nullable = true)\n",
|
|
" |-- Operating Profit Growth Rate: double (nullable = true)\n",
|
|
" |-- After-tax Net Profit Growth Rate: double (nullable = true)\n",
|
|
" |-- Regular Net Profit Growth Rate: double (nullable = true)\n",
|
|
" |-- Continuous Net Profit Growth Rate: double (nullable = true)\n",
|
|
" |-- Total Asset Growth Rate: double (nullable = true)\n",
|
|
" |-- Net Value Growth Rate: double (nullable = true)\n",
|
|
" |-- Total Asset Return Growth Rate Ratio: double (nullable = true)\n",
|
|
" |-- Cash Reinvestment %: double (nullable = true)\n",
|
|
" |-- Current Ratio: double (nullable = true)\n",
|
|
" |-- Quick Ratio: double (nullable = true)\n",
|
|
" |-- Interest Expense Ratio: double (nullable = true)\n",
|
|
" |-- Total debt/Total net worth: double (nullable = true)\n",
|
|
" |-- Debt ratio %: double (nullable = true)\n",
|
|
" |-- Net worth/Assets: double (nullable = true)\n",
|
|
" |-- Long-term fund suitability ratio (A): double (nullable = true)\n",
|
|
" |-- Borrowing dependency: double (nullable = true)\n",
|
|
" |-- Contingent liabilities/Net worth: double (nullable = true)\n",
|
|
" |-- Operating profit/Paid-in capital: double (nullable = true)\n",
|
|
" |-- Net profit before tax/Paid-in capital: double (nullable = true)\n",
|
|
" |-- Inventory and accounts receivable/Net value: double (nullable = true)\n",
|
|
" |-- Total Asset Turnover: double (nullable = true)\n",
|
|
" |-- Accounts Receivable Turnover: double (nullable = true)\n",
|
|
" |-- Average Collection Days: double (nullable = true)\n",
|
|
" |-- Inventory Turnover Rate (times): double (nullable = true)\n",
|
|
" |-- Fixed Assets Turnover Frequency: double (nullable = true)\n",
|
|
" |-- Net Worth Turnover Rate (times): double (nullable = true)\n",
|
|
" |-- Revenue per person: double (nullable = true)\n",
|
|
" |-- Operating profit per person: double (nullable = true)\n",
|
|
" |-- Allocation rate per person: double (nullable = true)\n",
|
|
" |-- Working Capital to Total Assets: double (nullable = true)\n",
|
|
" |-- Quick Assets/Total Assets: double (nullable = true)\n",
|
|
" |-- Current Assets/Total Assets: double (nullable = true)\n",
|
|
" |-- Cash/Total Assets: double (nullable = true)\n",
|
|
" |-- Quick Assets/Current Liability: double (nullable = true)\n",
|
|
" |-- Cash/Current Liability: double (nullable = true)\n",
|
|
" |-- Current Liability to Assets: double (nullable = true)\n",
|
|
" |-- Operating Funds to Liability: double (nullable = true)\n",
|
|
" |-- Inventory/Working Capital: double (nullable = true)\n",
|
|
" |-- Inventory/Current Liability: double (nullable = true)\n",
|
|
" |-- Current Liabilities/Liability: double (nullable = true)\n",
|
|
" |-- Working Capital/Equity: double (nullable = true)\n",
|
|
" |-- Current Liabilities/Equity: double (nullable = true)\n",
|
|
" |-- Long-term Liability to Current Assets: double (nullable = true)\n",
|
|
" |-- Retained Earnings to Total Assets: double (nullable = true)\n",
|
|
" |-- Total income/Total expense: double (nullable = true)\n",
|
|
" |-- Total expense/Assets: double (nullable = true)\n",
|
|
" |-- Current Asset Turnover Rate: double (nullable = true)\n",
|
|
" |-- Quick Asset Turnover Rate: double (nullable = true)\n",
|
|
" |-- Working capitcal Turnover Rate: double (nullable = true)\n",
|
|
" |-- Cash Turnover Rate: double (nullable = true)\n",
|
|
" |-- Cash Flow to Sales: double (nullable = true)\n",
|
|
" |-- Fixed Assets to Assets: double (nullable = true)\n",
|
|
" |-- Current Liability to Liability: double (nullable = true)\n",
|
|
" |-- Current Liability to Equity: double (nullable = true)\n",
|
|
" |-- Equity to Long-term Liability: double (nullable = true)\n",
|
|
" |-- Cash Flow to Total Assets: double (nullable = true)\n",
|
|
" |-- Cash Flow to Liability: double (nullable = true)\n",
|
|
" |-- CFO to Assets: double (nullable = true)\n",
|
|
" |-- Cash Flow to Equity: double (nullable = true)\n",
|
|
" |-- Current Liability to Current Assets: double (nullable = true)\n",
|
|
" |-- Liability-Assets Flag: double (nullable = true)\n",
|
|
" |-- Net Income to Total Assets: double (nullable = true)\n",
|
|
" |-- Total assets to GNP price: double (nullable = true)\n",
|
|
" |-- No-credit Interval: double (nullable = true)\n",
|
|
" |-- Gross Profit to Sales: double (nullable = true)\n",
|
|
" |-- Net Income to Stockholder's Equity: double (nullable = true)\n",
|
|
" |-- Liability to Equity: double (nullable = true)\n",
|
|
" |-- Degree of Financial Leverage (DFL): double (nullable = true)\n",
|
|
" |-- Interest Coverage Ratio (Interest expense to EBIT): double (nullable = true)\n",
|
|
" |-- Net Income Flag: double (nullable = true)\n",
|
|
" |-- Equity to Liability: double (nullable = true)\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"df = (\n",
|
|
" spark.read.format(\"csv\")\n",
|
|
" .option(\"header\", True)\n",
|
|
" .option(\"inferSchema\", True)\n",
|
|
" .load(\n",
|
|
" \"wasbs://publicwasb@mmlspark.blob.core.windows.net/company_bankruptcy_prediction_data.csv\"\n",
|
|
" )\n",
|
|
")\n",
|
|
"# print dataset size\n",
|
|
"print(\"records read: \" + str(df.count()))\n",
|
|
"print(\"Schema: \")\n",
|
|
"df.printSchema()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Split the dataset into train and test"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train, test = df.randomSplit([0.8, 0.2], seed=41)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Add featurizer to convert features to vector"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from pyspark.ml.feature import VectorAssembler\n",
|
|
"\n",
|
|
"feature_cols = df.columns[1:]\n",
|
|
"featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
|
|
"train_data = featurizer.transform(train)[\"Bankrupt?\", \"features\"]\n",
|
|
"test_data = featurizer.transform(test)[\"Bankrupt?\", \"features\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Default SynapseML LightGBM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"23/02/28 02:12:42 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n",
|
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" \r"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from synapse.ml.lightgbm import LightGBMClassifier\n",
|
|
"\n",
|
|
"model = LightGBMClassifier(\n",
|
|
" objective=\"binary\", featuresCol=\"features\", labelCol=\"Bankrupt?\", isUnbalance=True\n",
|
|
")\n",
|
|
"\n",
|
|
"model = model.fit(train_data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Model Prediction"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DataFrame[evaluation_type: string, confusion_matrix: matrix, accuracy: double, precision: double, recall: double, AUC: double]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Stage 27:> (0 + 1) / 1]\r"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"+---------------+--------------------+-----------------+------------------+-------------------+------------------+\n",
|
|
"|evaluation_type| confusion_matrix| accuracy| precision| recall| AUC|\n",
|
|
"+---------------+--------------------+-----------------+------------------+-------------------+------------------+\n",
|
|
"| Classification|1250.0 23.0 \\n3...|0.958997722095672|0.3611111111111111|0.29545454545454547|0.6386934942512319|\n",
|
|
"+---------------+--------------------+-----------------+------------------+-------------------+------------------+\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" \r"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def predict(model):\n",
|
|
" from synapse.ml.train import ComputeModelStatistics\n",
|
|
"\n",
|
|
" predictions = model.transform(test_data)\n",
|
|
" # predictions.limit(10).show()\n",
|
|
" \n",
|
|
" metrics = ComputeModelStatistics(\n",
|
|
" evaluationMetric=\"classification\",\n",
|
|
" labelCol=\"Bankrupt?\",\n",
|
|
" scoredLabelsCol=\"prediction\",\n",
|
|
" ).transform(predictions)\n",
|
|
" display(metrics)\n",
|
|
" return metrics\n",
|
|
"\n",
|
|
"default_metrics = predict(model)\n",
|
|
"default_metrics.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Run FLAML\n",
|
|
"In the FLAML automl run configuration, users can specify the task type, time budget, error metric, learner list, whether to subsample, resampling strategy type, and so on. All these arguments have default values which will be used if users do not provide them. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"''' import AutoML class from flaml package '''\n",
|
|
"from flaml import AutoML\n",
|
|
"from flaml.automl.spark.utils import to_pandas_on_spark\n",
|
|
"\n",
|
|
"automl = AutoML()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"settings = {\n",
|
|
" \"time_budget\": 30, # total running time in seconds\n",
|
|
" \"metric\": 'roc_auc',\n",
|
|
" \"estimator_list\": ['lgbm_spark'], # list of ML learners; we tune lightgbm in this example\n",
|
|
" \"task\": 'classification', # task type\n",
|
|
" \"log_file_name\": 'flaml_experiment.log', # flaml log file\n",
|
|
" \"seed\": 41, # random seed\n",
|
|
" \"force_cancel\": True, # force stop training once time_budget is used up\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Disable Arrow optimization to omit below warning:\n",
|
|
"```\n",
|
|
"/opt/spark/python/lib/pyspark.zip/pyspark/sql/pandas/conversion.py:87: UserWarning: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by the reason below:\n",
|
|
" Unsupported type in conversion to Arrow: VectorUDT\n",
|
|
"Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.\n",
|
|
" warnings.warn(msg)\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"false\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>index</th>\n",
|
|
" <th>Bankrupt?</th>\n",
|
|
" <th>features</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>[0.0828, 0.0693, 0.0884, 0.6468, 0.6468, 0.997...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>[0.1606, 0.1788, 0.1832, 0.5897, 0.5897, 0.998...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>2</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>[0.204, 0.2638, 0.2598, 0.4483, 0.4483, 0.9959...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>3</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>[0.217, 0.1881, 0.2451, 0.5992, 0.5992, 0.9962...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>4</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>[0.2314, 0.1628, 0.2068, 0.6001, 0.6001, 0.998...</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" index Bankrupt? features\n",
|
|
"0 0 0 [0.0828, 0.0693, 0.0884, 0.6468, 0.6468, 0.9971, 0.7958, 0.8078, 0.3047, 0.78, 0.0027, 0.0029, 0.428, 0.0, 0.0, 0.1273, 0.1273, 0.1273, 0.1872, 0.3127, 0.0038, 0.062, 0.1482, 0.022, 0.8478, 0.6893, 0.6893, 0.2176, 0.0, 0.0002, 0.2628, 0.291, 0.0039, 0.0025, 0.6306, 0.0137, 0.1776, 0.8224, 0.005, 0.3696, 0.0054, 0.062, 0.1473, 0.3986, 0.1109, 0.0003, 0.0182, 7150000000.0, 0.0003, 0.0302, 0.0025, 0.3763, 0.0009, 0.6971, 0.262, 0.3948, 0.0918, 0.0025, 0.0027, 0.1828, 0.242, 0.2766, 0.0039, 0.984, 0.7264, 0.3382, 0.0, 0.0, 0.0021, 1.0, 3650000000.0, 2500000000.0, 0.5939, 3060000000.0, 0.6714, 0.4836, 0.984, 0.3382, 0.1109, 0.0, 0.3666, 0.0, 0.1653, 0.072, 0.0, 0.0, 0.0, 0.6237, 0.6468, 0.7483, 0.2847, 0.0268, 0.5652, 1.0, 0.0199]\n",
|
|
"1 1 0 [0.1606, 0.1788, 0.1832, 0.5897, 0.5897, 0.9986, 0.7969, 0.8088, 0.3034, 0.781, 0.0003, 0.0002, 0.4434, 0.0002, 0.0, 0.1341, 0.1341, 0.1341, 0.1637, 0.2935, 0.0215, 0.0575, 0.1295, 0.0222, 0.848, 0.6894, 0.6894, 0.2176, 6700000000.0, 0.0003, 0.2646, 0.1561, 0.0075, 0.0016, 0.6306, 0.0275, 0.2228, 0.7772, 0.0061, 0.3952, 0.0054, 0.0574, 0.1285, 0.4264, 0.2579, 0.0218, 0.0003, 7550000000.0, 0.0029, 0.0569, 0.0184, 0.3689, 0.0009, 0.8013, 0.3721, 0.9357, 0.1842, 0.0028, 0.0042, 0.232, 0.2865, 0.2785, 0.0123, 1.0, 0.7403, 0.3506, 0.0, 0.811, 0.0019, 0.1083, 0.0001, 5310000000.0, 0.5939, 7880000000.0, 0.6715, 0.0499, 1.0, 0.3506, 0.1109, 0.463, 0.4385, 0.1781, 0.2476, 0.0388, 0.0, 0.5917, 4370000000.0, 0.6236, 0.5897, 0.8023, 0.2947, 0.0268, 0.5651, 1.0, 0.0151]\n",
|
|
"2 2 0 [0.204, 0.2638, 0.2598, 0.4483, 0.4483, 0.9959, 0.7937, 0.8063, 0.3034, 0.7782, 0.0007, 0.0004, 0.4511, 0.0003, 0.0, 0.1387, 0.1387, 0.1387, 0.1546, 0.263, 0.004, 0.0393, 0.0757, 0.0187, 0.8468, 0.6872, 0.6872, 0.2173, 0.0002, 0.0004, 0.2588, 0.1568, 0.0025, 0.0007, 0.6305, 0.04, 0.2419, 0.7581, 0.0048, 0.4073, 0.0054, 0.0394, 0.1165, 0.4142, 0.0315, 0.0009, 0.0074, 5310000000.0, 3030000000.0, 0.0195, 0.002, 0.3723, 0.0124, 0.6252, 0.1282, 0.3562, 0.0377, 0.0008, 0.0008, 0.2515, 0.3097, 0.2767, 0.0046, 1.0, 0.7042, 0.3617, 0.0, 0.8891, 0.0013, 0.0213, 0.0006, 0.0002, 0.5933, 0.0002, 0.6715, 0.5863, 1.0, 0.3617, 0.1109, 0.635, 0.4584, 0.3252, 0.3106, 0.1097, 0.0, 0.6816, 0.0003, 0.6221, 0.4483, 0.8117, 0.3038, 0.0268, 0.5651, 1.0, 0.0136]\n",
|
|
"3 3 0 [0.217, 0.1881, 0.2451, 0.5992, 0.5992, 0.9962, 0.794, 0.8061, 0.3034, 0.7781, 0.0029, 0.0038, 0.4555, 0.0003, 0.0, 0.1277, 0.1277, 0.1277, 0.1387, 0.271, 0.0049, 0.0319, 0.0091, 0.022, 0.848, 0.6893, 0.6893, 0.2176, 9790000000.0, 0.0011, 0.2629, 0.0, 0.004, 0.004, 0.6305, 0.2222, 0.286, 0.714, 0.0052, 0.6137, 0.0054, 0.0608, 0.1361, 0.407, 0.039, 0.0008, 0.0078, 0.0002, 0.0006, 0.1497, 0.0091, 0.3072, 0.0015, 0.6671, 0.6679, 0.656, 0.6709, 0.004, 0.012, 0.2966, 0.3228, 0.2769, 0.0003, 1.0, 0.6453, 0.523, 0.0, 0.8015, 0.002, 0.112, 0.0008, 0.0008, 0.5937, 0.0022, 0.6723, 0.022, 1.0, 0.523, 0.1109, 0.9353, 0.4857, 0.402, 1.0, 0.0707, 0.0, 0.6196, 0.0011, 0.6236, 0.5992, 0.6346, 0.4359, 0.0268, 0.565, 1.0, 0.0108]\n",
|
|
"4 4 0 [0.2314, 0.1628, 0.2068, 0.6001, 0.6001, 0.9988, 0.796, 0.8078, 0.3015, 0.7801, 0.0003, 0.0002, 0.458, 0.0005, 0.0, 0.1351, 0.1351, 0.1351, 0.1599, 0.315, 0.0085, 0.088, 0.1271, 0.0223, 0.8481, 0.6894, 0.6894, 0.2176, 3860000000.0, 0.0003, 0.2633, 0.363, 0.011, 0.0072, 0.6306, 0.0214, 0.2081, 0.7919, 0.0053, 0.3832, 0.0123, 0.088, 0.1261, 0.3996, 0.0885, 0.0008, 0.0075, 0.0005, 0.0003, 0.025, 0.0108, 0.3855, 0.0044, 0.8522, 0.8464, 0.8194, 0.0331, 0.0111, 0.0013, 0.1393, 0.3341, 0.277, 0.0003, 0.637, 0.7459, 0.3384, 0.0024, 0.8278, 0.002, 0.184, 0.0003, 0.0003, 0.594, 3320000000.0, 0.6715, 0.1798, 0.637, 0.3384, 0.1171, 0.587, 0.4524, 0.521, 0.2972, 0.0265, 0.0, 0.5269, 0.0003, 0.6241, 0.6001, 0.7985, 0.2903, 0.0268, 0.5651, 1.0, 0.0164]"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"df = to_pandas_on_spark(to_pandas_on_spark(train_data).to_spark(index_col=\"index\"))\n",
|
|
"\n",
|
|
"df.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[flaml.automl.automl: 02-28 02:12:59] {2922} INFO - task = classification\n",
|
|
"[flaml.automl.automl: 02-28 02:13:00] {2924} INFO - Data split method: stratified\n",
|
|
"[flaml.automl.automl: 02-28 02:13:00] {2927} INFO - Evaluation method: cv\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: `to_pandas` loads all data into the driver's memory. It should only be used if the resulting pandas Series is expected to be small.\n",
|
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[flaml.automl.automl: 02-28 02:13:01] {3054} INFO - Minimizing error metric: 1-roc_auc\n",
|
|
"[flaml.automl.automl: 02-28 02:13:01] {3209} INFO - List of ML learners in AutoML Run: ['lgbm_spark']\n",
|
|
"[flaml.automl.automl: 02-28 02:13:01] {3539} INFO - iteration 0, current learner lgbm_spark\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: `to_numpy` loads all data into the driver's memory. It should only be used if the resulting NumPy ndarray is expected to be small.\n",
|
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n",
|
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: If `index_col` is not specified for `to_spark`, the existing index is lost when converting to Spark DataFrame.\n",
|
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n",
|
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: `to_numpy` loads all data into the driver's memory. It should only be used if the resulting NumPy ndarray is expected to be small.\n",
|
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[flaml.automl.automl: 02-28 02:13:48] {3677} INFO - Estimated sufficient time budget=464999s. Estimated necessary time budget=465s.\n",
|
|
"[flaml.automl.automl: 02-28 02:13:48] {3724} INFO - at 48.5s,\testimator lgbm_spark's best error=0.0871,\tbest estimator lgbm_spark's best error=0.0871\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: If `index_col` is not specified for `to_spark`, the existing index is lost when converting to Spark DataFrame.\n",
|
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n",
|
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n",
|
|
"[flaml.automl.automl: 02-28 02:13:54] {3988} INFO - retrain lgbm_spark for 6.2s\n",
|
|
"[flaml.automl.automl: 02-28 02:13:54] {3995} INFO - retrained model: LightGBMClassifier_a2177c5be001\n",
|
|
"[flaml.automl.automl: 02-28 02:13:54] {3239} INFO - fit succeeded\n",
|
|
"[flaml.automl.automl: 02-28 02:13:54] {3240} INFO - Time taken to find the best model: 48.4579541683197\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"'''The main flaml automl API'''\n",
|
|
"automl.fit(dataframe=df, label='Bankrupt?', labelCol=\"Bankrupt?\", isUnbalance=True, **settings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Best model and metric"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Best hyperparmeter config: {'numIterations': 4, 'numLeaves': 4, 'minDataInLeaf': 20, 'learningRate': 0.09999999999999995, 'log_max_bin': 8, 'featureFraction': 1.0, 'lambdaL1': 0.0009765625, 'lambdaL2': 1.0}\n",
|
|
"Best roc_auc on validation data: 0.9129\n",
|
|
"Training duration of best run: 6.237 s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"''' retrieve best config'''\n",
|
|
"print('Best hyperparmeter config:', automl.best_config)\n",
|
|
"print('Best roc_auc on validation data: {0:.4g}'.format(1-automl.best_loss))\n",
|
|
"print('Training duration of best run: {0:.4g} s'.format(automl.best_config_train_time))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DataFrame[evaluation_type: string, confusion_matrix: matrix, accuracy: double, precision: double, recall: double, AUC: double]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"+---------------+--------------------+------------------+-------------------+------------------+------------------+\n",
|
|
"|evaluation_type| confusion_matrix| accuracy| precision| recall| AUC|\n",
|
|
"+---------------+--------------------+------------------+-------------------+------------------+------------------+\n",
|
|
"| Classification|1218.0 55.0 \\n1...|0.9453302961275627|0.32926829268292684|0.6136363636363636|0.7852156680711276|\n",
|
|
"+---------------+--------------------+------------------+-------------------+------------------+------------------+\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"flaml_metrics = predict(automl.model.estimator)\n",
|
|
"flaml_metrics.show()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"collapsed_sections": [],
|
|
"include_colab_link": true,
|
|
"name": "Copy of automl_nlp.ipynb",
|
|
"provenance": []
|
|
},
|
|
"gpuClass": "standard",
|
|
"kernelspec": {
|
|
"display_name": "flaml-dev",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "cbbf4d250a3560c7073bd6e01a7ecfe1c772dc45f2100f74412fcaea735f0880"
|
|
}
|
|
},
|
|
"widgets": {}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|