mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 15:09:41 +00:00
Allow custom GroupKFold object as split_type (#616)
* Allow custom GroupKFold object * handle unpickle error for prophet 1.1 * eval_method in test_object()
This commit is contained in:
parent
cbb85e2aab
commit
74cca60606
@ -1404,7 +1404,6 @@ class AutoML(BaseEstimator):
|
||||
len(np.unique(self._state.groups_all)) >= n_splits
|
||||
), "the number of groups must be equal or larger than n_splits"
|
||||
self._state.kf = GroupKFold(n_splits)
|
||||
self._state.kf.groups = self._state.groups_all
|
||||
elif self._split_type == "stratified":
|
||||
# logger.info("Using StratifiedKFold")
|
||||
assert y_train_all.size >= n_splits, (
|
||||
@ -1442,6 +1441,9 @@ class AutoML(BaseEstimator):
|
||||
else:
|
||||
# logger.info("Using splitter object")
|
||||
self._state.kf = self._split_type
|
||||
if isinstance(self._state.kf, GroupKFold):
|
||||
# self._split_type is either "group" or a GroupKFold object
|
||||
self._state.kf.groups = self._state.groups_all
|
||||
|
||||
def add_learner(self, learner_name, learner_class):
|
||||
"""Add a customized learner.
|
||||
@ -1681,10 +1683,7 @@ class AutoML(BaseEstimator):
|
||||
# Partially copied from fit() function
|
||||
# Initilize some attributes required for retrain_from_log
|
||||
self._decide_split_type(split_type)
|
||||
if record_id >= 0:
|
||||
eval_method = "cv"
|
||||
elif eval_method == "auto":
|
||||
eval_method = self._decide_eval_method(time_budget)
|
||||
eval_method = self._decide_eval_method(eval_method, time_budget)
|
||||
self.modelcount = 0
|
||||
self._auto_augment = auto_augment
|
||||
self._prepare_data(eval_method, split_ratio, n_splits)
|
||||
@ -1717,6 +1716,9 @@ class AutoML(BaseEstimator):
|
||||
assert hasattr(split_type, "split") and hasattr(
|
||||
split_type, "get_n_splits"
|
||||
), "split_type must be a string or a splitter object with split and get_n_splits methods."
|
||||
assert (
|
||||
not isinstance(split_type, GroupKFold) or self._state.groups is not None
|
||||
), "GroupKFold requires groups to be provided."
|
||||
self._split_type = split_type
|
||||
elif self._state.task in CLASSIFICATION:
|
||||
assert split_type in ["auto", "stratified", "uniform", "time", "group"]
|
||||
@ -1746,9 +1748,28 @@ class AutoML(BaseEstimator):
|
||||
assert split_type in ["auto", "uniform", "time", "group"]
|
||||
self._split_type = split_type if split_type != "auto" else "uniform"
|
||||
|
||||
def _decide_eval_method(self, time_budget):
|
||||
def _decide_eval_method(self, eval_method, time_budget):
|
||||
if not isinstance(self._split_type, str):
|
||||
assert eval_method in [
|
||||
"auto",
|
||||
"cv",
|
||||
], "eval_method must be 'auto' or 'cv' for custom data splitter."
|
||||
assert (
|
||||
self._state.X_val is None
|
||||
), "custom splitter and custom validation data can't be used together."
|
||||
return "cv"
|
||||
if self._state.X_val is not None:
|
||||
assert eval_method in [
|
||||
"auto",
|
||||
"holdout",
|
||||
], "eval_method must be 'auto' or 'holdout' for custom validation data."
|
||||
return "holdout"
|
||||
if eval_method != "auto":
|
||||
assert eval_method in [
|
||||
"holdout",
|
||||
"cv",
|
||||
], "eval_method must be 'holdout', 'cv' or 'auto'."
|
||||
return eval_method
|
||||
nrow, dim = self._nrow, self._ndim
|
||||
if (
|
||||
time_budget is None
|
||||
@ -2390,8 +2411,7 @@ class AutoML(BaseEstimator):
|
||||
logger.info(f"task = {task}")
|
||||
self._decide_split_type(split_type)
|
||||
logger.info(f"Data split method: {self._split_type}")
|
||||
if eval_method == "auto" or self._state.X_val is not None:
|
||||
eval_method = self._decide_eval_method(time_budget)
|
||||
eval_method = self._decide_eval_method(eval_method, time_budget)
|
||||
self._state.eval_method = eval_method
|
||||
logger.info("Evaluation method: {}".format(eval_method))
|
||||
|
||||
|
||||
@ -46,9 +46,11 @@ class TestScore:
|
||||
automl.score(X_test, y_test)
|
||||
automl.pickle("automl.pkl")
|
||||
with open("automl.pkl", "rb") as f:
|
||||
pickle.load(f)
|
||||
except ImportError:
|
||||
print("not using prophet due to ImportError")
|
||||
pickle.load(f) # v1.1 of prophet raises RecursionError
|
||||
except (ImportError, RecursionError):
|
||||
print(
|
||||
"not using prophet due to ImportError or RecursionError (when unpickling in v1.1)"
|
||||
)
|
||||
automl.fit(
|
||||
dataframe=df,
|
||||
**settings,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from sklearn.datasets import fetch_openml
|
||||
from flaml.automl import AutoML
|
||||
from sklearn.model_selection import train_test_split, KFold
|
||||
from sklearn.model_selection import GroupKFold, train_test_split, KFold
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
|
||||
@ -80,6 +80,19 @@ def test_groups():
|
||||
automl_settings["eval_method"] = "holdout"
|
||||
automl.fit(X, y, **automl_settings)
|
||||
|
||||
automl_settings["split_type"] = GroupKFold(n_splits=3)
|
||||
try:
|
||||
automl.fit(X, y, **automl_settings)
|
||||
raise RuntimeError(
|
||||
"GroupKFold object as split_type should fail when eval_method is holdout"
|
||||
)
|
||||
except AssertionError:
|
||||
# eval_method must be 'auto' or 'cv' for custom data splitter.
|
||||
pass
|
||||
|
||||
automl_settings["eval_method"] = "cv"
|
||||
automl.fit(X, y, **automl_settings)
|
||||
|
||||
|
||||
def test_rank():
|
||||
from sklearn.externals._arff import ArffException
|
||||
@ -150,7 +163,6 @@ def test_object():
|
||||
automl = AutoML()
|
||||
automl_settings = {
|
||||
"time_budget": 2,
|
||||
# "metric": 'accuracy',
|
||||
"task": "classification",
|
||||
"log_file_name": "test/{}.log".format(dataset),
|
||||
"model_history": True,
|
||||
@ -158,6 +170,9 @@ def test_object():
|
||||
"split_type": TestKFold(5),
|
||||
}
|
||||
automl.fit(X, y, **automl_settings)
|
||||
assert (
|
||||
automl._state.eval_method == "cv"
|
||||
), "eval_method must be 'cv' for custom data splitter"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user