bug fix in confg2params (#323)

* bug fix in confg2params

* set the task property before config2params
This commit is contained in:
Chi Wang 2021-12-03 19:37:49 -08:00 committed by GitHub
parent 18230ed22f
commit 54d303a95a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -90,9 +90,9 @@ class BaseEstimator:
config: A dictionary containing the hyperparameter names, 'n_jobs' as keys.
n_jobs is the number of parallel threads.
"""
self._task = task
self.params = self.config2params(config)
self.estimator_class = self._model = None
self._task = task
if "_estimator_type" in config:
self._estimator_type = self.params.pop("_estimator_type")
else:
@ -678,7 +678,7 @@ class TransformersEstimator(BaseEstimator):
# e.g., if your task == QUESTIONANSWERING, you need to return the answer instead
# of the index
def config2params(cls, config: dict) -> dict:
def config2params(self, config: dict) -> dict:
params = config.copy()
params[TransformersEstimator.ITER_HP] = params.get(
TransformersEstimator.ITER_HP, sys.maxsize
@ -758,7 +758,7 @@ class LGBMEstimator(BaseEstimator):
},
}
def config2params(cls, config: dict) -> dict:
def config2params(self, config: dict) -> dict:
params = config.copy()
if "log_max_bin" in params:
params["max_bin"] = (1 << params.pop("log_max_bin")) - 1
@ -983,7 +983,7 @@ class XGBoostEstimator(SKLearnEstimator):
def cost_relative2lgbm(cls):
return 1.6
def config2params(cls, config: dict) -> dict:
def config2params(self, config: dict) -> dict:
params = config.copy()
max_depth = params["max_depth"] = params.get("max_depth", 0)
if max_depth == 0:
@ -1087,7 +1087,7 @@ class XGBoostSklearnEstimator(SKLearnEstimator, LGBMEstimator):
def cost_relative2lgbm(cls):
return XGBoostEstimator.cost_relative2lgbm()
def config2params(cls, config: dict) -> dict:
def config2params(self, config: dict) -> dict:
params = config.copy()
max_depth = params["max_depth"] = params.get("max_depth", 0)
if max_depth == 0:
@ -1184,12 +1184,14 @@ class RandomForestEstimator(SKLearnEstimator, LGBMEstimator):
def cost_relative2lgbm(cls):
return 2
def config2params(cls, config: dict) -> dict:
def config2params(self, config: dict) -> dict:
params = config.copy()
if "max_leaves" in params:
params["max_leaf_nodes"] = params.get(
"max_leaf_nodes", params.pop("max_leaves")
)
if self._task not in CLASSIFICATION and "criterion" in config:
params.pop("criterion")
return params
def __init__(
@ -1235,7 +1237,7 @@ class LRL1Classifier(SKLearnEstimator):
def cost_relative2lgbm(cls):
return 160
def config2params(cls, config: dict) -> dict:
def config2params(self, config: dict) -> dict:
params = config.copy()
params["tol"] = params.get("tol", 0.0001)
params["solver"] = params.get("solver", "saga")
@ -1261,7 +1263,7 @@ class LRL2Classifier(SKLearnEstimator):
def cost_relative2lgbm(cls):
return 25
def config2params(cls, config: dict) -> dict:
def config2params(self, config: dict) -> dict:
params = config.copy()
params["tol"] = params.get("tol", 0.0001)
params["solver"] = params.get("solver", "lbfgs")
@ -1330,7 +1332,7 @@ class CatBoostEstimator(BaseEstimator):
X = X.to_numpy()
return X
def config2params(cls, config: dict) -> dict:
def config2params(self, config: dict) -> dict:
params = config.copy()
params["n_estimators"] = params.get("n_estimators", 8192)
if "n_jobs" in params:
@ -1440,7 +1442,7 @@ class KNeighborsEstimator(BaseEstimator):
def cost_relative2lgbm(cls):
return 30
def config2params(cls, config: dict) -> dict:
def config2params(self, config: dict) -> dict:
params = config.copy()
params["weights"] = params.get("weights", "distance")
return params