Support time series forecasting for discrete target variable (#416)

* support 'ts_forecast_classification' task to forecast discrete values

* update test_forecast.py
- add test for forecasting discrete values

* update test_model.py

* pre-commit changes
This commit is contained in:
Kevin Chen 2022-01-24 21:39:36 -05:00 committed by GitHub
parent 4814091d87
commit 81f54026c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 140 additions and 56 deletions

View File

@ -42,7 +42,7 @@ from .data import (
CLASSIFICATION,
TOKENCLASSIFICATION,
TS_FORECAST,
FORECAST,
TS_FORECASTREGRESSION,
REGRESSION,
_is_nlp_task,
NLG_TASKS,
@ -84,7 +84,7 @@ class SearchState:
self.data_size = data_size
self.ls_ever_converged = False
self.learner_class = learner_class
if task == TS_FORECAST:
if task in TS_FORECAST:
search_space = learner_class.search_space(
data_size=data_size, task=task, pred_horizon=period
)
@ -816,7 +816,7 @@ class AutoML(BaseEstimator):
return X
elif issparse(X):
X = X.tocsr()
if self._state.task == TS_FORECAST:
if self._state.task in TS_FORECAST:
X = pd.DataFrame(X)
if self._transformer:
X = self._transformer.transform(X)
@ -894,7 +894,7 @@ class AutoML(BaseEstimator):
), "# rows in X_train must match length of y_train."
self._df = isinstance(X_train_all, pd.DataFrame)
self._nrow, self._ndim = X_train_all.shape
if self._state.task == TS_FORECAST:
if self._state.task in TS_FORECAST:
X_train_all = pd.DataFrame(X_train_all)
X_train_all, y_train_all = self._validate_ts_data(
X_train_all, y_train_all
@ -906,7 +906,7 @@ class AutoML(BaseEstimator):
), "dataframe must be a pandas DataFrame"
assert label in dataframe.columns, "label must a column name in dataframe"
self._df = True
if self._state.task == TS_FORECAST:
if self._state.task in TS_FORECAST:
dataframe = self._validate_ts_data(dataframe)
X = dataframe.drop(columns=label)
self._nrow, self._ndim = X.shape
@ -1078,7 +1078,7 @@ class AutoML(BaseEstimator):
if X_val is None and eval_method == "holdout":
# if eval_method = holdout, make holdout data
if self._split_type == "time":
if self._state.task == TS_FORECAST:
if self._state.task in TS_FORECAST:
num_samples = X_train_all.shape[0]
period = self._state.fit_kwargs["period"]
assert (
@ -1239,7 +1239,7 @@ class AutoML(BaseEstimator):
)
elif self._split_type == "time":
# logger.info("Using TimeSeriesSplit")
if self._state.task == TS_FORECAST:
if self._state.task in TS_FORECAST:
period = self._state.fit_kwargs["period"]
if period * (n_splits + 1) > y_train_all.size:
n_splits = int(y_train_all.size / period - 1)
@ -1386,7 +1386,7 @@ class AutoML(BaseEstimator):
auto_augment = (
self._settings.get("auto_augment") if auto_augment is None else auto_augment
)
self._state.task = TS_FORECAST if task == FORECAST else task
self._state.task = task
self._estimator_type = "classifier" if task in CLASSIFICATION else "regressor"
self._state.fit_kwargs = fit_kwargs
@ -1489,7 +1489,7 @@ class AutoML(BaseEstimator):
elif self._state.task in REGRESSION:
assert split_type in ["auto", "uniform", "time", "group"]
self._split_type = split_type if split_type != "auto" else "uniform"
elif self._state.task == TS_FORECAST:
elif self._state.task in TS_FORECAST:
assert split_type in ["auto", "time"]
self._split_type = "time"
assert isinstance(
@ -1994,7 +1994,7 @@ class AutoML(BaseEstimator):
min_sample_size = min_sample_size or self._settings.get("min_sample_size")
use_ray = self._settings.get("use_ray") if use_ray is None else use_ray
self._state.task = TS_FORECAST if task == FORECAST else task
self._state.task = task
self._state.log_training_metric = log_training_metric
self._state.fit_kwargs = fit_kwargs
@ -2070,7 +2070,7 @@ class AutoML(BaseEstimator):
metric = "roc_auc"
elif "multi" in self._state.task:
metric = "log_loss"
elif self._state.task == TS_FORECAST:
elif self._state.task in TS_FORECAST:
metric = "mape"
elif self._state.task == "rank":
metric = "ndcg"
@ -2148,16 +2148,17 @@ class AutoML(BaseEstimator):
"extra_tree",
"xgb_limitdepth",
]
if TS_FORECAST == self._state.task:
if self._state.task in TS_FORECAST:
# catboost is removed because it has a `name` parameter, making it incompatible with hcrystalball
if "catboost" in estimator_list:
estimator_list.remove("catboost")
try:
import prophet
if self._state.task in TS_FORECASTREGRESSION:
try:
import prophet
estimator_list += ["prophet", "arima", "sarimax"]
except ImportError:
estimator_list += ["arima", "sarimax"]
estimator_list += ["prophet", "arima", "sarimax"]
except ImportError:
estimator_list += ["arima", "sarimax"]
elif "regression" != self._state.task:
estimator_list += ["lrl1"]
@ -2802,7 +2803,7 @@ class AutoML(BaseEstimator):
if self._max_iter > 1:
self._state.time_from_start -= self._state.time_budget
if (
self._state.task == TS_FORECAST
self._state.task in TS_FORECAST
or self._trained_estimator is None
or self._trained_estimator.model is None
or (

View File

@ -26,10 +26,18 @@ CLASSIFICATION = (
)
SEQREGRESSION = "seq-regression"
REGRESSION = ("regression", SEQREGRESSION)
TS_FORECAST = "ts_forecast"
TS_FORECASTREGRESSION = (
"forecast",
"ts_forecast",
"ts_forecast_regression",
)
TS_FORECASTCLASSIFICATION = "ts_forecast_classification"
TS_FORECAST = (
*TS_FORECASTREGRESSION,
TS_FORECASTCLASSIFICATION,
)
TS_TIMESTAMP_COL = "ds"
TS_VALUE_COL = "y"
FORECAST = "forecast"
SUMMARIZATION = "summarization"
NLG_TASKS = (SUMMARIZATION,)
NLU_TASKS = (
@ -266,7 +274,7 @@ class DataTransformer:
n = X.shape[0]
cat_columns, num_columns, datetime_columns = [], [], []
drop = False
if task == TS_FORECAST:
if task in TS_FORECAST:
X = X.rename(columns={X.columns[0]: TS_TIMESTAMP_COL})
ds_col = X.pop(TS_TIMESTAMP_COL)
if isinstance(y, Series):
@ -323,7 +331,7 @@ class DataTransformer:
X[column] = X[column].fillna(np.nan)
num_columns.append(column)
X = X[cat_columns + num_columns]
if task == TS_FORECAST:
if task in TS_FORECAST:
X.insert(0, TS_TIMESTAMP_COL, ds_col)
if cat_columns:
X[cat_columns] = X[cat_columns].astype("category")
@ -397,7 +405,7 @@ class DataTransformer:
self._num_columns,
self._datetime_columns,
)
if self._task == TS_FORECAST:
if self._task in TS_FORECAST:
X = X.rename(columns={X.columns[0]: TS_TIMESTAMP_COL})
ds_col = X.pop(TS_TIMESTAMP_COL)
for column in datetime_columns:
@ -419,7 +427,7 @@ class DataTransformer:
X[column] = X[column].map(datetime.toordinal)
del tmp_dt
X = X[cat_columns + num_columns].copy()
if self._task == TS_FORECAST:
if self._task in TS_FORECAST:
X.insert(0, TS_TIMESTAMP_COL, ds_col)
for column in cat_columns:
if X[column].dtype.name == "object":

View File

@ -20,18 +20,18 @@ from sklearn.metrics import (
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
from .model import (
XGBoostSklearnEstimator,
XGBoost_TS_Regressor,
XGBoost_TS,
XGBoostLimitDepthEstimator,
XGBoostLimitDepth_TS_Regressor,
XGBoostLimitDepth_TS,
RandomForestEstimator,
RF_TS_Regressor,
RF_TS,
LGBMEstimator,
LGBM_TS_Regressor,
LGBM_TS,
LRL1Classifier,
LRL2Classifier,
CatBoostEstimator,
ExtraTreesEstimator,
ExtraTrees_TS_Regressor,
ExtraTrees_TS,
KNeighborsEstimator,
Prophet,
ARIMA,
@ -94,21 +94,15 @@ huggingface_submetric_to_metric = {"rouge1": "rouge", "rouge2": "rouge"}
def get_estimator_class(task, estimator_name):
# when adding a new learner, need to add an elif branch
if "xgboost" == estimator_name:
estimator_class = (
XGBoost_TS_Regressor if TS_FORECAST == task else XGBoostSklearnEstimator
)
estimator_class = XGBoost_TS if task in TS_FORECAST else XGBoostSklearnEstimator
elif "xgb_limitdepth" == estimator_name:
estimator_class = (
XGBoostLimitDepth_TS_Regressor
if TS_FORECAST == task
else XGBoostLimitDepthEstimator
XGBoostLimitDepth_TS if task in TS_FORECAST else XGBoostLimitDepthEstimator
)
elif "rf" == estimator_name:
estimator_class = (
RF_TS_Regressor if TS_FORECAST == task else RandomForestEstimator
)
estimator_class = RF_TS if task in TS_FORECAST else RandomForestEstimator
elif "lgbm" == estimator_name:
estimator_class = LGBM_TS_Regressor if TS_FORECAST == task else LGBMEstimator
estimator_class = LGBM_TS if task in TS_FORECAST else LGBMEstimator
elif "lrl1" == estimator_name:
estimator_class = LRL1Classifier
elif "lrl2" == estimator_name:
@ -116,9 +110,7 @@ def get_estimator_class(task, estimator_name):
elif "catboost" == estimator_name:
estimator_class = CatBoostEstimator
elif "extra_tree" == estimator_name:
estimator_class = (
ExtraTrees_TS_Regressor if TS_FORECAST == task else ExtraTreesEstimator
)
estimator_class = ExtraTrees_TS if task in TS_FORECAST else ExtraTreesEstimator
elif "kneighbor" == estimator_name:
estimator_class = KNeighborsEstimator
elif "prophet" in estimator_name:
@ -453,7 +445,7 @@ def evaluate_model_CV(
else:
labels = None
groups = None
shuffle = False if task == TS_FORECAST else True
shuffle = False if task in TS_FORECAST else True
if isinstance(kf, RepeatedStratifiedKFold):
kf = kf.split(X_train_split, y_train_split)
elif isinstance(kf, GroupKFold):

View File

@ -23,6 +23,7 @@ from .data import (
group_counts,
CLASSIFICATION,
TS_FORECAST,
TS_FORECASTREGRESSION,
TS_TIMESTAMP_COL,
TS_VALUE_COL,
SEQCLASSIFICATION,
@ -1571,7 +1572,7 @@ class Prophet(SKLearnEstimator):
}
return space
def __init__(self, task=TS_FORECAST, n_jobs=1, **params):
def __init__(self, task="ts_forecast", n_jobs=1, **params):
super().__init__(task, **params)
def _join(self, X_train, y_train):
@ -1796,7 +1797,7 @@ class SARIMAX(ARIMA):
return train_time
class TS_SKLearn_Regressor(SKLearnEstimator):
class TS_SKLearn(SKLearnEstimator):
"""The class for tuning SKLearn Regressors for time-series forecasting, using hcrystalball"""
base_class = SKLearnEstimator
@ -1819,9 +1820,12 @@ class TS_SKLearn_Regressor(SKLearnEstimator):
)
return space
def __init__(self, task=TS_FORECAST, **params):
def __init__(self, task="ts_forecast", **params):
super().__init__(task, **params)
self.hcrystaball_model = None
self.ts_task = (
"regression" if task in TS_FORECASTREGRESSION else "classification"
)
def transform_X(self, X):
cols = list(X)
@ -1842,7 +1846,7 @@ class TS_SKLearn_Regressor(SKLearnEstimator):
params = self.params.copy()
lags = params.pop("lags")
optimize_for_horizon = params.pop("optimize_for_horizon")
estimator = self.base_class(task="regression", **params)
estimator = self.base_class(task=self.ts_task, **params)
self.hcrystaball_model = get_sklearn_wrapper(estimator.estimator_class)
self.hcrystaball_model.lags = int(lags)
self.hcrystaball_model.fit(X_train, y_train)
@ -1913,13 +1917,13 @@ class TS_SKLearn_Regressor(SKLearnEstimator):
return np.ones(X.shape[0])
class LGBM_TS_Regressor(TS_SKLearn_Regressor):
class LGBM_TS(TS_SKLearn):
"""The class for tuning LGBM Regressor for time-series forecasting"""
base_class = LGBMEstimator
class XGBoost_TS_Regressor(TS_SKLearn_Regressor):
class XGBoost_TS(TS_SKLearn):
"""The class for tuning XGBoost Regressor for time-series forecasting"""
base_class = XGBoostSklearnEstimator
@ -1930,19 +1934,19 @@ class XGBoost_TS_Regressor(TS_SKLearn_Regressor):
# base_class = CatBoostEstimator
class RF_TS_Regressor(TS_SKLearn_Regressor):
class RF_TS(TS_SKLearn):
"""The class for tuning Random Forest Regressor for time-series forecasting"""
base_class = RandomForestEstimator
class ExtraTrees_TS_Regressor(TS_SKLearn_Regressor):
class ExtraTrees_TS(TS_SKLearn):
"""The class for tuning Extra Trees Regressor for time-series forecasting"""
base_class = ExtraTreesEstimator
class XGBoostLimitDepth_TS_Regressor(TS_SKLearn_Regressor):
class XGBoostLimitDepth_TS(TS_SKLearn):
"""The class for tuning XGBoost Regressor with unlimited depth for time-series forecasting"""
base_class = XGBoostLimitDepthEstimator

View File

@ -1,6 +1,6 @@
# AutoML for NLP
This directory contains utility functions used by AutoNLP. Currently we support four NLP tasks: sequence classification, sequence regression, multiple choice and summarization.
This directory contains utility functions used by AutoNLP. Currently we support four NLP tasks: sequence classification, sequence regression, multiple choice and summarization.
Please refer to this [link](https://microsoft.github.io/FLAML/docs/Examples/AutoML-NLP) for examples.

View File

@ -7,6 +7,7 @@ class DataCollatorForAuto(DataCollatorWithPadding):
def __call__(self, features):
from itertools import chain
import torch
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature.pop(label_name) for feature in features]
batch_size = len(features)
@ -27,6 +28,7 @@ class DataCollatorForAuto(DataCollatorWithPadding):
class DataCollatorForPredict(DataCollatorWithPadding):
def __call__(self, features):
from itertools import chain
batch_size = len(features)
num_choices = len(features[0]["input_ids"])
flattened_features = [

View File

@ -86,7 +86,11 @@ setuptools.setup(
"nltk",
"rouge_score",
],
"ts_forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2", "hcrystalball==0.1.10"],
"ts_forecast": [
"prophet>=1.0.1",
"statsmodels>=0.12.2",
"hcrystalball==0.1.10",
],
"forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2", "hcrystalball==0.1.10"],
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
},

View File

@ -352,8 +352,81 @@ def test_multivariate_forecast_cat(budget=5):
# plt.show()
def test_forecast_classification(budget=5):
from hcrystalball.utils import get_sales_data
from hcrystalball.wrappers import get_sklearn_wrapper
time_horizon = 30
df = get_sales_data(n_dates=180, n_assortments=1, n_states=1, n_stores=1)
df = df[["Sales", "Open", "Promo", "Promo2"]]
# feature engineering
import numpy as np
df["above_mean_sales"] = np.where(df["Sales"] > df["Sales"].mean(), 1, 0)
df.reset_index(inplace=True)
train_df = df[:-time_horizon]
test_df = df[-time_horizon:]
X_train, X_test = (
train_df[["Date", "Open", "Promo", "Promo2"]],
test_df[["Date", "Open", "Promo", "Promo2"]],
)
y_train, y_test = train_df["above_mean_sales"], test_df["above_mean_sales"]
automl = AutoML()
settings = {
"time_budget": budget, # total running time in seconds
"metric": "accuracy", # primary metric
"task": "ts_forecast_classification", # task type
"log_file_name": "test/sales_classification_forecast.log", # flaml log file
"eval_method": "holdout",
}
"""The main flaml automl API"""
automl.fit(X_train=X_train, y_train=y_train, **settings, period=time_horizon)
""" retrieve best config and best learner"""
print("Best ML leaner:", automl.best_estimator)
print("Best hyperparmeter config:", automl.best_config)
print(f"Best mape on validation data: {automl.best_loss}")
print(f"Training duration of best run: {automl.best_config_train_time}s")
print(automl.model.estimator)
""" pickle and save the automl object """
import pickle
with open("automl.pkl", "wb") as f:
pickle.dump(automl, f, pickle.HIGHEST_PROTOCOL)
""" compute predictions of testing dataset """
y_pred = automl.predict(X_test)
""" compute different metric values on testing dataset"""
from flaml.ml import sklearn_metric_loss_score
print(y_test)
print(y_pred)
print("accuracy", "=", 1 - sklearn_metric_loss_score("accuracy", y_test, y_pred))
from flaml.data import get_output_from_log
(
time_history,
best_valid_loss_history,
valid_loss_history,
config_history,
metric_history,
) = get_output_from_log(filename=settings["log_file_name"], time_budget=budget)
for config in config_history:
print(config)
print(automl.resource_attr)
print(automl.max_resource)
print(automl.min_resource)
# import matplotlib.pyplot as plt
#
# plt.title("Learning Curve")
# plt.xlabel("Wall Clock Time (s)")
# plt.ylabel("Validation Accuracy")
# plt.scatter(time_history, 1 - np.array(valid_loss_history))
# plt.step(time_history, 1 - np.array(best_valid_loss_history), where="post")
# plt.show()
if __name__ == "__main__":
test_forecast_automl(60)
test_multivariate_forecast_num(60)
test_multivariate_forecast_cat(60)
test_numpy()
test_forecast_classification()

View File

@ -12,7 +12,7 @@ from flaml.model import (
RandomForestEstimator,
Prophet,
ARIMA,
LGBM_TS_Regressor,
LGBM_TS,
)
@ -98,7 +98,7 @@ def test_prep():
# X_test needs to be either a pandas Dataframe with dates as the first column or an int number of periods for predict().
pass
lgbm = LGBM_TS_Regressor(optimize_for_horizon=True, lags=1)
lgbm = LGBM_TS(optimize_for_horizon=True, lags=1)
X = DataFrame(
{
"A": [