mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-21 22:23:44 +00:00

* Added spark support for parallel training. * Added tests and fixed a bug * Added more tests and updated docs * Updated setup.py and docs * Added customize_learner and tests * Update spark tests and setup.py * Update docs and verbose * Update logging, fix issue in cloud notebook * Update github workflow for spark tests * Update github workflow * Remove hack of handling _choice_ * Allow for failures * Fix tests, update docs * Update setup.py * Update Dockerfile for Spark * Update tests, remove some warnings * Add test for notebooks, update utils * Add performance test for Spark * Fix lru_cache maxsize * Fix test failures on some platforms * Fix coverage report failure * resovle PR comments * resovle PR comments 2nd round * resovle PR comments 3rd round * fix lint and rename test class * resovle PR comments 4th round * refactor customize_learner to broadcast_code
102 lines
2.7 KiB
Python
102 lines
2.7 KiB
Python
from flaml.tune.spark.utils import (
|
|
with_parameters,
|
|
check_spark,
|
|
get_n_cpus,
|
|
get_broadcast_data,
|
|
)
|
|
from functools import partial
|
|
from timeit import timeit
|
|
import pytest
|
|
|
|
try:
|
|
from pyspark.sql import SparkSession
|
|
import pyspark
|
|
|
|
spark_available, _ = check_spark()
|
|
skip_spark = not spark_available
|
|
except ImportError:
|
|
print("Spark is not installed. Skip all spark tests.")
|
|
skip_spark = True
|
|
|
|
pytestmark = pytest.mark.skipif(
|
|
skip_spark, reason="Spark is not installed. Skip all spark tests."
|
|
)
|
|
|
|
|
|
def test_with_parameters_spark():
|
|
def train(config, data=None):
|
|
if isinstance(data, pyspark.broadcast.Broadcast):
|
|
data = data.value
|
|
print(config, len(data))
|
|
|
|
data = ["a"] * 10**6
|
|
|
|
with_parameters_train = with_parameters(train, data=data)
|
|
partial_train = partial(train, data=data)
|
|
|
|
spark = SparkSession.builder.getOrCreate()
|
|
rdd = spark.sparkContext.parallelize(list(range(2)))
|
|
|
|
t_partial = timeit(
|
|
lambda: rdd.map(lambda x: partial_train(config=x)).collect(), number=5
|
|
)
|
|
print("python_partial_train: " + str(t_partial))
|
|
|
|
t_spark = timeit(
|
|
lambda: rdd.map(lambda x: with_parameters_train(config=x)).collect(),
|
|
number=5,
|
|
)
|
|
print("spark_with_parameters_train: " + str(t_spark))
|
|
|
|
# assert t_spark < t_partial
|
|
|
|
|
|
def test_get_n_cpus_spark():
|
|
n_cpus = get_n_cpus()
|
|
assert isinstance(n_cpus, int)
|
|
|
|
|
|
def test_broadcast_code():
|
|
from flaml.tune.spark.utils import broadcast_code
|
|
from flaml.automl.model import LGBMEstimator
|
|
|
|
custom_code = """
|
|
from flaml.automl.model import LGBMEstimator
|
|
from flaml import tune
|
|
|
|
class MyLargeLGBM(LGBMEstimator):
|
|
@classmethod
|
|
def search_space(cls, **params):
|
|
return {
|
|
"n_estimators": {
|
|
"domain": tune.lograndint(lower=4, upper=32768),
|
|
"init_value": 32768,
|
|
"low_cost_init_value": 4,
|
|
},
|
|
"num_leaves": {
|
|
"domain": tune.lograndint(lower=4, upper=32768),
|
|
"init_value": 32768,
|
|
"low_cost_init_value": 4,
|
|
},
|
|
}
|
|
"""
|
|
|
|
_ = broadcast_code(custom_code=custom_code)
|
|
from flaml.tune.spark.mylearner import MyLargeLGBM
|
|
|
|
assert isinstance(MyLargeLGBM(), LGBMEstimator)
|
|
|
|
|
|
def test_get_broadcast_data():
|
|
data = ["a"] * 10
|
|
spark = SparkSession.builder.getOrCreate()
|
|
bc_data = spark.sparkContext.broadcast(data)
|
|
assert get_broadcast_data(bc_data) == data
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_with_parameters_spark()
|
|
test_get_n_cpus_spark()
|
|
test_broadcast_code()
|
|
test_get_broadcast_data()
|