data in csv (#430)

* data in csv

* support ray ObjectRef #365

* use object store to store data when using ray

* make lgbm tuning example a test

* homepage title
This commit is contained in:
Chi Wang 2022-01-30 19:36:41 -08:00 committed by GitHub
parent 6960a833ec
commit 8a44dd4318
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 225 additions and 59 deletions

View File

@ -1667,9 +1667,18 @@ class AutoML(BaseEstimator):
for estimator in self.estimator_list:
search_state = self._search_states[estimator]
if not hasattr(search_state, "training_function"):
search_state.training_function = partial(
AutoMLState._compute_with_config_base, self._state, estimator
)
if self._use_ray:
from ray.tune import with_parameters
search_state.training_function = with_parameters(
AutoMLState._compute_with_config_base,
self=self._state,
estimator=estimator,
)
else:
search_state.training_function = partial(
AutoMLState._compute_with_config_base, self._state, estimator
)
states = self._search_states
mem_res = self._mem_thres
@ -1761,12 +1770,14 @@ class AutoML(BaseEstimator):
shape (n, m). For 'ts_forecast' task, the first column of X_train
must be the timestamp column (datetime type). Other columns in
the dataframe are assumed to be exogenous variables (categorical or numeric).
When using ray, X_train can be a ray.ObjectRef.
y_train: A numpy array or a pandas series of labels in shape (n, ).
dataframe: A dataframe of training data including label column.
For 'ts_forecast' task, dataframe must be specified and must have
at least two columns, timestamp and label, where the first
column is the timestamp column (datetime type). Other columns in
the dataframe are assumed to be exogenous variables (categorical or numeric).
When using ray, dataframe can be a ray.ObjectRef.
label: A str of the label column name for, e.g., 'label';
Note: If X_train and y_train are provided,
dataframe and label are ignored;
@ -1993,6 +2004,28 @@ class AutoML(BaseEstimator):
)
min_sample_size = min_sample_size or self._settings.get("min_sample_size")
use_ray = self._settings.get("use_ray") if use_ray is None else use_ray
self._state.n_jobs = n_jobs
self._n_concurrent_trials = n_concurrent_trials
self._early_stop = early_stop
self._use_ray = use_ray or n_concurrent_trials > 1
# use the following condition if we have an estimation of average_trial_time and average_trial_overhead
# self._use_ray = use_ray or n_concurrent_trials > ( average_trail_time + average_trial_overhead) / (average_trial_time)
if self._use_ray:
import ray
n_cpus = use_ray and ray.available_resources()["CPU"] or os.cpu_count()
self._state.resources_per_trial = (
# when using gpu, default cpu is 1 per job; otherwise, default cpu is n_cpus / n_concurrent_trials
{"cpu": max(int(n_cpus / n_concurrent_trials), 1), "gpu": gpu_per_trial}
if gpu_per_trial == 0
else {"cpu": 1, "gpu": gpu_per_trial}
if n_jobs < 0
else {"cpu": n_jobs, "gpu": gpu_per_trial}
)
if isinstance(X_train, ray.ObjectRef):
X_train = ray.get(X_train)
elif isinstance(dataframe, ray.ObjectRef):
dataframe = ray.get(dataframe)
self._state.task = task
self._state.log_training_metric = log_training_metric
@ -2023,24 +2056,6 @@ class AutoML(BaseEstimator):
self._state.eval_method = eval_method
logger.info("Evaluation method: {}".format(eval_method))
self._state.n_jobs = n_jobs
self._n_concurrent_trials = n_concurrent_trials
self._early_stop = early_stop
self._use_ray = use_ray or n_concurrent_trials > 1
# use the following condition if we have an estimation of average_trial_time and average_trial_overhead
# self._use_ray = use_ray or n_concurrent_trials > ( average_trail_time + average_trial_overhead) / (average_trial_time)
if self._use_ray:
import ray
n_cpus = use_ray and ray.available_resources()["CPU"] or os.cpu_count()
self._state.resources_per_trial = (
# when using gpu, default cpu is 1 per job; otherwise, default cpu is n_cpus / n_concurrent_trials
{"cpu": max(int(n_cpus / n_concurrent_trials), 1), "gpu": gpu_per_trial}
if gpu_per_trial == 0
else {"cpu": 1, "gpu": gpu_per_trial}
if n_jobs < 0
else {"cpu": n_jobs, "gpu": gpu_per_trial}
)
self._retrain_in_budget = retrain_full == "budget" and (
eval_method == "holdout" and self._state.X_val is None
)

View File

@ -195,3 +195,20 @@ class ExperimentAnalysis:
"""
best_trial = self.get_best_trial(metric, mode, scope)
return best_trial.config if best_trial else None
@property
def best_result(self) -> Dict:
"""Get the last result of the best trial of the experiment
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
If you didn't pass these parameters, use
`get_best_trial(metric, mode, scope).last_result` instead.
"""
if not self.default_metric or not self.default_mode:
raise ValueError(
"To fetch the `best_result`, pass a `metric` and `mode` "
"parameter to `tune.run()`. Alternatively, use "
"`get_best_trial(metric, mode).last_result` to set "
"the metric and mode explicitly and fetch the last result."
)
return self.best_trial.last_result

View File

@ -265,7 +265,12 @@ class TestClassification(unittest.TestCase):
X_train = scipy.sparse.eye(900000)
y_train = np.random.randint(2, size=900000)
try:
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
import ray
X_train_ref = ray.put(X_train)
automl_experiment.fit(
X_train=X_train_ref, y_train=y_train, **automl_settings
)
print(automl_experiment.predict(X_train))
print(automl_experiment.model)
print(automl_experiment.config_history)

View File

@ -216,6 +216,15 @@ class TestMultiClass(unittest.TestCase):
filename=automl_settings["log_file_name"], time_budget=6
)
print(metric_history)
try:
import ray
df = ray.put(df)
automl_settings["dataframe"] = df
automl_settings["use_ray"] = True
automl_experiment.fit(**automl_settings)
except ImportError:
pass
def test_classification(self, as_frame=False):
automl_experiment = AutoML()

64
test/object_store.py Normal file
View File

@ -0,0 +1,64 @@
from flaml import tune
from flaml.model import LGBMEstimator
import lightgbm
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error
import ray
data = fetch_california_housing(return_X_y=False, as_frame=True)
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)
X_train_ref = ray.put(X_train)
print(isinstance(X_train_ref, ray.ObjectRef))
def train_lgbm(config: dict) -> dict:
# convert config dict to lgbm params
params = LGBMEstimator(**config).params
# train the model
# train_set = lightgbm.Dataset(X_train, y_train)
X_train = ray.get(X_train_ref)
train_set = lightgbm.Dataset(X_train, y_train)
model = lightgbm.train(params, train_set)
# evaluate the model
pred = model.predict(X_test)
mse = mean_squared_error(y_test, pred)
# return eval results as a dictionary
return {"mse": mse}
# load a built-in search space from flaml
flaml_lgbm_search_space = LGBMEstimator.search_space(X_train.shape)
# specify the search space as a dict from hp name to domain; you can define your own search space same way
config_search_space = {
hp: space["domain"] for hp, space in flaml_lgbm_search_space.items()
}
# give guidance about hp values corresponding to low training cost, i.e., {"n_estimators": 4, "num_leaves": 4}
low_cost_partial_config = {
hp: space["low_cost_init_value"]
for hp, space in flaml_lgbm_search_space.items()
if "low_cost_init_value" in space
}
# initial points to evaluate
points_to_evaluate = [
{
hp: space["init_value"]
for hp, space in flaml_lgbm_search_space.items()
if "init_value" in space
}
]
# run the tuning, minimizing mse, with total time budget 3 seconds
analysis = tune.run(
train_lgbm,
metric="mse",
mode="min",
config=config_search_space,
low_cost_partial_config=low_cost_partial_config,
points_to_evaluate=points_to_evaluate,
time_budget_s=3,
num_samples=-1,
)
print(analysis.best_result)

View File

@ -9,10 +9,13 @@ from flaml.model import LGBMEstimator
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
ray.init(address="auto")
X_train_ref = ray.put(X_train)
def train_breast_cancer(config):
params = LGBMEstimator(**config).params
X_train = ray.get(X_train_ref)
train_set = lgb.Dataset(X_train, label=y_train)
gbm = lgb.train(params, train_set)
preds = gbm.predict(X_test)
@ -21,7 +24,6 @@ def train_breast_cancer(config):
if __name__ == "__main__":
ray.init(address="auto")
flaml_lgbm_search_space = LGBMEstimator.search_space(X_train.shape)
config_search_space = {
hp: space["domain"] for hp, space in flaml_lgbm_search_space.items()

View File

@ -0,0 +1,37 @@
import time
from azureml.core import Workspace, Experiment, ScriptRunConfig, Environment
ws = Workspace.from_config()
ray_environment_name = "aml-ray-cpu"
ray_environment_dockerfile_path = "./Docker/Dockerfile-cpu"
# Build CPU image for Ray
ray_cpu_env = Environment.from_dockerfile(
name=ray_environment_name, dockerfile=ray_environment_dockerfile_path
)
ray_cpu_env.register(workspace=ws)
ray_cpu_build_details = ray_cpu_env.build(workspace=ws)
while ray_cpu_build_details.status not in ["Succeeded", "Failed"]:
print(
f"Awaiting completion of ray CPU environment build. Current status is: {ray_cpu_build_details.status}"
)
time.sleep(10)
env = Environment.get(workspace=ws, name=ray_environment_name)
compute_target = ws.compute_targets["cpucluster"]
command = ["python distribute_tune.py"]
config = ScriptRunConfig(
source_directory="ray/",
command=command,
compute_target=compute_target,
environment=env,
)
config.run_config.node_count = 2
config.run_config.environment_variables["_AZUREML_CR_START_RAY"] = "true"
config.run_config.environment_variables["AZUREML_COMPUTE_USE_COMMON_RUNTIME"] = "true"
exp = Experiment(ws, "test-ray")
run = exp.submit(config)
print(run.get_portal_url()) # link to ml.azure.com
run.wait_for_completion(show_output=True)

View File

@ -5,17 +5,28 @@ from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error
X, y = fetch_california_housing(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
data = fetch_california_housing(return_X_y=False, as_frame=True)
df, X, y = data.frame, data.data, data.target
df_train, _, X_train, X_test, _, y_test = train_test_split(
df, X, y, test_size=0.33, random_state=42
)
csv_file_name = "test/housing.csv"
df_train.to_csv(csv_file_name, index=False)
# X, y = fetch_california_housing(return_X_y=True, as_frame=True)
# X_train, X_test, y_train, y_test = train_test_split(
# X, y, test_size=0.33, random_state=42
# )
def train_lgbm(config: dict) -> dict:
# convert config dict to lgbm params
params = LGBMEstimator(**config).params
# train the model
train_set = lightgbm.Dataset(X_train, y_train)
# train_set = lightgbm.Dataset(X_train, y_train)
# LightGBM only accepts the csv with valid number format, if even these string columns are set to ignore.
train_set = lightgbm.Dataset(
csv_file_name, params={"label_column": "name:MedHouseVal", "header": True}
)
model = lightgbm.train(params, train_set)
# evaluate the model
pred = model.predict(X_test)
@ -24,34 +35,40 @@ def train_lgbm(config: dict) -> dict:
return {"mse": mse}
# load a built-in search space from flaml
flaml_lgbm_search_space = LGBMEstimator.search_space(X_train.shape)
# specify the search space as a dict from hp name to domain; you can define your own search space same way
config_search_space = {
hp: space["domain"] for hp, space in flaml_lgbm_search_space.items()
}
# give guidance about hp values corresponding to low training cost, i.e., {"n_estimators": 4, "num_leaves": 4}
low_cost_partial_config = {
hp: space["low_cost_init_value"]
for hp, space in flaml_lgbm_search_space.items()
if "low_cost_init_value" in space
}
# initial points to evaluate
points_to_evaluate = [
{
hp: space["init_value"]
for hp, space in flaml_lgbm_search_space.items()
if "init_value" in space
def test_tune_lgbm_csv():
# load a built-in search space from flaml
flaml_lgbm_search_space = LGBMEstimator.search_space(X_train.shape)
# specify the search space as a dict from hp name to domain; you can define your own search space same way
config_search_space = {
hp: space["domain"] for hp, space in flaml_lgbm_search_space.items()
}
]
# run the tuning, minimizing mse, with total time budget 3 seconds
analysis = tune.run(
train_lgbm,
metric="mse",
mode="min",
config=config_search_space,
low_cost_partial_config=low_cost_partial_config,
points_to_evaluate=points_to_evaluate,
time_budget_s=3,
num_samples=-1,
)
# give guidance about hp values corresponding to low training cost, i.e., {"n_estimators": 4, "num_leaves": 4}
low_cost_partial_config = {
hp: space["low_cost_init_value"]
for hp, space in flaml_lgbm_search_space.items()
if "low_cost_init_value" in space
}
# initial points to evaluate
points_to_evaluate = [
{
hp: space["init_value"]
for hp, space in flaml_lgbm_search_space.items()
if "init_value" in space
}
]
# run the tuning, minimizing mse, with total time budget 3 seconds
analysis = tune.run(
train_lgbm,
metric="mse",
mode="min",
config=config_search_space,
low_cost_partial_config=low_cost_partial_config,
points_to_evaluate=points_to_evaluate,
time_budget_s=3,
num_samples=-1,
)
print(analysis.best_result)
if __name__ == "__main__":
test_tune_lgbm_csv()

View File

@ -50,7 +50,7 @@ def train_lgbm(config: dict) -> dict:
# convert config dict to lgbm params
params = LGBMEstimator(**config).params
# train the model
train_set = lightgbm.Dataset(X_train, y_train)
train_set = lightgbm.Dataset(csv_file_name)
model = lightgbm.train(params, train_set)
# evaluate the model
pred = model.predict(X_test)

View File

@ -29,7 +29,7 @@ export default function Home() {
const {siteConfig} = useDocusaurusContext();
return (
<Layout
title={`Hello from ${siteConfig.title}`}
title={`AutoML & Tuning`}
description="A Fast Library for Automated Machine Learning and Tuning">
<HomepageHeader />
<main>