From 8f9f08cebcd892af723da3c9023a0bd48efcdfd9 Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Fri, 10 Sep 2021 20:09:08 -0700 Subject: [PATCH] try import catboost (#197) --- flaml/automl.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/flaml/automl.py b/flaml/automl.py index 31ef63ac4..621af4200 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -84,7 +84,10 @@ class SearchState: self.cat_hp_cost[name] = space["cat_hp_cost"] # if a starting point is provided, set the init config to be # the starting point provided - if isinstance(starting_point, dict) and starting_point.get(name) is not None: + if ( + isinstance(starting_point, dict) + and starting_point.get(name) is not None + ): self.init_config[name] = starting_point[name] if isinstance(starting_point, list): self.init_config = starting_point @@ -1143,9 +1146,9 @@ class AutoML: else: configs = [self._search_states[estimator].init_config] for config in configs: - config['learner'] = estimator + config["learner"] = estimator if len(self.estimator_list) > 1: - points.append({'ml': config}) + points.append({"ml": config}) else: points.append(config) return points @@ -1475,7 +1478,12 @@ class AutoML: elif self._state.task == "rank": estimator_list = ["lgbm", "xgboost"] else: - estimator_list = ["lgbm", "rf", "catboost", "xgboost", "extra_tree"] + try: + import catboost + + estimator_list = ["lgbm", "rf", "catboost", "xgboost", "extra_tree"] + except ImportError: + estimator_list = ["lgbm", "rf", "xgboost", "extra_tree"] if "regression" != self._state.task: estimator_list += ["lrl1"] for estimator_name in estimator_list: @@ -1765,8 +1773,11 @@ class AutoML: self._max_iter_per_learner = len(points_to_evaluate) low_cost_partial_config = None else: - points_to_evaluate = search_state.init_config if isinstance( - search_state.init_config, list) else [search_state.init_config] + points_to_evaluate = ( + search_state.init_config + if isinstance(search_state.init_config, list) + else [search_state.init_config] + ) low_cost_partial_config = search_state.low_cost_partial_config if self._hpo_method in ("bs", "cfo", "grid", "cfocat"): algo = SearchAlgo(