diff --git a/flaml/model.py b/flaml/model.py index 4580a3892..e1941d5f1 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -921,7 +921,16 @@ class TransformersEstimatorModelSelection(TransformersEstimator): class SKLearnEstimator(BaseEstimator): - """The base class for tuning scikit-learn estimators.""" + """ + The base class for tuning scikit-learn estimators. + + Subclasses can modify the function signature of ``__init__`` to + ignore the values in ``config`` that are not relevant to the constructor + of their underlying estimator. For example, some regressors in ``scikit-learn`` + don't accept the ``n_jobs`` parameter contained in ``config``. For these, + one can add ``n_jobs=None,`` before ``**config`` to make sure ``config`` doesn't + contain an ``n_jobs`` key. + """ def __init__(self, task="binary", **config): super().__init__(task, **config) diff --git a/website/docs/Use-Cases/Task-Oriented-AutoML.md b/website/docs/Use-Cases/Task-Oriented-AutoML.md index ec885b345..84c1b1ea0 100644 --- a/website/docs/Use-Cases/Task-Oriented-AutoML.md +++ b/website/docs/Use-Cases/Task-Oriented-AutoML.md @@ -169,7 +169,7 @@ class MyRegularizedGreedyForest(SKLearnEstimator): return space ``` -In the constructor, we set `self.estimator_class` as `RGFClassifier` or `RGFRegressor` according to the task type. If the estimator you want to tune does not have a scikit-learn style `fit()` and `predict()` API, you can override the `fit()` and `predict()` function of `flaml.model.BaseEstimator`, like [XGBoostEstimator](../reference/model#xgboostestimator-objects). +In the constructor, we set `self.estimator_class` as `RGFClassifier` or `RGFRegressor` according to the task type. If the estimator you want to tune does not have a scikit-learn style `fit()` and `predict()` API, you can override the `fit()` and `predict()` function of `flaml.model.BaseEstimator`, like [XGBoostEstimator](../reference/model#xgboostestimator-objects). Importantly, we also add the `task="binary"` parameter in the signature of `__init__` so that it doesn't get grouped together with the `**config` kwargs that determines the parameters with which the underlying estimator (`self.estimator_class`) is constructed. If your estimator doesn't use one of the parameters that it is passed, for example some regressors in `scikit-learn` don't use the `n_jobs` parameter, it is enough to add `n_jobs=None` to the signature so that it is ignored by the `**config` dict. 2. Give the custom estimator a name and add it in AutoML. E.g.,