Unify regression and classification for XGBoost (#276)

* scikit-learn API for XGBoostRegressor
This commit is contained in:
Chi Wang 2021-11-09 21:23:54 -08:00 committed by GitHub
parent 3f09c694a3
commit 5b0932e442
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,7 +19,6 @@ from sklearn.metrics import (
)
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
from .model import (
XGBoostEstimator,
XGBoostSklearnEstimator,
RandomForestEstimator,
LGBMEstimator,
@ -41,10 +40,7 @@ logger = logging.getLogger(__name__)
def get_estimator_class(task, estimator_name):
# when adding a new learner, need to add an elif branch
if "xgboost" == estimator_name:
if "regression" == task:
estimator_class = XGBoostEstimator
else:
estimator_class = XGBoostSklearnEstimator
estimator_class = XGBoostSklearnEstimator
elif "rf" == estimator_name:
estimator_class = RandomForestEstimator
elif "lgbm" == estimator_name: