mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-30 00:30:23 +00:00
bug fix in confg2params (#323)
* bug fix in confg2params * set the task property before config2params
This commit is contained in:
parent
18230ed22f
commit
54d303a95a
@ -90,9 +90,9 @@ class BaseEstimator:
|
||||
config: A dictionary containing the hyperparameter names, 'n_jobs' as keys.
|
||||
n_jobs is the number of parallel threads.
|
||||
"""
|
||||
self._task = task
|
||||
self.params = self.config2params(config)
|
||||
self.estimator_class = self._model = None
|
||||
self._task = task
|
||||
if "_estimator_type" in config:
|
||||
self._estimator_type = self.params.pop("_estimator_type")
|
||||
else:
|
||||
@ -678,7 +678,7 @@ class TransformersEstimator(BaseEstimator):
|
||||
# e.g., if your task == QUESTIONANSWERING, you need to return the answer instead
|
||||
# of the index
|
||||
|
||||
def config2params(cls, config: dict) -> dict:
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
params[TransformersEstimator.ITER_HP] = params.get(
|
||||
TransformersEstimator.ITER_HP, sys.maxsize
|
||||
@ -758,7 +758,7 @@ class LGBMEstimator(BaseEstimator):
|
||||
},
|
||||
}
|
||||
|
||||
def config2params(cls, config: dict) -> dict:
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
if "log_max_bin" in params:
|
||||
params["max_bin"] = (1 << params.pop("log_max_bin")) - 1
|
||||
@ -983,7 +983,7 @@ class XGBoostEstimator(SKLearnEstimator):
|
||||
def cost_relative2lgbm(cls):
|
||||
return 1.6
|
||||
|
||||
def config2params(cls, config: dict) -> dict:
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
max_depth = params["max_depth"] = params.get("max_depth", 0)
|
||||
if max_depth == 0:
|
||||
@ -1087,7 +1087,7 @@ class XGBoostSklearnEstimator(SKLearnEstimator, LGBMEstimator):
|
||||
def cost_relative2lgbm(cls):
|
||||
return XGBoostEstimator.cost_relative2lgbm()
|
||||
|
||||
def config2params(cls, config: dict) -> dict:
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
max_depth = params["max_depth"] = params.get("max_depth", 0)
|
||||
if max_depth == 0:
|
||||
@ -1184,12 +1184,14 @@ class RandomForestEstimator(SKLearnEstimator, LGBMEstimator):
|
||||
def cost_relative2lgbm(cls):
|
||||
return 2
|
||||
|
||||
def config2params(cls, config: dict) -> dict:
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
if "max_leaves" in params:
|
||||
params["max_leaf_nodes"] = params.get(
|
||||
"max_leaf_nodes", params.pop("max_leaves")
|
||||
)
|
||||
if self._task not in CLASSIFICATION and "criterion" in config:
|
||||
params.pop("criterion")
|
||||
return params
|
||||
|
||||
def __init__(
|
||||
@ -1235,7 +1237,7 @@ class LRL1Classifier(SKLearnEstimator):
|
||||
def cost_relative2lgbm(cls):
|
||||
return 160
|
||||
|
||||
def config2params(cls, config: dict) -> dict:
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
params["tol"] = params.get("tol", 0.0001)
|
||||
params["solver"] = params.get("solver", "saga")
|
||||
@ -1261,7 +1263,7 @@ class LRL2Classifier(SKLearnEstimator):
|
||||
def cost_relative2lgbm(cls):
|
||||
return 25
|
||||
|
||||
def config2params(cls, config: dict) -> dict:
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
params["tol"] = params.get("tol", 0.0001)
|
||||
params["solver"] = params.get("solver", "lbfgs")
|
||||
@ -1330,7 +1332,7 @@ class CatBoostEstimator(BaseEstimator):
|
||||
X = X.to_numpy()
|
||||
return X
|
||||
|
||||
def config2params(cls, config: dict) -> dict:
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
params["n_estimators"] = params.get("n_estimators", 8192)
|
||||
if "n_jobs" in params:
|
||||
@ -1440,7 +1442,7 @@ class KNeighborsEstimator(BaseEstimator):
|
||||
def cost_relative2lgbm(cls):
|
||||
return 30
|
||||
|
||||
def config2params(cls, config: dict) -> dict:
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
params["weights"] = params.get("weights", "distance")
|
||||
return params
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user