diff --git a/flaml/model.py b/flaml/model.py index 09cef1b34..30b9d0969 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -254,7 +254,7 @@ class BaseEstimator: """ if self._model is not None: X = self._preprocess(X) - return self._model.predict(X) + return self._model.predict(X, **kwargs) else: logger.warning( "Estimator is not fit yet. Please run fit() before predict()." @@ -277,7 +277,7 @@ class BaseEstimator: assert self._task in CLASSIFICATION, "predict_proba() only for classification." X = self._preprocess(X) - return self._model.predict_proba(X) + return self._model.predict_proba(X, **kwargs) def score(self, X_val: DataFrame, y_val: Series, **kwargs): """Report the evaluation score of a trained estimator. @@ -312,7 +312,7 @@ class BaseEstimator: ) else: X_val = self._preprocess(X_val) - metric = kwargs.get("metric", None) + metric = kwargs.pop("metric", None) if metric: y_pred = self.predict(X_val, **kwargs) if is_min_metric(metric): @@ -1321,7 +1321,7 @@ class XGBoostEstimator(SKLearnEstimator): if not issparse(X): X = self._preprocess(X) dtest = xgb.DMatrix(X) - return super().predict(dtest) + return super().predict(dtest, **kwargs) @classmethod def _callbacks(cls, start_time, deadline): @@ -1823,7 +1823,7 @@ class Prophet(SKLearnEstimator): ) if self._model is not None: X = self._preprocess(X) - forecast = self._model.predict(X) + forecast = self._model.predict(X, **kwargs) return forecast["yhat"] else: logger.warning( @@ -1835,7 +1835,7 @@ class Prophet(SKLearnEstimator): from sklearn.metrics import r2_score from .ml import metric_loss_score - y_pred = self.predict(X_val) + y_pred = self.predict(X_val, **kwargs) self._metric = kwargs.get("metric", None) if self._metric: return metric_loss_score(self._metric, y_pred, y_val) @@ -1916,10 +1916,10 @@ class ARIMA(Prophet): X = self._preprocess(X.drop(columns=TS_TIMESTAMP_COL)) regressors = list(X) forecast = self._model.predict( - start=start, end=end, exog=X[regressors] + start=start, end=end, exog=X[regressors], **kwargs ) else: - forecast = self._model.predict(start=start, end=end) + forecast = self._model.predict(start=start, end=end, **kwargs) else: raise ValueError( "X needs to be either a pandas Dataframe with dates as the first column" @@ -2120,7 +2120,7 @@ class TS_SKLearn(SKLearnEstimator): ) = self.hcrystaball_model._transform_data_to_tsmodel_input_format( X.iloc[:i, :] ) - preds.append(self._model[i - 1].predict(X_pred)[-1]) + preds.append(self._model[i - 1].predict(X_pred, **kwargs)[-1]) forecast = DataFrame( data=np.asarray(preds).reshape(-1, 1), columns=[self.hcrystaball_model.name], @@ -2131,7 +2131,7 @@ class TS_SKLearn(SKLearnEstimator): X_pred, _, ) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X) - forecast = self._model.predict(X_pred) + forecast = self._model.predict(X_pred, **kwargs) return forecast else: logger.warning(