diff --git a/flaml/automl/automl.py b/flaml/automl/automl.py index 1b6bef199..42576a8a2 100644 --- a/flaml/automl/automl.py +++ b/flaml/automl/automl.py @@ -17,6 +17,7 @@ from sklearn.model_selection import ( GroupKFold, TimeSeriesSplit, GroupShuffleSplit, + StratifiedGroupKFold, ) from sklearn.utils import shuffle from sklearn.base import BaseEstimator @@ -1575,8 +1576,8 @@ class AutoML(BaseEstimator): else: # logger.info("Using splitter object") self._state.kf = self._split_type - if isinstance(self._state.kf, GroupKFold): - # self._split_type is either "group" or a GroupKFold object + if isinstance(self._state.kf, (GroupKFold, StratifiedGroupKFold)): + # self._split_type is either "group", a GroupKFold object, or a StratifiedGroupKFold object self._state.kf.groups = self._state.groups_all def add_learner(self, learner_name, learner_class): diff --git a/flaml/automl/ml.py b/flaml/automl/ml.py index 6285bc29c..dd17cffea 100644 --- a/flaml/automl/ml.py +++ b/flaml/automl/ml.py @@ -17,7 +17,12 @@ from sklearn.metrics import ( mean_absolute_percentage_error, ndcg_score, ) -from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit +from sklearn.model_selection import ( + RepeatedStratifiedKFold, + GroupKFold, + TimeSeriesSplit, + StratifiedGroupKFold, +) from flaml.automl.model import ( XGBoostSklearnEstimator, XGBoost_TS, @@ -517,7 +522,7 @@ def evaluate_model_CV( shuffle = getattr(kf, "shuffle", task not in TS_FORECAST) if isinstance(kf, RepeatedStratifiedKFold): kf = kf.split(X_train_split, y_train_split) - elif isinstance(kf, GroupKFold): + elif isinstance(kf, (GroupKFold, StratifiedGroupKFold)): groups = kf.groups kf = kf.split(X_train_split, y_train_split, groups) shuffle = False @@ -548,8 +553,16 @@ def evaluate_model_CV( weight[val_index], ) if groups is not None: - fit_kwargs["groups"] = groups[train_index] - groups_val = groups[val_index] + fit_kwargs["groups"] = ( + groups[train_index] + if isinstance(groups, np.ndarray) + else groups.iloc[train_index] + ) + groups_val = ( + groups[val_index] + if isinstance(groups, np.ndarray) + else groups.iloc[val_index] + ) else: groups_val = None val_loss_i, metric_i, train_time_i, pred_time_i = get_val_loss( diff --git a/test/automl/test_split.py b/test/automl/test_split.py index 7eb8c7b50..9223c520a 100644 --- a/test/automl/test_split.py +++ b/test/automl/test_split.py @@ -94,6 +94,33 @@ def test_groups(): automl.fit(X, y, **automl_settings) +def test_stratified_groupkfold(): + from sklearn.model_selection import StratifiedGroupKFold + from flaml.data import load_openml_dataset + + X_train, _, y_train, _ = load_openml_dataset(dataset_id=1169, data_dir="test/") + splitter = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0) + + automl = AutoML() + settings = { + "time_budget": 6, + "metric": "ap", + "eval_method": "cv", + "split_type": splitter, + "groups": X_train["Airline"], + "estimator_list": [ + "lgbm", + "rf", + "xgboost", + "extra_tree", + "xgb_limitdepth", + "lrl1", + ], + } + + automl.fit(X_train=X_train, y_train=y_train, **settings) + + def test_rank(): from sklearn.externals._arff import ArffException