diff --git a/flaml/automl.py b/flaml/automl.py index 080c9202a..9a668d2e6 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -73,7 +73,7 @@ class SearchState: self.total_time_used - self.time_best_found, ) - def __init__(self, learner_class, data_size, task, starting_point=None): + def __init__(self, learner_class, data_size, task, starting_point=None, period=None): self.init_eci = learner_class.cost_relative2lgbm() self._search_space_domain = {} self.init_config = {} @@ -82,7 +82,10 @@ class SearchState: self.data_size = data_size self.ls_ever_converged = False self.learner_class = learner_class - search_space = learner_class.search_space(data_size=data_size, task=task) + if task == TS_FORECAST: + search_space = learner_class.search_space(data_size=data_size, task=task, pred_horizon=period) + else: + search_space = learner_class.search_space(data_size=data_size, task=task) for name, space in search_space.items(): assert ( "domain" in space @@ -808,6 +811,38 @@ class AutoML(BaseEstimator): X = self._transformer.transform(X) return X + def _validate_ts_data( + self, + dataframe, + y_train_all=None, + ): + assert ( + dataframe[dataframe.columns[0]].dtype.name == "datetime64[ns]" + ), f"For '{TS_FORECAST}' task, the first column must contain timestamp values." + if y_train_all is not None: + y_df = pd.DataFrame(y_train_all) if isinstance(y_train_all, pd.Series) else pd.DataFrame(y_train_all, columns=['labels']) + dataframe = dataframe.join(y_df) + duplicates = dataframe.duplicated() + if any(duplicates): + logger.warning( + "Duplicate timestamp values found in timestamp column. " + f"\n{dataframe.loc[duplicates, dataframe][dataframe.columns[0]]}" + ) + dataframe = dataframe.drop_duplicates() + logger.warning("Removed duplicate rows based on all columns") + assert ( + dataframe[[dataframe.columns[0]]].duplicated() is None + ), "Duplicate timestamp values with different values for other columns." + ts_series = pd.to_datetime(dataframe[dataframe.columns[0]]) + inferred_freq = pd.infer_freq(ts_series) + if inferred_freq is None: + logger.warning( + "Missing timestamps detected. To avoid error with estimators, set estimator list to ['prophet']. " + ) + if y_train_all is not None: + return dataframe.iloc[:, :-1], dataframe.iloc[:, -1] + return dataframe + def _validate_data( self, X_train_all, @@ -846,9 +881,7 @@ class AutoML(BaseEstimator): self._nrow, self._ndim = X_train_all.shape if self._state.task == TS_FORECAST: X_train_all = pd.DataFrame(X_train_all) - assert ( - X_train_all[X_train_all.columns[0]].dtype.name == "datetime64[ns]" - ), f"For '{TS_FORECAST}' task, the first column must contain timestamp values." + X_train_all, y_train_all = self._validate_ts_data(X_train_all, y_train_all) X, y = X_train_all, y_train_all elif dataframe is not None and label is not None: assert isinstance( @@ -857,9 +890,7 @@ class AutoML(BaseEstimator): assert label in dataframe.columns, "label must a column name in dataframe" self._df = True if self._state.task == TS_FORECAST: - assert ( - dataframe[dataframe.columns[0]].dtype.name == "datetime64[ns]" - ), f"For '{TS_FORECAST}' task, the first column must contain timestamp values." + dataframe = self._validate_ts_data(dataframe) X = dataframe.drop(columns=label) self._nrow, self._ndim = X.shape y = dataframe[label] @@ -2079,14 +2110,7 @@ class AutoML(BaseEstimator): logger.info(f"Minimizing error metric: {error_metric}") if "auto" == estimator_list: - if self._state.task == TS_FORECAST: - try: - import prophet - - estimator_list = ["prophet", "arima", "sarimax"] - except ImportError: - estimator_list = ["arima", "sarimax"] - elif self._state.task == "rank": + if self._state.task == "rank": estimator_list = ["lgbm", "xgboost", "xgb_limitdepth"] elif _is_nlp_task(self._state.task): estimator_list = ["transformer"] @@ -2110,8 +2134,18 @@ class AutoML(BaseEstimator): "extra_tree", "xgb_limitdepth", ] - if "regression" != self._state.task: + if TS_FORECAST == self._state.task: + # catboost is removed because it has a `name` parameter, making it incompatible with hcrystalball + estimator_list.remove("catboost") + try: + import prophet + + estimator_list += ["prophet", "arima", "sarimax"] + except ImportError: + estimator_list += ["arima", "sarimax"] + elif "regression" != self._state.task: estimator_list += ["lrl1"] + for estimator_name in estimator_list: if estimator_name not in self._state.learner_classes: self.add_learner( @@ -2127,6 +2161,7 @@ class AutoML(BaseEstimator): data_size=self._state.data_size, task=self._state.task, starting_point=starting_points.get(estimator_name), + period=self._state.fit_kwargs.get("period"), ) logger.info("List of ML learners in AutoML Run: {}".format(estimator_list)) self.estimator_list = estimator_list diff --git a/flaml/ml.py b/flaml/ml.py index 86dceb295..0655d5cc8 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -20,13 +20,18 @@ from sklearn.metrics import ( from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit from .model import ( XGBoostSklearnEstimator, + XGBoost_TS_Regressor, XGBoostLimitDepthEstimator, + XGBoostLimitDepth_TS_Regressor, RandomForestEstimator, + RF_TS_Regressor, LGBMEstimator, + LGBM_TS_Regressor, LRL1Classifier, LRL2Classifier, CatBoostEstimator, ExtraTreesEstimator, + ExtraTrees_TS_Regressor, KNeighborsEstimator, Prophet, ARIMA, @@ -89,13 +94,13 @@ huggingface_submetric_to_metric = {"rouge1": "rouge", "rouge2": "rouge"} def get_estimator_class(task, estimator_name): # when adding a new learner, need to add an elif branch if "xgboost" == estimator_name: - estimator_class = XGBoostSklearnEstimator + estimator_class = XGBoost_TS_Regressor if TS_FORECAST == task else XGBoostSklearnEstimator elif "xgb_limitdepth" == estimator_name: - estimator_class = XGBoostLimitDepthEstimator + estimator_class = XGBoostLimitDepth_TS_Regressor if TS_FORECAST == task else XGBoostLimitDepthEstimator elif "rf" == estimator_name: - estimator_class = RandomForestEstimator + estimator_class = RF_TS_Regressor if TS_FORECAST == task else RandomForestEstimator elif "lgbm" == estimator_name: - estimator_class = LGBMEstimator + estimator_class = LGBM_TS_Regressor if TS_FORECAST == task else LGBMEstimator elif "lrl1" == estimator_name: estimator_class = LRL1Classifier elif "lrl2" == estimator_name: @@ -103,7 +108,7 @@ def get_estimator_class(task, estimator_name): elif "catboost" == estimator_name: estimator_class = CatBoostEstimator elif "extra_tree" == estimator_name: - estimator_class = ExtraTreesEstimator + estimator_class = ExtraTrees_TS_Regressor if TS_FORECAST == task else ExtraTreesEstimator elif "kneighbor" == estimator_name: estimator_class = KNeighborsEstimator elif "prophet" in estimator_name: @@ -441,10 +446,6 @@ def evaluate_model_CV( groups = kf.groups kf = kf.split(X_train_split, y_train_split, groups) shuffle = False - elif isinstance(kf, TimeSeriesSplit) and task == TS_FORECAST: - y_train_all = pd.DataFrame(y_train_all, columns=[TS_VALUE_COL]) - train = X_train_all.join(y_train_all) - kf = kf.split(train) elif isinstance(kf, TimeSeriesSplit): kf = kf.split(X_train_split, y_train_split) else: diff --git a/flaml/model.py b/flaml/model.py index 15dd96ffc..8736db77f 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -1789,6 +1789,135 @@ class SARIMAX(ARIMA): return train_time +class TS_SKLearn_Regressor(SKLearnEstimator): + """ The class for tuning SKLearn Regressors for time-series forecasting, using hcrystalball""" + + base_class = SKLearnEstimator + + @classmethod + def search_space(cls, data_size, pred_horizon, **params): + space = cls.base_class.search_space(data_size, **params) + space.update({ + "optimize_for_horizon": { + "domain": tune.choice([True, False]), + "init_value": False, + "low_cost_init_value": False, + }, + "lags": { + "domain": tune.randint(lower=1, upper=data_size[0] - pred_horizon), + "init_value": 3, + }, + }) + return space + + def __init__(self, task=TS_FORECAST, **params): + super().__init__(task, **params) + self.hcrystaball_model = None + + def transform_X(self, X): + cols = list(X) + if len(cols) == 1: + ds_col = cols[0] + X = pd.DataFrame(index=X[ds_col]) + elif len(cols) > 1: + ds_col = cols[0] + exog_cols = cols[1:] + X = X[exog_cols].set_index(X[ds_col]) + return X + + def _fit(self, X_train, y_train, budget=None, **kwargs): + from hcrystalball.wrappers import get_sklearn_wrapper + + X_train = self.transform_X(X_train) + X_train = self._preprocess(X_train) + params = self.params.copy() + lags = params.pop("lags") + optimize_for_horizon = params.pop("optimize_for_horizon") + estimator = self.base_class(task="regression", **params) + self.hcrystaball_model = get_sklearn_wrapper(estimator.estimator_class) + self.hcrystaball_model.lags = int(lags) + self.hcrystaball_model.fit(X_train, y_train) + if optimize_for_horizon: + # Direct Multi-step Forecast Strategy - fit a seperate model for each horizon + model_list = [] + for i in range(1, kwargs["period"] + 1): + X_fit, y_fit = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X_train, y_train, i) + self.hcrystaball_model.model.set_params(**estimator.params) + model = self.hcrystaball_model.model.fit(X_fit, y_fit) + model_list.append(model) + self._model = model_list + else: + X_fit, y_fit = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X_train, y_train, kwargs["period"]) + self.hcrystaball_model.model.set_params(**estimator.params) + model = self.hcrystaball_model.model.fit(X_fit, y_fit) + self._model = model + + def fit(self, X_train, y_train, budget=None, **kwargs): + current_time = time.time() + self._fit(X_train, y_train, budget=budget, **kwargs) + train_time = time.time() - current_time + return train_time + + def predict(self, X_test): + if self._model is not None: + X_test = self.transform_X(X_test) + X_test = self._preprocess(X_test) + if isinstance(self._model, list): + assert ( + len(self._model) == len(X_test) + ), "Model is optimized for horizon, length of X_test must be equal to `period`." + preds = [] + for i in range(1, len(self._model) + 1): + X_pred, _ = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X_test.iloc[:i, :]) + preds.append(self._model[i - 1].predict(X_pred)[-1]) + forecast = pd.DataFrame(data=np.asarray(preds).reshape(-1, 1), + columns=[self.hcrystaball_model.name], + index=X_test.index) + else: + X_pred, _ = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X_test) + forecast = self._model.predict(X_pred) + return forecast + else: + logger.warning( + "Estimator is not fit yet. Please run fit() before predict()." + ) + return np.ones(X_test.shape[0]) + + +class LGBM_TS_Regressor(TS_SKLearn_Regressor): + """ The class for tuning LGBM Regressor for time-series forecasting""" + + base_class = LGBMEstimator + + +class XGBoost_TS_Regressor(TS_SKLearn_Regressor): + """ The class for tuning XGBoost Regressor for time-series forecasting""" + + base_class = XGBoostSklearnEstimator + +# catboost regressor is invalid because it has a `name` parameter, making it incompatible with hcrystalball +# class CatBoost_TS_Regressor(TS_Regressor): +# base_class = CatBoostEstimator + + +class RF_TS_Regressor(TS_SKLearn_Regressor): + """ The class for tuning Random Forest Regressor for time-series forecasting""" + + base_class = RandomForestEstimator + + +class ExtraTrees_TS_Regressor(TS_SKLearn_Regressor): + """ The class for tuning Extra Trees Regressor for time-series forecasting""" + + base_class = ExtraTreesEstimator + + +class XGBoostLimitDepth_TS_Regressor(TS_SKLearn_Regressor): + """ The class for tuning XGBoost Regressor with unlimited depth for time-series forecasting""" + + base_class = XGBoostLimitDepthEstimator + + class suppress_stdout_stderr(object): def __init__(self): # Open a pair of null files diff --git a/setup.py b/setup.py index d82472dc2..35296c65c 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,7 @@ setuptools.setup( "torch", "nltk", "rouge_score", + "hcrystalball==0.1.10", "seqeval", ], "catboost": ["catboost>=0.26"], @@ -85,8 +86,8 @@ setuptools.setup( "nltk", "rouge_score", ], - "ts_forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"], - "forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"], + "ts_forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2", "hcrystalball==0.1.10"], + "forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2", "hcrystalball==0.1.10"], "benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"], }, classifiers=[ diff --git a/test/automl/test_forecast.py b/test/automl/test_forecast.py index 5befb5307..797e782d3 100644 --- a/test/automl/test_forecast.py +++ b/test/automl/test_forecast.py @@ -105,6 +105,7 @@ def test_numpy(): task="ts_forecast", time_budget=3, # time budget in seconds log_file_name="test/ts_forecast.log", + n_splits=3, # number of splits ) print(automl.predict(X_train[72:])) except ImportError: @@ -280,7 +281,6 @@ def load_multi_dataset_cat(time_horizon): def test_multivariate_forecast_cat(budget=5): time_horizon = 180 train_df, test_df = load_multi_dataset_cat(time_horizon) - print(train_df) X_test = test_df[ ["timeStamp", "season", "above_monthly_avg"] ] # test dataframe must contain values for the regressors / multivariate variables @@ -290,7 +290,7 @@ def test_multivariate_forecast_cat(budget=5): "time_budget": budget, # total running time in seconds "metric": "mape", # primary metric "task": "ts_forecast", # task type - "log_file_name": "test/energy_forecast_numerical.log", # flaml log file + "log_file_name": "test/energy_forecast_categorical.log", # flaml log file "eval_method": "holdout", "log_type": "all", "label": "demand", @@ -360,3 +360,4 @@ if __name__ == "__main__": test_forecast_automl(60) test_multivariate_forecast_num(60) test_multivariate_forecast_cat(60) + test_numpy()