mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-03 19:29:52 +00:00
Unify regression and classification for XGBoost (#276)
* scikit-learn API for XGBoostRegressor
This commit is contained in:
parent
3f09c694a3
commit
5b0932e442
@ -19,7 +19,6 @@ from sklearn.metrics import (
|
||||
)
|
||||
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
|
||||
from .model import (
|
||||
XGBoostEstimator,
|
||||
XGBoostSklearnEstimator,
|
||||
RandomForestEstimator,
|
||||
LGBMEstimator,
|
||||
@ -41,10 +40,7 @@ logger = logging.getLogger(__name__)
|
||||
def get_estimator_class(task, estimator_name):
|
||||
# when adding a new learner, need to add an elif branch
|
||||
if "xgboost" == estimator_name:
|
||||
if "regression" == task:
|
||||
estimator_class = XGBoostEstimator
|
||||
else:
|
||||
estimator_class = XGBoostSklearnEstimator
|
||||
estimator_class = XGBoostSklearnEstimator
|
||||
elif "rf" == estimator_name:
|
||||
estimator_class = RandomForestEstimator
|
||||
elif "lgbm" == estimator_name:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user