Merge branch 'main' into fix_alerts

This commit is contained in:
zsk 2022-07-26 16:25:56 -04:00 committed by GitHub
commit 655b7bfefa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -254,7 +254,7 @@ class BaseEstimator:
"""
if self._model is not None:
X = self._preprocess(X)
return self._model.predict(X)
return self._model.predict(X, **kwargs)
else:
logger.warning(
"Estimator is not fit yet. Please run fit() before predict()."
@ -277,7 +277,7 @@ class BaseEstimator:
assert self._task in CLASSIFICATION, "predict_proba() only for classification."
X = self._preprocess(X)
return self._model.predict_proba(X)
return self._model.predict_proba(X, **kwargs)
def score(self, X_val: DataFrame, y_val: Series, **kwargs):
"""Report the evaluation score of a trained estimator.
@ -312,7 +312,7 @@ class BaseEstimator:
)
else:
X_val = self._preprocess(X_val)
metric = kwargs.get("metric", None)
metric = kwargs.pop("metric", None)
if metric:
y_pred = self.predict(X_val, **kwargs)
if is_min_metric(metric):
@ -1321,7 +1321,7 @@ class XGBoostEstimator(SKLearnEstimator):
if not issparse(X):
X = self._preprocess(X)
dtest = xgb.DMatrix(X)
return super().predict(dtest)
return super().predict(dtest, **kwargs)
@classmethod
def _callbacks(cls, start_time, deadline):
@ -1823,7 +1823,7 @@ class Prophet(SKLearnEstimator):
)
if self._model is not None:
X = self._preprocess(X)
forecast = self._model.predict(X)
forecast = self._model.predict(X, **kwargs)
return forecast["yhat"]
else:
logger.warning(
@ -1835,7 +1835,7 @@ class Prophet(SKLearnEstimator):
from sklearn.metrics import r2_score
from .ml import metric_loss_score
y_pred = self.predict(X_val)
y_pred = self.predict(X_val, **kwargs)
self._metric = kwargs.get("metric", None)
if self._metric:
return metric_loss_score(self._metric, y_pred, y_val)
@ -1916,10 +1916,10 @@ class ARIMA(Prophet):
X = self._preprocess(X.drop(columns=TS_TIMESTAMP_COL))
regressors = list(X)
forecast = self._model.predict(
start=start, end=end, exog=X[regressors]
start=start, end=end, exog=X[regressors], **kwargs
)
else:
forecast = self._model.predict(start=start, end=end)
forecast = self._model.predict(start=start, end=end, **kwargs)
else:
raise ValueError(
"X needs to be either a pandas Dataframe with dates as the first column"
@ -2120,7 +2120,7 @@ class TS_SKLearn(SKLearnEstimator):
) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(
X.iloc[:i, :]
)
preds.append(self._model[i - 1].predict(X_pred)[-1])
preds.append(self._model[i - 1].predict(X_pred, **kwargs)[-1])
forecast = DataFrame(
data=np.asarray(preds).reshape(-1, 1),
columns=[self.hcrystaball_model.name],
@ -2131,7 +2131,7 @@ class TS_SKLearn(SKLearnEstimator):
X_pred,
_,
) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X)
forecast = self._model.predict(X_pred)
forecast = self._model.predict(X_pred, **kwargs)
return forecast
else:
logger.warning(