Merge branch 'main' into fix_alerts

This commit is contained in:
zsk 2022-07-26 16:25:56 -04:00 committed by GitHub
commit 655b7bfefa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -254,7 +254,7 @@ class BaseEstimator:
""" """
if self._model is not None: if self._model is not None:
X = self._preprocess(X) X = self._preprocess(X)
return self._model.predict(X) return self._model.predict(X, **kwargs)
else: else:
logger.warning( logger.warning(
"Estimator is not fit yet. Please run fit() before predict()." "Estimator is not fit yet. Please run fit() before predict()."
@ -277,7 +277,7 @@ class BaseEstimator:
assert self._task in CLASSIFICATION, "predict_proba() only for classification." assert self._task in CLASSIFICATION, "predict_proba() only for classification."
X = self._preprocess(X) X = self._preprocess(X)
return self._model.predict_proba(X) return self._model.predict_proba(X, **kwargs)
def score(self, X_val: DataFrame, y_val: Series, **kwargs): def score(self, X_val: DataFrame, y_val: Series, **kwargs):
"""Report the evaluation score of a trained estimator. """Report the evaluation score of a trained estimator.
@ -312,7 +312,7 @@ class BaseEstimator:
) )
else: else:
X_val = self._preprocess(X_val) X_val = self._preprocess(X_val)
metric = kwargs.get("metric", None) metric = kwargs.pop("metric", None)
if metric: if metric:
y_pred = self.predict(X_val, **kwargs) y_pred = self.predict(X_val, **kwargs)
if is_min_metric(metric): if is_min_metric(metric):
@ -1321,7 +1321,7 @@ class XGBoostEstimator(SKLearnEstimator):
if not issparse(X): if not issparse(X):
X = self._preprocess(X) X = self._preprocess(X)
dtest = xgb.DMatrix(X) dtest = xgb.DMatrix(X)
return super().predict(dtest) return super().predict(dtest, **kwargs)
@classmethod @classmethod
def _callbacks(cls, start_time, deadline): def _callbacks(cls, start_time, deadline):
@ -1823,7 +1823,7 @@ class Prophet(SKLearnEstimator):
) )
if self._model is not None: if self._model is not None:
X = self._preprocess(X) X = self._preprocess(X)
forecast = self._model.predict(X) forecast = self._model.predict(X, **kwargs)
return forecast["yhat"] return forecast["yhat"]
else: else:
logger.warning( logger.warning(
@ -1835,7 +1835,7 @@ class Prophet(SKLearnEstimator):
from sklearn.metrics import r2_score from sklearn.metrics import r2_score
from .ml import metric_loss_score from .ml import metric_loss_score
y_pred = self.predict(X_val) y_pred = self.predict(X_val, **kwargs)
self._metric = kwargs.get("metric", None) self._metric = kwargs.get("metric", None)
if self._metric: if self._metric:
return metric_loss_score(self._metric, y_pred, y_val) return metric_loss_score(self._metric, y_pred, y_val)
@ -1916,10 +1916,10 @@ class ARIMA(Prophet):
X = self._preprocess(X.drop(columns=TS_TIMESTAMP_COL)) X = self._preprocess(X.drop(columns=TS_TIMESTAMP_COL))
regressors = list(X) regressors = list(X)
forecast = self._model.predict( forecast = self._model.predict(
start=start, end=end, exog=X[regressors] start=start, end=end, exog=X[regressors], **kwargs
) )
else: else:
forecast = self._model.predict(start=start, end=end) forecast = self._model.predict(start=start, end=end, **kwargs)
else: else:
raise ValueError( raise ValueError(
"X needs to be either a pandas Dataframe with dates as the first column" "X needs to be either a pandas Dataframe with dates as the first column"
@ -2120,7 +2120,7 @@ class TS_SKLearn(SKLearnEstimator):
) = self.hcrystaball_model._transform_data_to_tsmodel_input_format( ) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(
X.iloc[:i, :] X.iloc[:i, :]
) )
preds.append(self._model[i - 1].predict(X_pred)[-1]) preds.append(self._model[i - 1].predict(X_pred, **kwargs)[-1])
forecast = DataFrame( forecast = DataFrame(
data=np.asarray(preds).reshape(-1, 1), data=np.asarray(preds).reshape(-1, 1),
columns=[self.hcrystaball_model.name], columns=[self.hcrystaball_model.name],
@ -2131,7 +2131,7 @@ class TS_SKLearn(SKLearnEstimator):
X_pred, X_pred,
_, _,
) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X) ) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X)
forecast = self._model.predict(X_pred) forecast = self._model.predict(X_pred, **kwargs)
return forecast return forecast
else: else:
logger.warning( logger.warning(