diff --git a/flaml/automl.py b/flaml/automl.py index b51944f07..17931a1de 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -385,6 +385,15 @@ class AutoMLState: tune.report(**result) return result + def sanitize(self, config: dict) -> dict: + """Make a config ready for passing to estimator.""" + config = config.get("ml", config).copy() + if "FLAML_sample_size" in config: + del config["FLAML_sample_size"] + if "learner" in config: + del config["learner"] + return config + def _train_with_config( self, estimator, @@ -395,11 +404,7 @@ class AutoMLState: sample_size = config_w_resource.get( "FLAML_sample_size", len(self.y_train_all) ) - config = config_w_resource.get("ml", config_w_resource).copy() - if "FLAML_sample_size" in config: - del config["FLAML_sample_size"] - if "learner" in config: - del config["learner"] + config = self.sanitize(config_w_resource) this_estimator_kwargs = self.fit_kwargs_by_estimator.get( estimator @@ -3203,7 +3208,7 @@ class AutoML(BaseEstimator): x[1].learner_class( task=self._state.task, n_jobs=self._state.n_jobs, - **x[1].best_config, + **self._state.sanitize(x[1].best_config), ), ) for x in search_states[:2] @@ -3214,13 +3219,15 @@ class AutoML(BaseEstimator): x[1].learner_class( task=self._state.task, n_jobs=self._state.n_jobs, - **x[1].best_config, + **self._state.sanitize(x[1].best_config), ), ) for x in search_states[2:] if x[1].best_loss < 4 * self._selected.best_loss ] - logger.info(estimators) + logger.info( + [(estimator[0], estimator[1].params) for estimator in estimators] + ) if len(estimators) > 1: if self._state.task in CLASSIFICATION: from sklearn.ensemble import StackingClassifier as Stacker diff --git a/flaml/version.py b/flaml/version.py index 382021f30..9e604c040 100644 --- a/flaml/version.py +++ b/flaml/version.py @@ -1 +1 @@ -__version__ = "1.0.6" +__version__ = "1.0.7" diff --git a/test/automl/test_classification.py b/test/automl/test_classification.py index 69cd019f4..6eac50a57 100644 --- a/test/automl/test_classification.py +++ b/test/automl/test_classification.py @@ -256,6 +256,7 @@ class TestClassification(unittest.TestCase): time_budget=10, task="classification", n_concurrent_trials=2, + ensemble=True, ) except ImportError: return