mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-26 08:34:10 +00:00
Disable shuffle for custom CV (#659)
* Disable shuffle for custom CV * Add custom fold shuffle test * Update test_split.py * Update test_split.py
This commit is contained in:
parent
ca9f9054e7
commit
e43485607a
@ -459,7 +459,7 @@ def evaluate_model_CV(
|
|||||||
"label_list"
|
"label_list"
|
||||||
) # pass the label list on to compute the evaluation metric
|
) # pass the label list on to compute the evaluation metric
|
||||||
groups = None
|
groups = None
|
||||||
shuffle = False if task in TS_FORECAST else True
|
shuffle = getattr(kf, "shuffle", task not in TS_FORECAST)
|
||||||
if isinstance(kf, RepeatedStratifiedKFold):
|
if isinstance(kf, RepeatedStratifiedKFold):
|
||||||
kf = kf.split(X_train_split, y_train_split)
|
kf = kf.split(X_train_split, y_train_split)
|
||||||
elif isinstance(kf, GroupKFold):
|
elif isinstance(kf, GroupKFold):
|
||||||
|
@ -174,6 +174,11 @@ def test_object():
|
|||||||
automl._state.eval_method == "cv"
|
automl._state.eval_method == "cv"
|
||||||
), "eval_method must be 'cv' for custom data splitter"
|
), "eval_method must be 'cv' for custom data splitter"
|
||||||
|
|
||||||
|
kf = TestKFold(5)
|
||||||
|
kf.shuffle = True
|
||||||
|
automl_settings["split_type"] = kf
|
||||||
|
automl.fit(X, y, **automl_settings)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_groups()
|
test_groups()
|
||||||
|
@ -364,7 +364,7 @@ For both classification and regression, time-based split can be enforced if the
|
|||||||
|
|
||||||
When `eval_method="cv"`, `split_type` can also be set as a custom splitter. It needs to be an instance of a derived class of scikit-learn
|
When `eval_method="cv"`, `split_type` can also be set as a custom splitter. It needs to be an instance of a derived class of scikit-learn
|
||||||
[KFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold)
|
[KFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold)
|
||||||
and have ``split`` and ``get_n_splits`` methods with the same signatures.
|
and have ``split`` and ``get_n_splits`` methods with the same signatures. To disable shuffling, the splitter instance must contain the attribute `shuffle=False`.
|
||||||
|
|
||||||
### Parallel tuning
|
### Parallel tuning
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user