From 5b0932e442e41316a8ddaff460ec15cfdffd18ca Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Tue, 9 Nov 2021 21:23:54 -0800 Subject: [PATCH] Unify regression and classification for XGBoost (#276) * scikit-learn API for XGBoostRegressor --- flaml/ml.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/flaml/ml.py b/flaml/ml.py index 02c523d25..fdafec752 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -19,7 +19,6 @@ from sklearn.metrics import ( ) from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit from .model import ( - XGBoostEstimator, XGBoostSklearnEstimator, RandomForestEstimator, LGBMEstimator, @@ -41,10 +40,7 @@ logger = logging.getLogger(__name__) def get_estimator_class(task, estimator_name): # when adding a new learner, need to add an elif branch if "xgboost" == estimator_name: - if "regression" == task: - estimator_class = XGBoostEstimator - else: - estimator_class = XGBoostSklearnEstimator + estimator_class = XGBoostSklearnEstimator elif "rf" == estimator_name: estimator_class = RandomForestEstimator elif "lgbm" == estimator_name: