autogen/test/spark/test_utils.py
Li Jiang da2cd7ca89
Add supporting using Spark as the backend of parallel training (#846)
* Added spark support for parallel training.

* Added tests and fixed a bug

* Added more tests and updated docs

* Updated setup.py and docs

* Added customize_learner and tests

* Update spark tests and setup.py

* Update docs and verbose

* Update logging, fix issue in cloud notebook

* Update github workflow for spark tests

* Update github workflow

* Remove hack of handling _choice_

* Allow for failures

* Fix tests, update docs

* Update setup.py

* Update Dockerfile for Spark

* Update tests, remove some warnings

* Add test for notebooks, update utils

* Add performance test for Spark

* Fix lru_cache maxsize

* Fix test failures on some platforms

* Fix coverage report failure

* resovle PR comments

* resovle PR comments 2nd round

* resovle PR comments 3rd round

* fix lint and rename test class

* resovle PR comments 4th round

* refactor customize_learner to broadcast_code
2022-12-23 08:18:49 -08:00

102 lines
2.7 KiB
Python

from flaml.tune.spark.utils import (
with_parameters,
check_spark,
get_n_cpus,
get_broadcast_data,
)
from functools import partial
from timeit import timeit
import pytest
try:
from pyspark.sql import SparkSession
import pyspark
spark_available, _ = check_spark()
skip_spark = not spark_available
except ImportError:
print("Spark is not installed. Skip all spark tests.")
skip_spark = True
pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)
def test_with_parameters_spark():
def train(config, data=None):
if isinstance(data, pyspark.broadcast.Broadcast):
data = data.value
print(config, len(data))
data = ["a"] * 10**6
with_parameters_train = with_parameters(train, data=data)
partial_train = partial(train, data=data)
spark = SparkSession.builder.getOrCreate()
rdd = spark.sparkContext.parallelize(list(range(2)))
t_partial = timeit(
lambda: rdd.map(lambda x: partial_train(config=x)).collect(), number=5
)
print("python_partial_train: " + str(t_partial))
t_spark = timeit(
lambda: rdd.map(lambda x: with_parameters_train(config=x)).collect(),
number=5,
)
print("spark_with_parameters_train: " + str(t_spark))
# assert t_spark < t_partial
def test_get_n_cpus_spark():
n_cpus = get_n_cpus()
assert isinstance(n_cpus, int)
def test_broadcast_code():
from flaml.tune.spark.utils import broadcast_code
from flaml.automl.model import LGBMEstimator
custom_code = """
from flaml.automl.model import LGBMEstimator
from flaml import tune
class MyLargeLGBM(LGBMEstimator):
@classmethod
def search_space(cls, **params):
return {
"n_estimators": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
"num_leaves": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
}
"""
_ = broadcast_code(custom_code=custom_code)
from flaml.tune.spark.mylearner import MyLargeLGBM
assert isinstance(MyLargeLGBM(), LGBMEstimator)
def test_get_broadcast_data():
data = ["a"] * 10
spark = SparkSession.builder.getOrCreate()
bc_data = spark.sparkContext.broadcast(data)
assert get_broadcast_data(bc_data) == data
if __name__ == "__main__":
test_with_parameters_spark()
test_get_n_cpus_spark()
test_broadcast_code()
test_get_broadcast_data()