mirror of
https://github.com/microsoft/autogen.git
synced 2025-10-27 15:59:35 +00:00
try import catboost (#197)
This commit is contained in:
parent
71219df6c6
commit
8f9f08cebc
@ -84,7 +84,10 @@ class SearchState:
|
||||
self.cat_hp_cost[name] = space["cat_hp_cost"]
|
||||
# if a starting point is provided, set the init config to be
|
||||
# the starting point provided
|
||||
if isinstance(starting_point, dict) and starting_point.get(name) is not None:
|
||||
if (
|
||||
isinstance(starting_point, dict)
|
||||
and starting_point.get(name) is not None
|
||||
):
|
||||
self.init_config[name] = starting_point[name]
|
||||
if isinstance(starting_point, list):
|
||||
self.init_config = starting_point
|
||||
@ -1143,9 +1146,9 @@ class AutoML:
|
||||
else:
|
||||
configs = [self._search_states[estimator].init_config]
|
||||
for config in configs:
|
||||
config['learner'] = estimator
|
||||
config["learner"] = estimator
|
||||
if len(self.estimator_list) > 1:
|
||||
points.append({'ml': config})
|
||||
points.append({"ml": config})
|
||||
else:
|
||||
points.append(config)
|
||||
return points
|
||||
@ -1475,7 +1478,12 @@ class AutoML:
|
||||
elif self._state.task == "rank":
|
||||
estimator_list = ["lgbm", "xgboost"]
|
||||
else:
|
||||
estimator_list = ["lgbm", "rf", "catboost", "xgboost", "extra_tree"]
|
||||
try:
|
||||
import catboost
|
||||
|
||||
estimator_list = ["lgbm", "rf", "catboost", "xgboost", "extra_tree"]
|
||||
except ImportError:
|
||||
estimator_list = ["lgbm", "rf", "xgboost", "extra_tree"]
|
||||
if "regression" != self._state.task:
|
||||
estimator_list += ["lrl1"]
|
||||
for estimator_name in estimator_list:
|
||||
@ -1765,8 +1773,11 @@ class AutoML:
|
||||
self._max_iter_per_learner = len(points_to_evaluate)
|
||||
low_cost_partial_config = None
|
||||
else:
|
||||
points_to_evaluate = search_state.init_config if isinstance(
|
||||
search_state.init_config, list) else [search_state.init_config]
|
||||
points_to_evaluate = (
|
||||
search_state.init_config
|
||||
if isinstance(search_state.init_config, list)
|
||||
else [search_state.init_config]
|
||||
)
|
||||
low_cost_partial_config = search_state.low_cost_partial_config
|
||||
if self._hpo_method in ("bs", "cfo", "grid", "cfocat"):
|
||||
algo = SearchAlgo(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user