From 74cca60606f43c885a177d3a2eb4e47828ba36a1 Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Wed, 29 Jun 2022 21:04:25 -0700 Subject: [PATCH] Allow custom GroupKFold object as split_type (#616) * Allow custom GroupKFold object * handle unpickle error for prophet 1.1 * eval_method in test_object() --- flaml/automl.py | 36 ++++++++++++++++++++++++++++-------- test/automl/test_score.py | 8 +++++--- test/automl/test_split.py | 19 +++++++++++++++++-- 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/flaml/automl.py b/flaml/automl.py index 412aee696..669fa5c24 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -1404,7 +1404,6 @@ class AutoML(BaseEstimator): len(np.unique(self._state.groups_all)) >= n_splits ), "the number of groups must be equal or larger than n_splits" self._state.kf = GroupKFold(n_splits) - self._state.kf.groups = self._state.groups_all elif self._split_type == "stratified": # logger.info("Using StratifiedKFold") assert y_train_all.size >= n_splits, ( @@ -1442,6 +1441,9 @@ class AutoML(BaseEstimator): else: # logger.info("Using splitter object") self._state.kf = self._split_type + if isinstance(self._state.kf, GroupKFold): + # self._split_type is either "group" or a GroupKFold object + self._state.kf.groups = self._state.groups_all def add_learner(self, learner_name, learner_class): """Add a customized learner. @@ -1681,10 +1683,7 @@ class AutoML(BaseEstimator): # Partially copied from fit() function # Initilize some attributes required for retrain_from_log self._decide_split_type(split_type) - if record_id >= 0: - eval_method = "cv" - elif eval_method == "auto": - eval_method = self._decide_eval_method(time_budget) + eval_method = self._decide_eval_method(eval_method, time_budget) self.modelcount = 0 self._auto_augment = auto_augment self._prepare_data(eval_method, split_ratio, n_splits) @@ -1717,6 +1716,9 @@ class AutoML(BaseEstimator): assert hasattr(split_type, "split") and hasattr( split_type, "get_n_splits" ), "split_type must be a string or a splitter object with split and get_n_splits methods." + assert ( + not isinstance(split_type, GroupKFold) or self._state.groups is not None + ), "GroupKFold requires groups to be provided." self._split_type = split_type elif self._state.task in CLASSIFICATION: assert split_type in ["auto", "stratified", "uniform", "time", "group"] @@ -1746,9 +1748,28 @@ class AutoML(BaseEstimator): assert split_type in ["auto", "uniform", "time", "group"] self._split_type = split_type if split_type != "auto" else "uniform" - def _decide_eval_method(self, time_budget): + def _decide_eval_method(self, eval_method, time_budget): + if not isinstance(self._split_type, str): + assert eval_method in [ + "auto", + "cv", + ], "eval_method must be 'auto' or 'cv' for custom data splitter." + assert ( + self._state.X_val is None + ), "custom splitter and custom validation data can't be used together." + return "cv" if self._state.X_val is not None: + assert eval_method in [ + "auto", + "holdout", + ], "eval_method must be 'auto' or 'holdout' for custom validation data." return "holdout" + if eval_method != "auto": + assert eval_method in [ + "holdout", + "cv", + ], "eval_method must be 'holdout', 'cv' or 'auto'." + return eval_method nrow, dim = self._nrow, self._ndim if ( time_budget is None @@ -2390,8 +2411,7 @@ class AutoML(BaseEstimator): logger.info(f"task = {task}") self._decide_split_type(split_type) logger.info(f"Data split method: {self._split_type}") - if eval_method == "auto" or self._state.X_val is not None: - eval_method = self._decide_eval_method(time_budget) + eval_method = self._decide_eval_method(eval_method, time_budget) self._state.eval_method = eval_method logger.info("Evaluation method: {}".format(eval_method)) diff --git a/test/automl/test_score.py b/test/automl/test_score.py index c9b879066..3fd94a24b 100644 --- a/test/automl/test_score.py +++ b/test/automl/test_score.py @@ -46,9 +46,11 @@ class TestScore: automl.score(X_test, y_test) automl.pickle("automl.pkl") with open("automl.pkl", "rb") as f: - pickle.load(f) - except ImportError: - print("not using prophet due to ImportError") + pickle.load(f) # v1.1 of prophet raises RecursionError + except (ImportError, RecursionError): + print( + "not using prophet due to ImportError or RecursionError (when unpickling in v1.1)" + ) automl.fit( dataframe=df, **settings, diff --git a/test/automl/test_split.py b/test/automl/test_split.py index be18e44f2..b40631cb2 100644 --- a/test/automl/test_split.py +++ b/test/automl/test_split.py @@ -1,6 +1,6 @@ from sklearn.datasets import fetch_openml from flaml.automl import AutoML -from sklearn.model_selection import train_test_split, KFold +from sklearn.model_selection import GroupKFold, train_test_split, KFold from sklearn.metrics import accuracy_score @@ -80,6 +80,19 @@ def test_groups(): automl_settings["eval_method"] = "holdout" automl.fit(X, y, **automl_settings) + automl_settings["split_type"] = GroupKFold(n_splits=3) + try: + automl.fit(X, y, **automl_settings) + raise RuntimeError( + "GroupKFold object as split_type should fail when eval_method is holdout" + ) + except AssertionError: + # eval_method must be 'auto' or 'cv' for custom data splitter. + pass + + automl_settings["eval_method"] = "cv" + automl.fit(X, y, **automl_settings) + def test_rank(): from sklearn.externals._arff import ArffException @@ -150,7 +163,6 @@ def test_object(): automl = AutoML() automl_settings = { "time_budget": 2, - # "metric": 'accuracy', "task": "classification", "log_file_name": "test/{}.log".format(dataset), "model_history": True, @@ -158,6 +170,9 @@ def test_object(): "split_type": TestKFold(5), } automl.fit(X, y, **automl_settings) + assert ( + automl._state.eval_method == "cv" + ), "eval_method must be 'cv' for custom data splitter" if __name__ == "__main__":