diff --git a/flaml/automl.py b/flaml/automl.py index fddf43540..95dfeb8ca 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -507,7 +507,11 @@ class AutoML(BaseEstimator): True - retrain only after search finishes; False - no retraining; 'budget' - do best effort to retrain without violating the time budget. - split_type: str, default="auto" | the data split type. + split_type: str or splitter object, default="auto" | the data split type. + A valid splitter object is an instance of a derived class of scikit-learn KFold + (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold) + and have ``split`` and ``get_n_splits`` methods with the same signatures. + Valid str options depend on different tasks. For classification tasks, valid choices are [ "auto", 'stratified', 'uniform', 'time']. "auto" -> stratified. For regression tasks, valid choices are ["auto", 'uniform', 'time']. @@ -955,7 +959,7 @@ class AutoML(BaseEstimator): self._state.task in CLASSIFICATION and self._auto_augment and self._state.fit_kwargs.get("sample_weight") is None - and self._split_type not in ["time", "group"] + and self._split_type in ["stratified", "uniform"] ): # logger.info(f"label {pd.unique(y_train_all)}") label_set, counts = np.unique(y_train_all, return_counts=True) @@ -1183,11 +1187,14 @@ class AutoML(BaseEstimator): self._state.kf = TimeSeriesSplit(n_splits=n_splits, test_size=period) else: self._state.kf = TimeSeriesSplit(n_splits=n_splits) - else: + elif isinstance(self._split_type, str): # logger.info("Using RepeatedKFold") self._state.kf = RepeatedKFold( n_splits=n_splits, n_repeats=1, random_state=RANDOM_SEED ) + else: + # logger.info("Using splitter object") + self._state.kf = self._split_type def add_learner(self, learner_name, learner_class): """Add a customized learner. @@ -1277,7 +1284,11 @@ class AutoML(BaseEstimator): ['auto', 'cv', 'holdout']. split_ratio: A float of the validation data percentage for holdout. n_splits: An integer of the number of folds for cross-validation. - split_type: str, default="auto" | the data split type. + split_type: str or splitter object, default="auto" | the data split type. + A valid splitter object is an instance of a derived class of scikit-learn KFold + (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold) + and have ``split`` and ``get_n_splits`` methods with the same signatures. + Valid str options depend on different tasks. For classification tasks, valid choices are [ "auto", 'stratified', 'uniform', 'time', 'group']. "auto" -> stratified. For regression tasks, valid choices are ["auto", 'uniform', 'time']. @@ -1399,7 +1410,12 @@ class AutoML(BaseEstimator): self._state.task = get_classification_objective( len(np.unique(self._y_train_all)) ) - if self._state.task in CLASSIFICATION: + if not isinstance(split_type, str): + assert hasattr(split_type, "split") and hasattr( + split_type, "get_n_splits" + ), "split_type must be a string or a splitter object with split and get_n_splits methods." + self._split_type = split_type + elif self._state.task in CLASSIFICATION: assert split_type in ["auto", "stratified", "uniform", "time", "group"] self._split_type = ( split_type @@ -1786,7 +1802,11 @@ class AutoML(BaseEstimator): True - retrain only after search finishes; False - no retraining; 'budget' - do best effort to retrain without violating the time budget. - split_type: str, default="auto" | the data split type. + split_type: str or splitter object, default="auto" | the data split type. + A valid splitter object is an instance of a derived class of scikit-learn KFold + (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold) + and have ``split`` and ``get_n_splits`` methods with the same signatures. + Valid str options depend on different tasks. For classification tasks, valid choices are [ "auto", 'stratified', 'uniform', 'time']. "auto" -> stratified. For regression tasks, valid choices are ["auto", 'uniform', 'time']. diff --git a/flaml/ml.py b/flaml/ml.py index 145b045d0..480eb42db 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -412,7 +412,7 @@ def evaluate_model_CV( else: labels = None groups = None - shuffle = True + shuffle = False if task == TS_FORECAST else True if isinstance(kf, RepeatedStratifiedKFold): kf = kf.split(X_train_split, y_train_split) elif isinstance(kf, GroupKFold): @@ -423,7 +423,6 @@ def evaluate_model_CV( y_train_all = pd.DataFrame(y_train_all, columns=[TS_VALUE_COL]) train = X_train_all.join(y_train_all) kf = kf.split(train) - shuffle = False elif isinstance(kf, TimeSeriesSplit): kf = kf.split(X_train_split, y_train_split) else: diff --git a/test/automl/test_split.py b/test/automl/test_split.py index 067c0c1af..3a81ad3e3 100644 --- a/test/automl/test_split.py +++ b/test/automl/test_split.py @@ -2,7 +2,7 @@ import unittest from sklearn.datasets import fetch_openml from flaml.automl import AutoML -from sklearn.model_selection import train_test_split +from sklearn.model_selection import train_test_split, KFold from sklearn.metrics import accuracy_score @@ -123,6 +123,45 @@ def test_rank(): automl.fit(X, y, **automl_settings) +def test_object(): + from sklearn.externals._arff import ArffException + + try: + X, y = fetch_openml(name=dataset, return_X_y=True) + except (ArffException, ValueError): + from sklearn.datasets import load_wine + + X, y = load_wine(return_X_y=True) + + import numpy as np + + class TestKFold(KFold): + def __init__(self, n_splits): + self.n_splits = int(n_splits) + + def split(self, X): + rng = np.random.default_rng() + train_num = int(len(X) * 0.8) + for _ in range(self.n_splits): + permu_idx = rng.permutation(len(X)) + yield permu_idx[:train_num], permu_idx[train_num:] + + def get_n_splits(self, X=None, y=None, groups=None): + return self.n_splits + + automl = AutoML() + automl_settings = { + "time_budget": 2, + # "metric": 'accuracy', + "task": "classification", + "log_file_name": "test/{}.log".format(dataset), + "model_history": True, + "log_training_metric": True, + "split_type": TestKFold(5), + } + automl.fit(X, y, **automl_settings) + + if __name__ == "__main__": # unittest.main() test_groups()