try import catboost (#197)

This commit is contained in:
Chi Wang 2021-09-10 20:09:08 -07:00 committed by GitHub
parent 71219df6c6
commit 8f9f08cebc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -84,7 +84,10 @@ class SearchState:
self.cat_hp_cost[name] = space["cat_hp_cost"]
# if a starting point is provided, set the init config to be
# the starting point provided
if isinstance(starting_point, dict) and starting_point.get(name) is not None:
if (
isinstance(starting_point, dict)
and starting_point.get(name) is not None
):
self.init_config[name] = starting_point[name]
if isinstance(starting_point, list):
self.init_config = starting_point
@ -1143,9 +1146,9 @@ class AutoML:
else:
configs = [self._search_states[estimator].init_config]
for config in configs:
config['learner'] = estimator
config["learner"] = estimator
if len(self.estimator_list) > 1:
points.append({'ml': config})
points.append({"ml": config})
else:
points.append(config)
return points
@ -1475,7 +1478,12 @@ class AutoML:
elif self._state.task == "rank":
estimator_list = ["lgbm", "xgboost"]
else:
estimator_list = ["lgbm", "rf", "catboost", "xgboost", "extra_tree"]
try:
import catboost
estimator_list = ["lgbm", "rf", "catboost", "xgboost", "extra_tree"]
except ImportError:
estimator_list = ["lgbm", "rf", "xgboost", "extra_tree"]
if "regression" != self._state.task:
estimator_list += ["lrl1"]
for estimator_name in estimator_list:
@ -1765,8 +1773,11 @@ class AutoML:
self._max_iter_per_learner = len(points_to_evaluate)
low_cost_partial_config = None
else:
points_to_evaluate = search_state.init_config if isinstance(
search_state.init_config, list) else [search_state.init_config]
points_to_evaluate = (
search_state.init_config
if isinstance(search_state.init_config, list)
else [search_state.init_config]
)
low_cost_partial_config = search_state.low_cost_partial_config
if self._hpo_method in ("bs", "cfo", "grid", "cfocat"):
algo = SearchAlgo(