Unify regression and classification for XGBoost (#276)

* scikit-learn API for XGBoostRegressor
This commit is contained in:
Chi Wang 2021-11-09 21:23:54 -08:00 committed by GitHub
parent 3f09c694a3
commit 5b0932e442
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,7 +19,6 @@ from sklearn.metrics import (
) )
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
from .model import ( from .model import (
XGBoostEstimator,
XGBoostSklearnEstimator, XGBoostSklearnEstimator,
RandomForestEstimator, RandomForestEstimator,
LGBMEstimator, LGBMEstimator,
@ -41,10 +40,7 @@ logger = logging.getLogger(__name__)
def get_estimator_class(task, estimator_name): def get_estimator_class(task, estimator_name):
# when adding a new learner, need to add an elif branch # when adding a new learner, need to add an elif branch
if "xgboost" == estimator_name: if "xgboost" == estimator_name:
if "regression" == task: estimator_class = XGBoostSklearnEstimator
estimator_class = XGBoostEstimator
else:
estimator_class = XGBoostSklearnEstimator
elif "rf" == estimator_name: elif "rf" == estimator_name:
estimator_class = RandomForestEstimator estimator_class = RandomForestEstimator
elif "lgbm" == estimator_name: elif "lgbm" == estimator_name: