autogen/test/spark/test_ensemble.py
Li Jiang da2cd7ca89
Add supporting using Spark as the backend of parallel training (#846)
* Added spark support for parallel training.

* Added tests and fixed a bug

* Added more tests and updated docs

* Updated setup.py and docs

* Added customize_learner and tests

* Update spark tests and setup.py

* Update docs and verbose

* Update logging, fix issue in cloud notebook

* Update github workflow for spark tests

* Update github workflow

* Remove hack of handling _choice_

* Allow for failures

* Fix tests, update docs

* Update setup.py

* Update Dockerfile for Spark

* Update tests, remove some warnings

* Add test for notebooks, update utils

* Add performance test for Spark

* Fix lru_cache maxsize

* Fix test failures on some platforms

* Fix coverage report failure

* resovle PR comments

* resovle PR comments 2nd round

* resovle PR comments 3rd round

* fix lint and rename test class

* resovle PR comments 4th round

* refactor customize_learner to broadcast_code
2022-12-23 08:18:49 -08:00

58 lines
1.9 KiB
Python

import unittest
from sklearn.datasets import load_wine
from flaml import AutoML
from flaml.tune.spark.utils import check_spark
import os
spark_available, _ = check_spark()
skip_spark = not spark_available
os.environ["FLAML_MAX_CONCURRENT"] = "2"
# To solve pylint issue, we put code for customizing mylearner in a separate file
if os.path.exists(os.path.join(os.getcwd(), "test", "spark", "custom_mylearner.py")):
try:
from test.spark.custom_mylearner import *
from flaml.tune.spark.mylearner import MyRegularizedGreedyForest
skip_my_learner = False
except ImportError:
skip_my_learner = True
else:
skip_my_learner = True
class TestEnsemble(unittest.TestCase):
def setUp(self) -> None:
if skip_spark:
self.skipTest("Spark is not installed. Skip all spark tests.")
@unittest.skipIf(
skip_my_learner,
"Please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file",
)
def test_ensemble(self):
automl = AutoML()
automl.add_learner(learner_name="RGF", learner_class=MyRegularizedGreedyForest)
X_train, y_train = load_wine(return_X_y=True)
settings = {
"time_budget": 5, # total running time in seconds
"estimator_list": ["rf", "xgboost", "catboost"],
"task": "classification", # task type
"sample": True, # whether to subsample training data
"log_file_name": "test/wine.log",
"log_training_metric": True, # whether to log training metric
"ensemble": {
"final_estimator": MyRegularizedGreedyForest(),
"passthrough": False,
},
"n_jobs": 1,
"n_concurrent_trials": 2,
"use_spark": True,
}
automl.fit(X_train=X_train, y_train=y_train, **settings)
if __name__ == "__main__":
unittest.main()