autogen/test/spark/test_tune.py
Li Jiang 50334f2c52
Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe

add support to SynapseML LightGBM model

update to pyspark>=3.2.0 to leverage pandas_on_Spark API

* clean code, add TODOs

* add sample_train_data for pyspark.pandas dataframe, fix bugs

* improve some functions, fix bugs

* fix dict change size during iteration

* update model predict

* update LightGBM model, update test

* update SynapseML LightGBM params

* update synapseML and tests

* update TODOs

* Added support to roc_auc for spark models

* Added support to score of spark estimator

* Added test for automl score of spark estimator

* Added cv support to pyspark.pandas dataframe

* Update test, fix bugs

* Added tests

* Updated docs, tests, added a notebook

* Fix bugs in non-spark env

* Fix bugs and improve tests

* Fix uninstall pyspark

* Fix tests error

* Fix java.lang.OutOfMemoryError: Java heap space

* Fix test_performance

* Update test_sparkml to test_0sparkml to use the expected spark conf

* Remove unnecessary widgets in notebook

* Fix iloc java.lang.StackOverflowError

* fix pre-commit

* Added params check for spark dataframes

* Refactor code for train_test_split to a function

* Update train_test_split_pyspark

* Refactor if-else, remove unnecessary code

* Remove y from predict, remove mem control from n_iter compute

* Update workflow

* Improve _split_pyspark

* Fix test failure of too short training time

* Fix typos, improve docstrings

* Fix index errors of pandas_on_spark, add spark loss metric

* Fix typo of ndcgAtK

* Update NDCG metrics and tests

* Remove unuseful logger

* Use cache and count to ensure consistent indexes

* refactor for merge maain

* fix errors of refactor

* Updated SparkLightGBMEstimator and cache

* Updated config2params

* Remove unused import

* Fix unknown parameters

* Update default_estimator_list

* Add unit tests for spark metrics
2023-03-25 19:59:46 +00:00

60 lines
1.6 KiB
Python

import lightgbm as lgb
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from flaml import tune
from flaml.automl.model import LGBMEstimator
from flaml.tune.spark.utils import check_spark
import os
import pytest
spark_available, _ = check_spark()
skip_spark = not spark_available
pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)
os.environ["FLAML_MAX_CONCURRENT"] = "2"
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
def train_breast_cancer(config):
params = LGBMEstimator(**config).params
train_set = lgb.Dataset(X_train, label=y_train)
gbm = lgb.train(params, train_set)
preds = gbm.predict(X_test)
pred_labels = np.rint(preds)
result = {
"mean_accuracy": accuracy_score(y_test, pred_labels),
}
return result
def test_tune_spark():
flaml_lgbm_search_space = LGBMEstimator.search_space(X_train.shape)
config_search_space = {
hp: space["domain"] for hp, space in flaml_lgbm_search_space.items()
}
analysis = tune.run(
train_breast_cancer,
metric="mean_accuracy",
mode="max",
config=config_search_space,
num_samples=-1,
time_budget_s=5,
use_spark=True,
verbose=3,
n_concurrent_trials=4,
)
# print("Best hyperparameters found were: ", analysis.best_config)
print("The best trial's result: ", analysis.best_trial.last_result)
if __name__ == "__main__":
test_tune_spark()