mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-13 16:44:32 +00:00
Unify regression and classification for XGBoost (#276)
* scikit-learn API for XGBoostRegressor
This commit is contained in:
parent
3f09c694a3
commit
5b0932e442
@ -19,7 +19,6 @@ from sklearn.metrics import (
|
|||||||
)
|
)
|
||||||
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
|
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
|
||||||
from .model import (
|
from .model import (
|
||||||
XGBoostEstimator,
|
|
||||||
XGBoostSklearnEstimator,
|
XGBoostSklearnEstimator,
|
||||||
RandomForestEstimator,
|
RandomForestEstimator,
|
||||||
LGBMEstimator,
|
LGBMEstimator,
|
||||||
@ -41,10 +40,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def get_estimator_class(task, estimator_name):
|
def get_estimator_class(task, estimator_name):
|
||||||
# when adding a new learner, need to add an elif branch
|
# when adding a new learner, need to add an elif branch
|
||||||
if "xgboost" == estimator_name:
|
if "xgboost" == estimator_name:
|
||||||
if "regression" == task:
|
estimator_class = XGBoostSklearnEstimator
|
||||||
estimator_class = XGBoostEstimator
|
|
||||||
else:
|
|
||||||
estimator_class = XGBoostSklearnEstimator
|
|
||||||
elif "rf" == estimator_name:
|
elif "rf" == estimator_name:
|
||||||
estimator_class = RandomForestEstimator
|
estimator_class = RandomForestEstimator
|
||||||
elif "lgbm" == estimator_name:
|
elif "lgbm" == estimator_name:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user