mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-24 01:11:45 +00:00
Merge branch 'main' into fix_alerts
This commit is contained in:
commit
655b7bfefa
@ -254,7 +254,7 @@ class BaseEstimator:
|
||||
"""
|
||||
if self._model is not None:
|
||||
X = self._preprocess(X)
|
||||
return self._model.predict(X)
|
||||
return self._model.predict(X, **kwargs)
|
||||
else:
|
||||
logger.warning(
|
||||
"Estimator is not fit yet. Please run fit() before predict()."
|
||||
@ -277,7 +277,7 @@ class BaseEstimator:
|
||||
assert self._task in CLASSIFICATION, "predict_proba() only for classification."
|
||||
|
||||
X = self._preprocess(X)
|
||||
return self._model.predict_proba(X)
|
||||
return self._model.predict_proba(X, **kwargs)
|
||||
|
||||
def score(self, X_val: DataFrame, y_val: Series, **kwargs):
|
||||
"""Report the evaluation score of a trained estimator.
|
||||
@ -312,7 +312,7 @@ class BaseEstimator:
|
||||
)
|
||||
else:
|
||||
X_val = self._preprocess(X_val)
|
||||
metric = kwargs.get("metric", None)
|
||||
metric = kwargs.pop("metric", None)
|
||||
if metric:
|
||||
y_pred = self.predict(X_val, **kwargs)
|
||||
if is_min_metric(metric):
|
||||
@ -1321,7 +1321,7 @@ class XGBoostEstimator(SKLearnEstimator):
|
||||
if not issparse(X):
|
||||
X = self._preprocess(X)
|
||||
dtest = xgb.DMatrix(X)
|
||||
return super().predict(dtest)
|
||||
return super().predict(dtest, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _callbacks(cls, start_time, deadline):
|
||||
@ -1823,7 +1823,7 @@ class Prophet(SKLearnEstimator):
|
||||
)
|
||||
if self._model is not None:
|
||||
X = self._preprocess(X)
|
||||
forecast = self._model.predict(X)
|
||||
forecast = self._model.predict(X, **kwargs)
|
||||
return forecast["yhat"]
|
||||
else:
|
||||
logger.warning(
|
||||
@ -1835,7 +1835,7 @@ class Prophet(SKLearnEstimator):
|
||||
from sklearn.metrics import r2_score
|
||||
from .ml import metric_loss_score
|
||||
|
||||
y_pred = self.predict(X_val)
|
||||
y_pred = self.predict(X_val, **kwargs)
|
||||
self._metric = kwargs.get("metric", None)
|
||||
if self._metric:
|
||||
return metric_loss_score(self._metric, y_pred, y_val)
|
||||
@ -1916,10 +1916,10 @@ class ARIMA(Prophet):
|
||||
X = self._preprocess(X.drop(columns=TS_TIMESTAMP_COL))
|
||||
regressors = list(X)
|
||||
forecast = self._model.predict(
|
||||
start=start, end=end, exog=X[regressors]
|
||||
start=start, end=end, exog=X[regressors], **kwargs
|
||||
)
|
||||
else:
|
||||
forecast = self._model.predict(start=start, end=end)
|
||||
forecast = self._model.predict(start=start, end=end, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"X needs to be either a pandas Dataframe with dates as the first column"
|
||||
@ -2120,7 +2120,7 @@ class TS_SKLearn(SKLearnEstimator):
|
||||
) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(
|
||||
X.iloc[:i, :]
|
||||
)
|
||||
preds.append(self._model[i - 1].predict(X_pred)[-1])
|
||||
preds.append(self._model[i - 1].predict(X_pred, **kwargs)[-1])
|
||||
forecast = DataFrame(
|
||||
data=np.asarray(preds).reshape(-1, 1),
|
||||
columns=[self.hcrystaball_model.name],
|
||||
@ -2131,7 +2131,7 @@ class TS_SKLearn(SKLearnEstimator):
|
||||
X_pred,
|
||||
_,
|
||||
) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X)
|
||||
forecast = self._model.predict(X_pred)
|
||||
forecast = self._model.predict(X_pred, **kwargs)
|
||||
return forecast
|
||||
else:
|
||||
logger.warning(
|
||||
|
Loading…
x
Reference in New Issue
Block a user