Time series forecasting with sklearn regressors (#362)

* add sklearn regressors as learners for ts_forecast task

* add direct forecasting strategy
warnings and errors for duplicate rows and missing values

- add preprocess for sklearn time series forecast
 update automl.py
 update test/test_forecast.py

* update model.py and test_forecast.py for cv eval_method

* add "hcrystalball" dependency in setup.py

* update automl.py
- add _validate_ts_data function for abstraction
- include xgb_limitdepth as a learner

* update model.py
- update search space for sklearn ts regressors

* update automl.py and test_forecast.py for numpy array inputs

* add documentations to model.py

* add documentation for removing catboost regressor

* update automl.py
- _validate_ts_data() function

Signed-off-by: Kevin Chen <chenkevin.8787@gmail.com>
This commit is contained in:
Kevin Chen 2022-01-07 02:12:38 -05:00 committed by GitHub
parent 612668e8ed
commit d4273669e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 197 additions and 30 deletions

View File

@ -73,7 +73,7 @@ class SearchState:
self.total_time_used - self.time_best_found,
)
def __init__(self, learner_class, data_size, task, starting_point=None):
def __init__(self, learner_class, data_size, task, starting_point=None, period=None):
self.init_eci = learner_class.cost_relative2lgbm()
self._search_space_domain = {}
self.init_config = {}
@ -82,7 +82,10 @@ class SearchState:
self.data_size = data_size
self.ls_ever_converged = False
self.learner_class = learner_class
search_space = learner_class.search_space(data_size=data_size, task=task)
if task == TS_FORECAST:
search_space = learner_class.search_space(data_size=data_size, task=task, pred_horizon=period)
else:
search_space = learner_class.search_space(data_size=data_size, task=task)
for name, space in search_space.items():
assert (
"domain" in space
@ -808,6 +811,38 @@ class AutoML(BaseEstimator):
X = self._transformer.transform(X)
return X
def _validate_ts_data(
self,
dataframe,
y_train_all=None,
):
assert (
dataframe[dataframe.columns[0]].dtype.name == "datetime64[ns]"
), f"For '{TS_FORECAST}' task, the first column must contain timestamp values."
if y_train_all is not None:
y_df = pd.DataFrame(y_train_all) if isinstance(y_train_all, pd.Series) else pd.DataFrame(y_train_all, columns=['labels'])
dataframe = dataframe.join(y_df)
duplicates = dataframe.duplicated()
if any(duplicates):
logger.warning(
"Duplicate timestamp values found in timestamp column. "
f"\n{dataframe.loc[duplicates, dataframe][dataframe.columns[0]]}"
)
dataframe = dataframe.drop_duplicates()
logger.warning("Removed duplicate rows based on all columns")
assert (
dataframe[[dataframe.columns[0]]].duplicated() is None
), "Duplicate timestamp values with different values for other columns."
ts_series = pd.to_datetime(dataframe[dataframe.columns[0]])
inferred_freq = pd.infer_freq(ts_series)
if inferred_freq is None:
logger.warning(
"Missing timestamps detected. To avoid error with estimators, set estimator list to ['prophet']. "
)
if y_train_all is not None:
return dataframe.iloc[:, :-1], dataframe.iloc[:, -1]
return dataframe
def _validate_data(
self,
X_train_all,
@ -846,9 +881,7 @@ class AutoML(BaseEstimator):
self._nrow, self._ndim = X_train_all.shape
if self._state.task == TS_FORECAST:
X_train_all = pd.DataFrame(X_train_all)
assert (
X_train_all[X_train_all.columns[0]].dtype.name == "datetime64[ns]"
), f"For '{TS_FORECAST}' task, the first column must contain timestamp values."
X_train_all, y_train_all = self._validate_ts_data(X_train_all, y_train_all)
X, y = X_train_all, y_train_all
elif dataframe is not None and label is not None:
assert isinstance(
@ -857,9 +890,7 @@ class AutoML(BaseEstimator):
assert label in dataframe.columns, "label must a column name in dataframe"
self._df = True
if self._state.task == TS_FORECAST:
assert (
dataframe[dataframe.columns[0]].dtype.name == "datetime64[ns]"
), f"For '{TS_FORECAST}' task, the first column must contain timestamp values."
dataframe = self._validate_ts_data(dataframe)
X = dataframe.drop(columns=label)
self._nrow, self._ndim = X.shape
y = dataframe[label]
@ -2079,14 +2110,7 @@ class AutoML(BaseEstimator):
logger.info(f"Minimizing error metric: {error_metric}")
if "auto" == estimator_list:
if self._state.task == TS_FORECAST:
try:
import prophet
estimator_list = ["prophet", "arima", "sarimax"]
except ImportError:
estimator_list = ["arima", "sarimax"]
elif self._state.task == "rank":
if self._state.task == "rank":
estimator_list = ["lgbm", "xgboost", "xgb_limitdepth"]
elif _is_nlp_task(self._state.task):
estimator_list = ["transformer"]
@ -2110,8 +2134,18 @@ class AutoML(BaseEstimator):
"extra_tree",
"xgb_limitdepth",
]
if "regression" != self._state.task:
if TS_FORECAST == self._state.task:
# catboost is removed because it has a `name` parameter, making it incompatible with hcrystalball
estimator_list.remove("catboost")
try:
import prophet
estimator_list += ["prophet", "arima", "sarimax"]
except ImportError:
estimator_list += ["arima", "sarimax"]
elif "regression" != self._state.task:
estimator_list += ["lrl1"]
for estimator_name in estimator_list:
if estimator_name not in self._state.learner_classes:
self.add_learner(
@ -2127,6 +2161,7 @@ class AutoML(BaseEstimator):
data_size=self._state.data_size,
task=self._state.task,
starting_point=starting_points.get(estimator_name),
period=self._state.fit_kwargs.get("period"),
)
logger.info("List of ML learners in AutoML Run: {}".format(estimator_list))
self.estimator_list = estimator_list

View File

@ -20,13 +20,18 @@ from sklearn.metrics import (
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
from .model import (
XGBoostSklearnEstimator,
XGBoost_TS_Regressor,
XGBoostLimitDepthEstimator,
XGBoostLimitDepth_TS_Regressor,
RandomForestEstimator,
RF_TS_Regressor,
LGBMEstimator,
LGBM_TS_Regressor,
LRL1Classifier,
LRL2Classifier,
CatBoostEstimator,
ExtraTreesEstimator,
ExtraTrees_TS_Regressor,
KNeighborsEstimator,
Prophet,
ARIMA,
@ -89,13 +94,13 @@ huggingface_submetric_to_metric = {"rouge1": "rouge", "rouge2": "rouge"}
def get_estimator_class(task, estimator_name):
# when adding a new learner, need to add an elif branch
if "xgboost" == estimator_name:
estimator_class = XGBoostSklearnEstimator
estimator_class = XGBoost_TS_Regressor if TS_FORECAST == task else XGBoostSklearnEstimator
elif "xgb_limitdepth" == estimator_name:
estimator_class = XGBoostLimitDepthEstimator
estimator_class = XGBoostLimitDepth_TS_Regressor if TS_FORECAST == task else XGBoostLimitDepthEstimator
elif "rf" == estimator_name:
estimator_class = RandomForestEstimator
estimator_class = RF_TS_Regressor if TS_FORECAST == task else RandomForestEstimator
elif "lgbm" == estimator_name:
estimator_class = LGBMEstimator
estimator_class = LGBM_TS_Regressor if TS_FORECAST == task else LGBMEstimator
elif "lrl1" == estimator_name:
estimator_class = LRL1Classifier
elif "lrl2" == estimator_name:
@ -103,7 +108,7 @@ def get_estimator_class(task, estimator_name):
elif "catboost" == estimator_name:
estimator_class = CatBoostEstimator
elif "extra_tree" == estimator_name:
estimator_class = ExtraTreesEstimator
estimator_class = ExtraTrees_TS_Regressor if TS_FORECAST == task else ExtraTreesEstimator
elif "kneighbor" == estimator_name:
estimator_class = KNeighborsEstimator
elif "prophet" in estimator_name:
@ -441,10 +446,6 @@ def evaluate_model_CV(
groups = kf.groups
kf = kf.split(X_train_split, y_train_split, groups)
shuffle = False
elif isinstance(kf, TimeSeriesSplit) and task == TS_FORECAST:
y_train_all = pd.DataFrame(y_train_all, columns=[TS_VALUE_COL])
train = X_train_all.join(y_train_all)
kf = kf.split(train)
elif isinstance(kf, TimeSeriesSplit):
kf = kf.split(X_train_split, y_train_split)
else:

View File

@ -1789,6 +1789,135 @@ class SARIMAX(ARIMA):
return train_time
class TS_SKLearn_Regressor(SKLearnEstimator):
""" The class for tuning SKLearn Regressors for time-series forecasting, using hcrystalball"""
base_class = SKLearnEstimator
@classmethod
def search_space(cls, data_size, pred_horizon, **params):
space = cls.base_class.search_space(data_size, **params)
space.update({
"optimize_for_horizon": {
"domain": tune.choice([True, False]),
"init_value": False,
"low_cost_init_value": False,
},
"lags": {
"domain": tune.randint(lower=1, upper=data_size[0] - pred_horizon),
"init_value": 3,
},
})
return space
def __init__(self, task=TS_FORECAST, **params):
super().__init__(task, **params)
self.hcrystaball_model = None
def transform_X(self, X):
cols = list(X)
if len(cols) == 1:
ds_col = cols[0]
X = pd.DataFrame(index=X[ds_col])
elif len(cols) > 1:
ds_col = cols[0]
exog_cols = cols[1:]
X = X[exog_cols].set_index(X[ds_col])
return X
def _fit(self, X_train, y_train, budget=None, **kwargs):
from hcrystalball.wrappers import get_sklearn_wrapper
X_train = self.transform_X(X_train)
X_train = self._preprocess(X_train)
params = self.params.copy()
lags = params.pop("lags")
optimize_for_horizon = params.pop("optimize_for_horizon")
estimator = self.base_class(task="regression", **params)
self.hcrystaball_model = get_sklearn_wrapper(estimator.estimator_class)
self.hcrystaball_model.lags = int(lags)
self.hcrystaball_model.fit(X_train, y_train)
if optimize_for_horizon:
# Direct Multi-step Forecast Strategy - fit a seperate model for each horizon
model_list = []
for i in range(1, kwargs["period"] + 1):
X_fit, y_fit = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X_train, y_train, i)
self.hcrystaball_model.model.set_params(**estimator.params)
model = self.hcrystaball_model.model.fit(X_fit, y_fit)
model_list.append(model)
self._model = model_list
else:
X_fit, y_fit = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X_train, y_train, kwargs["period"])
self.hcrystaball_model.model.set_params(**estimator.params)
model = self.hcrystaball_model.model.fit(X_fit, y_fit)
self._model = model
def fit(self, X_train, y_train, budget=None, **kwargs):
current_time = time.time()
self._fit(X_train, y_train, budget=budget, **kwargs)
train_time = time.time() - current_time
return train_time
def predict(self, X_test):
if self._model is not None:
X_test = self.transform_X(X_test)
X_test = self._preprocess(X_test)
if isinstance(self._model, list):
assert (
len(self._model) == len(X_test)
), "Model is optimized for horizon, length of X_test must be equal to `period`."
preds = []
for i in range(1, len(self._model) + 1):
X_pred, _ = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X_test.iloc[:i, :])
preds.append(self._model[i - 1].predict(X_pred)[-1])
forecast = pd.DataFrame(data=np.asarray(preds).reshape(-1, 1),
columns=[self.hcrystaball_model.name],
index=X_test.index)
else:
X_pred, _ = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X_test)
forecast = self._model.predict(X_pred)
return forecast
else:
logger.warning(
"Estimator is not fit yet. Please run fit() before predict()."
)
return np.ones(X_test.shape[0])
class LGBM_TS_Regressor(TS_SKLearn_Regressor):
""" The class for tuning LGBM Regressor for time-series forecasting"""
base_class = LGBMEstimator
class XGBoost_TS_Regressor(TS_SKLearn_Regressor):
""" The class for tuning XGBoost Regressor for time-series forecasting"""
base_class = XGBoostSklearnEstimator
# catboost regressor is invalid because it has a `name` parameter, making it incompatible with hcrystalball
# class CatBoost_TS_Regressor(TS_Regressor):
# base_class = CatBoostEstimator
class RF_TS_Regressor(TS_SKLearn_Regressor):
""" The class for tuning Random Forest Regressor for time-series forecasting"""
base_class = RandomForestEstimator
class ExtraTrees_TS_Regressor(TS_SKLearn_Regressor):
""" The class for tuning Extra Trees Regressor for time-series forecasting"""
base_class = ExtraTreesEstimator
class XGBoostLimitDepth_TS_Regressor(TS_SKLearn_Regressor):
""" The class for tuning XGBoost Regressor with unlimited depth for time-series forecasting"""
base_class = XGBoostLimitDepthEstimator
class suppress_stdout_stderr(object):
def __init__(self):
# Open a pair of null files

View File

@ -60,6 +60,7 @@ setuptools.setup(
"torch",
"nltk",
"rouge_score",
"hcrystalball==0.1.10",
"seqeval",
],
"catboost": ["catboost>=0.26"],
@ -85,8 +86,8 @@ setuptools.setup(
"nltk",
"rouge_score",
],
"ts_forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
"forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
"ts_forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2", "hcrystalball==0.1.10"],
"forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2", "hcrystalball==0.1.10"],
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
},
classifiers=[

View File

@ -105,6 +105,7 @@ def test_numpy():
task="ts_forecast",
time_budget=3, # time budget in seconds
log_file_name="test/ts_forecast.log",
n_splits=3, # number of splits
)
print(automl.predict(X_train[72:]))
except ImportError:
@ -280,7 +281,6 @@ def load_multi_dataset_cat(time_horizon):
def test_multivariate_forecast_cat(budget=5):
time_horizon = 180
train_df, test_df = load_multi_dataset_cat(time_horizon)
print(train_df)
X_test = test_df[
["timeStamp", "season", "above_monthly_avg"]
] # test dataframe must contain values for the regressors / multivariate variables
@ -290,7 +290,7 @@ def test_multivariate_forecast_cat(budget=5):
"time_budget": budget, # total running time in seconds
"metric": "mape", # primary metric
"task": "ts_forecast", # task type
"log_file_name": "test/energy_forecast_numerical.log", # flaml log file
"log_file_name": "test/energy_forecast_categorical.log", # flaml log file
"eval_method": "holdout",
"log_type": "all",
"label": "demand",
@ -360,3 +360,4 @@ if __name__ == "__main__":
test_forecast_automl(60)
test_multivariate_forecast_num(60)
test_multivariate_forecast_cat(60)
test_numpy()