mirror of
https://github.com/microsoft/autogen.git
synced 2025-10-28 16:29:39 +00:00
enable ensemble when using ray (#583)
* enable ensemble when using ray * sanitize config
This commit is contained in:
parent
0642b6e7bb
commit
f8cc38bc16
@ -385,6 +385,15 @@ class AutoMLState:
|
|||||||
tune.report(**result)
|
tune.report(**result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def sanitize(self, config: dict) -> dict:
|
||||||
|
"""Make a config ready for passing to estimator."""
|
||||||
|
config = config.get("ml", config).copy()
|
||||||
|
if "FLAML_sample_size" in config:
|
||||||
|
del config["FLAML_sample_size"]
|
||||||
|
if "learner" in config:
|
||||||
|
del config["learner"]
|
||||||
|
return config
|
||||||
|
|
||||||
def _train_with_config(
|
def _train_with_config(
|
||||||
self,
|
self,
|
||||||
estimator,
|
estimator,
|
||||||
@ -395,11 +404,7 @@ class AutoMLState:
|
|||||||
sample_size = config_w_resource.get(
|
sample_size = config_w_resource.get(
|
||||||
"FLAML_sample_size", len(self.y_train_all)
|
"FLAML_sample_size", len(self.y_train_all)
|
||||||
)
|
)
|
||||||
config = config_w_resource.get("ml", config_w_resource).copy()
|
config = self.sanitize(config_w_resource)
|
||||||
if "FLAML_sample_size" in config:
|
|
||||||
del config["FLAML_sample_size"]
|
|
||||||
if "learner" in config:
|
|
||||||
del config["learner"]
|
|
||||||
|
|
||||||
this_estimator_kwargs = self.fit_kwargs_by_estimator.get(
|
this_estimator_kwargs = self.fit_kwargs_by_estimator.get(
|
||||||
estimator
|
estimator
|
||||||
@ -3203,7 +3208,7 @@ class AutoML(BaseEstimator):
|
|||||||
x[1].learner_class(
|
x[1].learner_class(
|
||||||
task=self._state.task,
|
task=self._state.task,
|
||||||
n_jobs=self._state.n_jobs,
|
n_jobs=self._state.n_jobs,
|
||||||
**x[1].best_config,
|
**self._state.sanitize(x[1].best_config),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for x in search_states[:2]
|
for x in search_states[:2]
|
||||||
@ -3214,13 +3219,15 @@ class AutoML(BaseEstimator):
|
|||||||
x[1].learner_class(
|
x[1].learner_class(
|
||||||
task=self._state.task,
|
task=self._state.task,
|
||||||
n_jobs=self._state.n_jobs,
|
n_jobs=self._state.n_jobs,
|
||||||
**x[1].best_config,
|
**self._state.sanitize(x[1].best_config),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for x in search_states[2:]
|
for x in search_states[2:]
|
||||||
if x[1].best_loss < 4 * self._selected.best_loss
|
if x[1].best_loss < 4 * self._selected.best_loss
|
||||||
]
|
]
|
||||||
logger.info(estimators)
|
logger.info(
|
||||||
|
[(estimator[0], estimator[1].params) for estimator in estimators]
|
||||||
|
)
|
||||||
if len(estimators) > 1:
|
if len(estimators) > 1:
|
||||||
if self._state.task in CLASSIFICATION:
|
if self._state.task in CLASSIFICATION:
|
||||||
from sklearn.ensemble import StackingClassifier as Stacker
|
from sklearn.ensemble import StackingClassifier as Stacker
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
__version__ = "1.0.6"
|
__version__ = "1.0.7"
|
||||||
|
|||||||
@ -256,6 +256,7 @@ class TestClassification(unittest.TestCase):
|
|||||||
time_budget=10,
|
time_budget=10,
|
||||||
task="classification",
|
task="classification",
|
||||||
n_concurrent_trials=2,
|
n_concurrent_trials=2,
|
||||||
|
ensemble=True,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return
|
return
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user