This commit is contained in:
Anonymous-submission-repo 2022-10-09 11:39:29 -04:00
parent c01e65bb48
commit 9bc32acafb
24 changed files with 511 additions and 71 deletions

View File

@ -736,6 +736,7 @@ class AutoML(BaseEstimator):
settings["use_ray"] = settings.get("use_ray", False)
settings["metric_constraints"] = settings.get("metric_constraints", [])
settings["cv_score_agg_func"] = settings.get("cv_score_agg_func", None)
settings["lexico_objectives"] = settings.get("lexico_objectives", None)
settings["fit_kwargs_by_estimator"] = settings.get(
"fit_kwargs_by_estimator", {}
)
@ -2149,6 +2150,7 @@ class AutoML(BaseEstimator):
cv_score_agg_func=None,
skip_transform=None,
fit_kwargs_by_estimator=None,
lexico_objectives=None,
**fit_kwargs,
):
"""Find a model for a given task.
@ -2402,6 +2404,19 @@ class AutoML(BaseEstimator):
[TrainingArgumentsForAuto](nlp/huggingface/training_args).
e.g.,
skip_transform: boolean, default=False | Whether to pre-process data prior to modeling.
lexico_objectives: A dictionary with four elements.
It specifics the information used for multiple objectives optimization with lexicographic preference.
e.g.,```lexico_objectives = {"metrics":["error_rate","pred_time"], "modes":["min","min"],
"tolerances":{"error_rate":0.01,"pred_time":0.0}, "targets":{"error_rate":0.0,"pred_time":0.0}}```
Either "metrics" or "modes" is a list of str.
It represents the optimization objectives, the objective as minimization or maximization respectively.
Both "metrics" and "modes" are ordered by priorities from high to low.
"tolerances" is a dictionary to specify the optimality tolerance of each objective.
"targets" is a dictionary to specify the optimization targets for each objective.
If providing lexico_objectives, the arguments metric, hpo_method will be invalid.
fit_kwargs_by_estimator: dict, default=None | The user specified keywords arguments, grouped by estimator name.
For TransformersEstimator, available fit_kwargs can be found from
[TrainingArgumentsForAuto](nlp/huggingface/training_args).
@ -2502,7 +2517,11 @@ class AutoML(BaseEstimator):
self._settings.get("retrain_full") if retrain_full is None else retrain_full
)
split_type = split_type or self._settings.get("split_type")
hpo_method = hpo_method or self._settings.get("hpo_method")
if lexico_objectives is None:
hpo_method = hpo_method or self._settings.get("hpo_method")
else:
hpo_method = "cfo"
learner_selector = learner_selector or self._settings.get("learner_selector")
no_starting_points = starting_points is None
if no_starting_points:
@ -2606,7 +2625,9 @@ class AutoML(BaseEstimator):
self._state.cv_score_agg_func = cv_score_agg_func or self._settings.get(
"cv_score_agg_func"
)
self._state.lexico_objectives = lexico_objectives or self._settings.get(
"lexico_objectives"
)
self._retrain_in_budget = retrain_full == "budget" and (
eval_method == "holdout" and self._state.X_val is None
)
@ -2996,6 +3017,7 @@ class AutoML(BaseEstimator):
metric_constraints=self.metric_constraints,
seed=self._seed,
time_budget_s=time_left,
lexico_objectives=self._state.lexico_objectives,
)
else:
# if self._hpo_method is bo, sometimes the search space and the initial config dimension do not match
@ -3232,6 +3254,7 @@ class AutoML(BaseEstimator):
],
metric_constraints=self.metric_constraints,
seed=self._seed,
lexico_objectives=self._state.lexico_objectives,
)
else:
# if self._hpo_method is bo, sometimes the search space and the initial config dimension do not match

View File

@ -11,7 +11,7 @@ from flaml.data import (
MULTICHOICECLASSIFICATION,
SUMMARIZATION,
SEQCLASSIFICATION,
SEQREGRESSION
SEQREGRESSION,
)

View File

@ -1,6 +1,6 @@
# ChaCha for Online AutoML
FLAML includes *ChaCha* which is an automatic hyperparameter tuning solution for online machine learning. Online machine learning has the following properties: (1) data comes in sequential order; and (2) the performance of the machine learning model is evaluated online, i.e., at every iteration. *ChaCha* performs online AutoML respecting the aforementioned properties of online learning, and at the same time respecting the following constraints: (1) only a small constant number of 'live' models are allowed to perform online learning at the same time; and (2) no model persistence or offline training is allowed, which means that once we decide to replace a 'live' model with a new one, the replaced model can no longer be retrieved.
FLAML includes *ChaCha* which is an automatic hyperparameter tuning solution for online machine learning. Online machine learning has the following properties: (1) data comes in sequential order; and (2) the performance of the machine learning model is evaluated online, i.e., at every iteration. *ChaCha* performs online AutoML respecting the aforementioned properties of online learning, and at the same time respecting the following constraints: (1) only a small constant number of 'live' models are allowed to perform online learning at the same time; and (2) no model persistence or offline training is allowed, which means that once we decide to replace a 'live' model with a new one, the replaced model can no longer be retrieved.
For more technical details about *ChaCha*, please check our paper.

View File

@ -20,7 +20,7 @@ def evaluate_config(config):
# and the cost could be related to certain hyperparameters
# in this example, we assume it's proportional to x
time.sleep(config['x']/100000)
# use tune.report to report the metric to optimize
# use tune.report to report the metric to optimize
tune.report(metric=metric)
analysis = tune.run(
@ -35,7 +35,7 @@ analysis = tune.run(
num_samples=-1, # the maximal number of configs to try, -1 means infinite
time_budget_s=60, # the time budget in seconds
local_dir='logs/', # the local directory to store logs
# verbose=0, # verbosity
# verbose=0, # verbosity
# use_ray=True, # uncomment when performing parallel tuning using ray
)
@ -59,7 +59,7 @@ def evaluate_config(config):
# and the cost could be related to certain hyperparameters
# in this example, we assume it's proportional to x
time.sleep(config['x']/100000)
# use tune.report to report the metric to optimize
# use tune.report to report the metric to optimize
tune.report(metric=metric)
# provide a time budget (in seconds) for the tuning process

View File

@ -17,7 +17,10 @@
# Copyright (c) Microsoft Corporation.
from typing import Dict, Optional
import numpy as np
from flaml.tune import result
from .trial import Trial
from collections import defaultdict
import logging
@ -68,7 +71,6 @@ class ExperimentAnalysis:
@property
def results(self) -> Dict[str, Dict]:
"""Get the last result of all the trials of the experiment"""
return {trial.trial_id: trial.last_result for trial in self.trials}
def _validate_metric(self, metric: str) -> str:
@ -89,6 +91,42 @@ class ExperimentAnalysis:
raise ValueError("If set, `mode` has to be one of [min, max]")
return mode or self.default_mode
def lexico_best(self, trials):
results = {index: trial.last_result for index, trial in enumerate(trials)}
metrics = self.lexico_objectives["metrics"]
modes = self.lexico_objectives["modes"]
f_best = {}
keys = list(results.keys())
length = len(keys)
histories = defaultdict(list)
for time_index in range(length):
for objective, mode in zip(metrics, modes):
histories[objective].append(
results[keys[time_index]][objective]
if mode == "min"
else trials[keys[time_index]][objective] * -1
)
obj_initial = self.lexico_objectives["metrics"][0]
feasible_index = [*range(len(histories[obj_initial]))]
for k_metric in self.lexico_objectives["metrics"]:
k_values = np.array(histories[k_metric])
f_best[k_metric] = np.min(k_values.take(feasible_index))
feasible_index_prior = np.where(
k_values
<= max(
[
f_best[k_metric]
+ self.lexico_objectives["tolerances"][k_metric],
self.lexico_objectives["targets"][k_metric],
]
)
)[0].tolist()
feasible_index = [
val for val in feasible_index if val in feasible_index_prior
]
best_trial = trials[feasible_index[-1]]
return best_trial
def get_best_trial(
self,
metric: Optional[str] = None,
@ -120,9 +158,11 @@ class ExperimentAnalysis:
values are disregarded and these trials are never selected as
the best trial.
"""
if self.lexico_objectives is not None:
best_trial = self.lexico_best(self.trials)
return best_trial
metric = self._validate_metric(metric)
mode = self._validate_mode(mode)
if scope not in ["all", "last", "avg", "last-5-avg", "last-10-avg"]:
raise ValueError(
"ExperimentAnalysis: attempting to get best trial for "
@ -138,7 +178,6 @@ class ExperimentAnalysis:
for trial in self.trials:
if metric not in trial.metric_analysis:
continue
if scope in ["last", "avg", "last-5-avg", "last-10-avg"]:
metric_score = trial.metric_analysis[metric][scope]
else:
@ -158,7 +197,6 @@ class ExperimentAnalysis:
elif (mode == "min") and (best_metric_score > metric_score):
best_metric_score = metric_score
best_trial = trial
if not best_trial:
logger.warning(
"Could not find best trial. Did you pass the correct `metric` "

View File

@ -43,6 +43,9 @@ PID = "pid"
# (Optional) Default (anonymous) metric when using tune.report(x)
DEFAULT_METRIC = "_metric"
# (Optional) Default (anonymous) mode when using tune.report(x)
DEFAULT_MODE = "min"
# (Optional) Mean reward for current training iteration
EPISODE_REWARD_MEAN = "episode_reward_mean"

View File

@ -63,6 +63,7 @@ class BlendSearch(Searcher):
seed: Optional[int] = 20,
cost_attr: Optional[str] = "auto",
experimental: Optional[bool] = False,
lexico_objectives: Optional[dict] = None,
use_incumbent_result_in_evaluation=False,
):
"""Constructor.
@ -127,6 +128,7 @@ class BlendSearch(Searcher):
self.penalty = PENALTY # penalty term for constraints
self._metric, self._mode = metric, mode
self._use_incumbent_result_in_evaluation = use_incumbent_result_in_evaluation
self.lexico_objectives = lexico_objectives
init_config = low_cost_partial_config or {}
if not init_config:
logger.info(
@ -176,6 +178,7 @@ class BlendSearch(Searcher):
max_resource,
reduction_factor,
self.cost_attr,
self.lexico_objectives,
seed,
)
if global_search_alg is not None:
@ -480,11 +483,15 @@ class BlendSearch(Searcher):
del self._subspace[trial_id]
def _create_thread(self, config, result, space):
if self.lexico_objectives is None:
obj = result[self._ls.metric]
else:
obj = {k: result[k] for k in self.lexico_objectives["metrics"]}
self._search_thread_pool[self._thread_count] = SearchThread(
self._ls.mode,
self._ls.create(
config,
result[self._ls.metric],
obj,
cost=result.get(self.cost_attr, 1),
space=space,
),
@ -1044,6 +1051,7 @@ class BlendSearchTuner(BlendSearch, NNITuner):
self._ls.max_resource,
self._ls.resource_multiple_factor,
cost_attr=self.cost_attr,
lexico_objectives=self.lexico_objectives,
seed=self._ls.seed,
)
if self._gs is not None:

View File

@ -5,6 +5,7 @@
from typing import Dict, Optional, Tuple
import numpy as np
import logging
from collections import defaultdict
try:
from ray import __version__ as ray_version
@ -48,6 +49,7 @@ class FLOW2(Searcher):
max_resource: Optional[float] = None,
resource_multiple_factor: Optional[float] = None,
cost_attr: Optional[str] = "time_total_s",
lexico_objectives=None,
seed: Optional[int] = 20,
):
"""Constructor.
@ -90,13 +92,16 @@ class FLOW2(Searcher):
self.best_config = flatten_dict(init_config)
self.resource_attr = resource_attr
self.min_resource = min_resource
self.lexico_objectives = lexico_objectives
self.resource_multiple_factor = (
resource_multiple_factor or SAMPLE_MULTIPLY_FACTOR
)
self.cost_attr = cost_attr
self.max_resource = max_resource
self._resource = None
self._f_best = None
self._step_lb = np.Inf
self._histories = None
if space is not None:
self._init_search()
@ -263,9 +268,22 @@ class FLOW2(Searcher):
self.max_resource,
self.resource_multiple_factor,
self.cost_attr,
self.lexico_objectives,
self.seed + 1,
)
flow2.best_obj = obj * self.metric_op # minimize internally
if self.lexico_objectives is not None:
flow2.best_obj = {}
for k, v in obj.items():
flow2.best_obj[k] = (
v * -1
if self.lexico_objectives["modes"][
self.lexico_objectives["metrics"].index(k)
]
== "max"
else v
)
else:
flow2.best_obj = obj * self.metric_op # minimize internally
flow2.cost_incumbent = cost
self.seed += 1
return flow2
@ -303,6 +321,56 @@ class FLOW2(Searcher):
self._init_search()
return True
def lexico_compare(self, result) -> bool:
def update_fbest():
obj_initial = self.lexico_objectives["metrics"][0]
feasible_index = [*range(len(self._histories[obj_initial]))]
for k_metric in self.lexico_objectives["metrics"]:
k_values = np.array(self._histories[k_metric])
self._f_best[k_metric] = np.min(k_values.take(feasible_index))
feasible_index_prior = np.where(
k_values
<= max(
[
self._f_best[k_metric]
+ self.lexico_objectives["tolerances"][k_metric],
self.lexico_objectives["targets"][k_metric],
]
)
)[0].tolist()
feasible_index = [
val for val in feasible_index if val in feasible_index_prior
]
if self._histories is None:
self._histories, self._f_best = defaultdict(list), {}
for k in self.lexico_objectives["metrics"]:
self._histories[k].append(result[k])
update_fbest()
return True
else:
for k in self.lexico_objectives["metrics"]:
self._histories[k].append(result[k])
update_fbest()
for k_metric in self.lexico_objectives["metrics"]:
k_T = self.lexico_objectives["tolerances"][k_metric]
k_c = self.lexico_objectives["targets"][k_metric]
if (result[k_metric] < max([self._f_best[k_metric] + k_T, k_c])) and (
self.best_obj[k_metric] < max([self._f_best[k_metric] + k_T, k_c])
):
continue
elif result[k_metric] < self.best_obj[k_metric]:
return True
else:
return False
for k_metr in self.lexico_objectives["metrics"]:
if result[k_metr] == self.best_obj[k_metr]:
continue
elif result[k_metr] < self.best_obj[k_metr]:
return True
else:
return False
def on_trial_complete(
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
):
@ -313,10 +381,28 @@ class FLOW2(Searcher):
"""
self.trial_count_complete += 1
if not error and result:
obj = result.get(self._metric)
obj = (
result.get(self._metric)
if self.lexico_objectives is None
else {k: result[k] for k in self.lexico_objectives["metrics"]}
)
if obj:
obj *= self.metric_op
if self.best_obj is None or obj < self.best_obj:
obj = (
{
k: obj[k] * -1 if m == "max" else obj[k]
for k, m in zip(
self.lexico_objectives["metrics"],
self.lexico_objectives["modes"],
)
}
if isinstance(obj, dict)
else obj * self.metric_op
)
if (
self.best_obj is None
or (self.lexico_objectives is None and obj < self.best_obj)
or (self.lexico_objectives is not None and self.lexico_compare(obj))
):
self.best_obj = obj
self.best_config, self.step = self._configs[trial_id]
self.incumbent = self.normalize(self.best_config)
@ -329,7 +415,6 @@ class FLOW2(Searcher):
self._num_allowed4incumbent = 2 * self.dim
self._proposed_by.clear()
if self._K > 0:
# self._oldK must have been set when self._K>0
self.step *= np.sqrt(self._K / self._oldK)
self.step = min(self.step, self.step_ub)
self._iter_best_config = self.trial_count_complete
@ -340,7 +425,6 @@ class FLOW2(Searcher):
self._trunc = max(self._trunc >> 1, 1)
proposed_by = self._proposed_by.get(trial_id)
if proposed_by == self.incumbent:
# proposed by current incumbent and no better
self._num_complete4incumbent += 1
cost = (
result.get(self.cost_attr, 1)
@ -357,17 +441,34 @@ class FLOW2(Searcher):
if self._num_complete4incumbent == self.dir and (
not self._resource or self._resource == self.max_resource
):
# check stuck condition if using max resource
self._num_complete4incumbent -= 2
self._num_allowed4incumbent = max(self._num_allowed4incumbent, 2)
def on_trial_result(self, trial_id: str, result: Dict):
"""Early update of incumbent."""
if result:
obj = result.get(self._metric)
obj = (
result.get(self._metric)
if self.lexico_objectives is None
else {k: result[k] for k in self.lexico_objectives["metrics"]}
)
if obj:
obj *= self.metric_op
if self.best_obj is None or obj < self.best_obj:
obj = (
{
k: obj[k] * -1 if m == "max" else obj[k]
for k, m in zip(
self.lexico_objectives["metrics"],
self.lexico_objectives["modes"],
)
}
if isinstance(obj, dict)
else obj * self.metric_op
)
if (
self.best_obj is None
or (self.lexico_objectives is None and obj < self.best_obj)
or (self.lexico_objectives is not None and self.lexico_compare(obj))
):
self.best_obj = obj
config = self._configs[trial_id][0]
if self.best_config != config:

View File

@ -137,7 +137,10 @@ class SearchThread:
if result:
self.cost_last = result.get(self.cost_attr, 1)
self.cost_total += self.cost_last
if self._search_alg.metric in result:
if self._search_alg.metric in result and (
not hasattr(self._search_alg, "lexico_objectives")
or self._search_alg.lexico_objectives is None
):
obj = result[self._search_alg.metric] * self._metric_op
if obj < self.obj_best1 or self.best_result is None:
self.cost_best2 = self.cost_best1
@ -146,7 +149,11 @@ class SearchThread:
self.obj_best1 = obj
self.cost_best = self.cost_last
self.best_result = result
self._update_speed()
if (
not hasattr(self._search_alg, "lexico_objectives")
or self._search_alg.lexico_objectives is None
):
self._update_speed()
self.running -= 1
assert self.running >= 0

View File

@ -7,6 +7,7 @@ import numpy as np
import datetime
import time
import os
from collections import defaultdict
try:
from ray import __version__ as ray_version
@ -20,7 +21,7 @@ except (ImportError, AssertionError):
from .analysis import ExperimentAnalysis as EA
from .trial import Trial
from .result import DEFAULT_METRIC
from .result import DEFAULT_METRIC, DEFAULT_MODE
import logging
logger = logging.getLogger(__name__)
@ -33,16 +34,70 @@ _training_iteration = 0
INCUMBENT_RESULT = "__incumbent_result__"
def is_nan_or_inf(value):
return np.isnan(value) or np.isinf(value)
class ExperimentAnalysis(EA):
"""Class for storing the experiment results."""
def __init__(self, trials, metric, mode):
def __init__(self, trials, metric, mode, lexico_objectives):
try:
super().__init__(self, None, trials, metric, mode)
except (TypeError, ValueError):
self.trials = trials
self.default_metric = metric or DEFAULT_METRIC
self.default_mode = mode
self.default_mode = mode or DEFAULT_MODE
self.lexico_objectives = lexico_objectives
def lexico_best(self, trials):
results = {index: trial.last_result for index, trial in enumerate(trials)}
metrics = self.lexico_objectives["metrics"]
modes = self.lexico_objectives["modes"]
f_best = {}
keys = list(results.keys())
length = len(keys)
histories = defaultdict(list)
for time_index in range(length):
for objective, mode in zip(metrics, modes):
histories[objective].append(
results[keys[time_index]][objective]
if mode == "min"
else trials[keys[time_index]][objective] * -1
)
obj_initial = self.lexico_objectives["metrics"][0]
feasible_index = [*range(len(histories[obj_initial]))]
for k_metric in self.lexico_objectives["metrics"]:
k_values = np.array(histories[k_metric])
f_best[k_metric] = np.min(k_values.take(feasible_index))
feasible_index_prior = np.where(
k_values
<= max(
[
f_best[k_metric]
+ self.lexico_objectives["tolerances"][k_metric],
self.lexico_objectives["targets"][k_metric],
]
)
)[0].tolist()
feasible_index = [
val for val in feasible_index if val in feasible_index_prior
]
best_trial = trials[feasible_index[-1]]
return best_trial
def get_best_trial(
self,
metric: Optional[str] = None,
mode: Optional[str] = None,
scope: str = "last",
filter_nan_and_inf: bool = True,
) -> Optional[Trial]:
if self.lexico_objectives is not None:
best_trial = self.lexico_best(self.trials)
else:
best_trial = super().get_best_trial(metric, mode, scope, filter_nan_and_inf)
return best_trial
def report(_metric=None, **kwargs):
@ -148,6 +203,7 @@ def run(
max_failure: Optional[int] = 100,
use_ray: Optional[bool] = False,
use_incumbent_result_in_evaluation: Optional[bool] = None,
lexico_objectives: Optional[dict] = None,
log_file_name: Optional[str] = None,
**ray_args,
):
@ -300,6 +356,18 @@ def run(
max_failure: int | the maximal consecutive number of failures to sample
a trial before the tuning is terminated.
use_ray: A boolean of whether to use ray as the backend.
lexico_objectives: A dictionary with four elements.
It specifics the information used for multiple objectives optimization with lexicographic preference.
e.g.,```lexico_objectives = {"metrics":["error_rate","pred_time"], "modes":["min","min"],
"tolerances":{"error_rate":0.01,"pred_time":0.0}, "targets":{"error_rate":0.0,"pred_time":0.0}}```
Either "metrics" or "modes" is a list of str.
It represents the optimization objectives, the objective as minimization or maximization respectively.
Both "metrics" and "modes" are ordered by priorities from high to low.
"tolerances" is a dictionary to specify the optimality tolerance of each objective.
"targets" is a dictionary to specify the optimization targets for each objective.
If providing lexico_objectives, the arguments metric, mode, and search_alg will be invalid.
log_file_name: A string of the log file name. Default to None.
When set to None:
if local_dir is not given, no log file is created;
@ -374,15 +442,21 @@ def run(
try:
import optuna as _
SearchAlgorithm = BlendSearch
if lexico_objectives is None:
SearchAlgorithm = BlendSearch
else:
SearchAlgorithm = CFO
except ImportError:
SearchAlgorithm = CFO
logger.warning(
"Using CFO for search. To use BlendSearch, run: pip install flaml[blendsearch]"
)
if lexico_objectives is None:
metric = metric or DEFAULT_METRIC
else:
metric = lexico_objectives["metrics"][0] or DEFAULT_METRIC
search_alg = SearchAlgorithm(
metric=metric or DEFAULT_METRIC,
metric=metric,
mode=mode,
space=config,
points_to_evaluate=points_to_evaluate,
@ -398,6 +472,7 @@ def run(
config_constraints=config_constraints,
metric_constraints=metric_constraints,
use_incumbent_result_in_evaluation=use_incumbent_result_in_evaluation,
lexico_objectives=lexico_objectives,
)
else:
if metric is None or mode is None:
@ -532,7 +607,12 @@ def run(
logger.warning(
f"fail to sample a trial for {max_failure} times in a row, stopping."
)
analysis = ExperimentAnalysis(_runner.get_trials(), metric=metric, mode=mode)
analysis = ExperimentAnalysis(
_runner.get_trials(),
metric=metric,
mode=mode,
lexico_objectives=lexico_objectives,
)
return analysis
finally:
# recover the global variables in case of nested run

View File

@ -0,0 +1,40 @@
from flaml import AutoML
from flaml.data import load_openml_dataset
def _test_lexiflow():
X_train, X_test, y_train, y_test = load_openml_dataset(
dataset_id=179, data_dir="test/data"
)
lexico_objectives = {}
lexico_objectives["metrics"] = ["val_loss", "pred_time"]
lexico_objectives["tolerances"] = {"val_loss": 0.01, "pred_time": 0.0}
lexico_objectives["targets"] = {"val_loss": 0.0, "pred_time": 0.0}
lexico_objectives["modes"] = ["min", "min"]
automl = AutoML()
settings = {
"time_budget": 100,
"lexico_objectives": lexico_objectives,
"estimator_list": ["xgboost"],
"use_ray": True,
"task": "classification",
"max_iter": 10000000,
"train_time_limit": 60,
"verbose": 0,
"eval_method": "holdout",
"mem_thres": 128 * (1024**3),
"seed": 1,
}
automl.fit(X_train=X_train, y_train=y_train, X_val=X_test, y_val=y_test, **settings)
print(automl.predict(X_train))
print(automl.model)
print(automl.config_history)
print(automl.best_iteration)
print(automl.best_estimator)
if __name__ == "__main__":
_test_lexiflow()

View File

@ -20,7 +20,7 @@ def main():
logger.info(" ".join(f"{k}={v}" for k, v in vars(args).items()))
data_path = os.path.join(args.data, 'data.csv')
data_path = os.path.join(args.data, "data.csv")
df = pd.read_csv(data_path)
train_df, test_df = train_test_split(

View File

@ -19,7 +19,7 @@ environment:
os: Linux
command: >-
python data_prep.py
python data_prep.py
--data {inputs.data}
--test_train_ratio {inputs.test_train_ratio}
--train_data {outputs.train_data}

View File

@ -83,10 +83,10 @@ def build_and_submit_aml_pipeline(config):
################################################
# load component functions
################################################
data_prep_component = Component.from_yaml(ws, yaml_file=LOCAL_DIR
/ "data_prep/data_prep.yaml")
train_component = Component.from_yaml(ws, yaml_file=LOCAL_DIR
/ "train/train.yaml")
data_prep_component = Component.from_yaml(
ws, yaml_file=LOCAL_DIR / "data_prep/data_prep.yaml"
)
train_component = Component.from_yaml(ws, yaml_file=LOCAL_DIR / "train/train.yaml")
################################################
# build pipeline

View File

@ -14,16 +14,19 @@ def remote_run():
################################################
# connect to your Azure ML workspace
################################################
ws = Workspace(subscription_id=args.subscription_id,
resource_group=args.resource_group,
workspace_name=args.workspace)
ws = Workspace(
subscription_id=args.subscription_id,
resource_group=args.resource_group,
workspace_name=args.workspace,
)
################################################
# load component functions
################################################
pipeline_tuning_func = Component.from_yaml(ws, yaml_file=LOCAL_DIR
/ "tuner/component_spec.yaml")
pipeline_tuning_func = Component.from_yaml(
ws, yaml_file=LOCAL_DIR / "tuner/component_spec.yaml"
)
################################################
# build pipeline
@ -44,6 +47,7 @@ def remote_run():
def local_run():
logger.info("Run tuner locally.")
from tuner import tuner_func
tuner_func.tune_pipeline(concurrent_run=2)
@ -52,15 +56,18 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_mutually_exclusive_group(required=False)
parser.add_argument(
"--subscription_id", type=str, help="your_subscription_id", required=False,
"--subscription_id",
type=str,
help="your_subscription_id",
required=False,
)
parser.add_argument(
"--resource_group", type=str, help="your_resource_group", required=False)
parser.add_argument(
"--workspace", type=str, help="your_workspace", required=False)
"--resource_group", type=str, help="your_resource_group", required=False
)
parser.add_argument("--workspace", type=str, help="your_workspace", required=False)
parser.add_argument('--remote', dest='remote', action='store_true')
parser.add_argument('--local', dest='remote', action='store_false')
parser.add_argument("--remote", dest="remote", action="store_true")
parser.add_argument("--local", dest="remote", action="store_false")
parser.set_defaults(remote=True)
args = parser.parse_args()

View File

@ -5,7 +5,7 @@ import pandas as pd
from azureml.core import Run
class LightGBMCallbackHandler():
class LightGBMCallbackHandler:
def __init__(self):
pass
@ -24,16 +24,22 @@ class LightGBMCallbackHandler():
def main(args):
"""Main function of the script."""
train_path = os.path.join(args.train_data, 'data.csv')
train_path = os.path.join(args.train_data, "data.csv")
print("traning_path:", train_path)
test_path = os.path.join(args.test_data, 'data.csv')
test_path = os.path.join(args.test_data, "data.csv")
train_set = lgb.Dataset(train_path)
test_set = lgb.Dataset(test_path)
callbacks_handler = LightGBMCallbackHandler()
config = {"header": True, "objective": "binary", "label_column": 30, "metric": "binary_error",
"n_estimators": args.n_estimators, "learning_rate": args.learning_rate}
config = {
"header": True,
"objective": "binary",
"label_column": 30,
"metric": "binary_error",
"n_estimators": args.n_estimators,
"learning_rate": args.learning_rate,
}
gbm = lgb.train(
config,
train_set,
@ -44,9 +50,9 @@ def main(args):
],
)
print('Saving model...')
print("Saving model...")
# save model to file
gbm.save_model(os.path.join(args.model, 'model.txt'))
gbm.save_model(os.path.join(args.model, "model.txt"))
if __name__ == "__main__":

View File

@ -4,9 +4,9 @@ name: classifier
version: 0.0.1
display_name: Train lgbm classifier
inputs:
train_data:
train_data:
type: path
test_data:
test_data:
type: path
learning_rate:
type: float
@ -20,8 +20,8 @@ environment:
conda_dependencies_file: env.yaml
os: Linux
command: >-
python train.py
--train_data {inputs.train_data}
python train.py
--train_data {inputs.train_data}
--test_data {inputs.test_data}
--learning_rate {inputs.learning_rate}
--n_estimators {inputs.n_estimators}

View File

@ -9,4 +9,4 @@ environment:
conda_dependencies_file: env.yaml
os: Linux
command: >-
python tuner/tuner_func.py
python tuner/tuner_func.py

View File

@ -8,8 +8,7 @@ logger = logging.getLogger(__name__)
def run_with_config(config: dict):
"""Run the pipeline with a given config dict
"""
"""Run the pipeline with a given config dict"""
# pass the hyperparameters to AzureML jobs by overwriting the config file.
overrides = [f"{key}={value}" for key, value in config.items()]
@ -24,25 +23,25 @@ def run_with_config(config: dict):
while not stop:
# get status
status = run._core_run.get_status()
print(f'status: {status}')
print(f"status: {status}")
# get metrics
metrics = run._core_run.get_metrics(recursive=True)
if metrics:
run_metrics = list(metrics.values())
new_metric = run_metrics[0]['eval_binary_error']
new_metric = run_metrics[0]["eval_binary_error"]
if type(new_metric) == list:
new_metric = new_metric[-1]
print(f'eval_binary_error: {new_metric}')
print(f"eval_binary_error: {new_metric}")
tune.report(eval_binary_error=new_metric)
time.sleep(5)
if status == 'FAILED' or status == 'Completed':
if status == "FAILED" or status == "Completed":
stop = True
print("The run is terminated.")

128
test/tune/test_lexiflow.py Normal file
View File

@ -0,0 +1,128 @@
import torch
import thop
import torch.nn as nn
from flaml import tune
import torch.nn.functional as F
import torchvision
import numpy as np
import time
from ray import tune as raytune
DEVICE = torch.device("cpu")
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
def _test_lexiflow():
train_dataset = torchvision.datasets.FashionMNIST(
"test/data",
train=True,
download=True,
transform=torchvision.transforms.ToTensor(),
)
train_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
val_dataset = torchvision.datasets.FashionMNIST(
"test/data", train=False, transform=torchvision.transforms.ToTensor()
)
val_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
def define_model(configuration):
n_layers = configuration["n_layers"]
layers = []
in_features = 28 * 28
for i in range(n_layers):
out_features = configuration["n_units_l{}".format(i)]
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
p = configuration["dropout_{}".format(i)]
layers.append(nn.Dropout(p))
in_features = out_features
layers.append(nn.Linear(in_features, 10))
layers.append(nn.LogSoftmax(dim=1))
return nn.Sequential(*layers)
def train_model(model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
F.nll_loss(model(data), target).backward()
optimizer.step()
def eval_model(model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
pred = model(data).argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / N_VALID_EXAMPLES
flops, params = thop.profile(
model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),), verbose=False
)
return np.log2(flops), 1 - accuracy, params
def evaluate_function(configuration):
model = define_model(configuration).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), configuration["lr"])
n_epoch = configuration["n_epoch"]
for epoch in range(n_epoch):
train_model(model, optimizer, train_loader)
flops, error_rate, params = eval_model(model, val_loader)
return {"error_rate": error_rate, "flops": flops, "params": params}
lexico_objectives = {}
lexico_objectives["metrics"] = ["error_rate", "flops"]
lexico_objectives["tolerances"] = {"error_rate": 0.02, "flops": 0.0}
lexico_objectives["targets"] = {"error_rate": 0.0, "flops": 0.0}
lexico_objectives["modes"] = ["min", "min"]
search_space = {
"n_layers": raytune.randint(lower=1, upper=3),
"n_units_l0": raytune.randint(lower=4, upper=128),
"n_units_l1": raytune.randint(lower=4, upper=128),
"n_units_l2": raytune.randint(lower=4, upper=128),
"dropout_0": raytune.uniform(lower=0.2, upper=0.5),
"dropout_1": raytune.uniform(lower=0.2, upper=0.5),
"dropout_2": raytune.uniform(lower=0.2, upper=0.5),
"lr": raytune.loguniform(lower=1e-5, upper=1e-1),
"n_epoch": raytune.randint(lower=1, upper=20),
}
low_cost_partial_config = {
"n_layers": 1,
"n_units_l0": 4,
"n_units_l1": 4,
"n_units_l2": 4,
"n_epoch": 1,
}
analysis = tune.run(
evaluate_function,
num_samples=100000000,
time_budget_s=100,
config=search_space,
use_ray=False,
lexico_objectives=lexico_objectives,
low_cost_partial_config=low_cost_partial_config,
)
result = analysis.best_result
print(result)
if __name__ == "__main__":
_test_lexiflow()

View File

@ -37,7 +37,7 @@ automl = AutoML()
settings = {
"time_budget": 60, # total running time in seconds
"metric": "accuracy", # metric to optimize
"task": "classification", # task type
"task": "classification", # task type
"log_file_name": "airlines_experiment.log", # flaml log file
}
experiment = mlflow.set_experiment("flaml") # the experiment name in AzureML workspace

View File

@ -205,7 +205,7 @@ Overall, to tune the hyperparameters of the AzureML pipeline, run:
```bash
# the training job will run remotely as an AzureML job in both choices
# run the tuning job locally
# run the tuning job locally
python submit_tune.py --local
# run the tuning job remotely
python submit_tune.py --remote --subscription_id <your subscription_id> --resource_group <your resource_group> --workspace <your workspace>

View File

@ -79,10 +79,10 @@ You can use FLAML in .NET in the following ways:
**Low-code**
- [*Model Builder*](https://dotnet.microsoft.com/apps/machinelearning-ai/ml-dotnet/model-builder) - A Visual Studio extension for training ML models using FLAML. For more information on how to install the, see the [install Model Builder](https://docs.microsoft.com/dotnet/machine-learning/how-to-guides/install-model-builder?tabs=visual-studio-2022) guide.
- [*ML.NET CLI*](https://docs.microsoft.com/dotnet/machine-learning/automate-training-with-cli) - A dotnet CLI tool for training machine learning models using FLAML on Windows, MacOS, and Linux. For more information on how to install the ML.NET CLI, see the [install the ML.NET CLI](https://docs.microsoft.com/dotnet/machine-learning/how-to-guides/install-ml-net-cli?tabs=windows) guide.
- [*ML.NET CLI*](https://docs.microsoft.com/dotnet/machine-learning/automate-training-with-cli) - A dotnet CLI tool for training machine learning models using FLAML on Windows, MacOS, and Linux. For more information on how to install the ML.NET CLI, see the [install the ML.NET CLI](https://docs.microsoft.com/dotnet/machine-learning/how-to-guides/install-ml-net-cli?tabs=windows) guide.
**Code-first**
- [*Microsoft.ML.AutoML*](https://www.nuget.org/packages/Microsoft.ML.AutoML/0.20.0-preview.22313.1) - NuGet package that provides direct access to the FLAML AutoML APIs that power low-code solutions like Model Builder and the ML.NET CLI. For more information on installing NuGet packages, see the install and use a NuGet package in [Visual Studio](https://docs.microsoft.com/nuget/quickstart/install-and-use-a-package-in-visual-studio) or [dotnet CLI](https://docs.microsoft.com/nuget/quickstart/install-and-use-a-package-using-the-dotnet-cli) guides.
- [*Microsoft.ML.AutoML*](https://www.nuget.org/packages/Microsoft.ML.AutoML/0.20.0-preview.22313.1) - NuGet package that provides direct access to the FLAML AutoML APIs that power low-code solutions like Model Builder and the ML.NET CLI. For more information on installing NuGet packages, see the install and use a NuGet package in [Visual Studio](https://docs.microsoft.com/nuget/quickstart/install-and-use-a-package-in-visual-studio) or [dotnet CLI](https://docs.microsoft.com/nuget/quickstart/install-and-use-a-package-using-the-dotnet-cli) guides.
To get started with the ML.NET API and AutoML, see the [csharp-notebooks](https://github.com/dotnet/csharp-notebooks#machine-learning).

View File

@ -17,7 +17,7 @@ For technical details, please check our research publications.
* [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.
* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.
* [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.
* [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).
* [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).
* [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).
Many researchers and engineers have contributed to the technology development. In alphabetical order: Vijay Aski, Sebastien Bubeck, Surajit Chaudhuri, Kevin Chen, Yi Wei Chen, Nadiia Chepurko, Ofer Dekel, Alex Deng, Anshuman Dutt, Nicolo Fusi, Jianfeng Gao, Johannes Gehrke, Niklas Gustafsson, Silu Huang, Moe Kayali, Dongwoo Kim, Christian Konig, John Langford, Menghao Li, Mingqin Li, Susan Xueqing Liu, Zhe Liu, Naveen Gaur, Paul Mineiro, Vivek Narasayya, Jake Radzikowski, Marco Rossi, Amin Saied, Neil Tenenholtz, Olga Vrousgou, Chi Wang, Yue Wang, Markus Weimer, Qingyun Wu, Qiufeng Yin, Haozhe Zhang, Minjia Zhang, XiaoYun Zhang, Eric Zhu, Rui Zhuang.