fix zerodivision (#1000)

* fix zerodivision

* update

* remove final

---------

Co-authored-by: Li Jiang <lijiang1@microsoft.com>
This commit is contained in:
Susan Xueqing Liu 2023-04-22 23:55:51 -04:00 committed by GitHub
parent da0d8c05e1
commit 7114b8f742
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1136,8 +1136,7 @@ class TransformersEstimator(BaseEstimator):
except ZeroDivisionError: except ZeroDivisionError:
logger.warning("Zero division error appeared in HuggingFace Transformers.") logger.warning("Zero division error appeared in HuggingFace Transformers.")
predictions = np.array([-0.05] * len(test_dataset)) predictions = np.array([-0.05] * len(test_dataset))
else: return predictions
return predictions
def score(self, X_val: DataFrame, y_val: Series, **kwargs): def score(self, X_val: DataFrame, y_val: Series, **kwargs):
import transformers import transformers
@ -1169,14 +1168,13 @@ class TransformersEstimator(BaseEstimator):
kwargs = {} if self._task not in NLG_TASKS else {"metric_key_prefix": "predict"} kwargs = {} if self._task not in NLG_TASKS else {"metric_key_prefix": "predict"}
try: try:
predictions = new_trainer.predict(test_dataset, **kwargs) predictions = new_trainer.predict(test_dataset, **kwargs).predictions
except ZeroDivisionError: except ZeroDivisionError:
logger.warning("Zero division error appeared in HuggingFace Transformers.") logger.warning("Zero division error appeared in HuggingFace Transformers.")
predictions = np.array([0] * len(test_dataset)) predictions = np.array([0] * len(test_dataset))
post_y_pred, _ = postprocess_prediction_and_true( post_y_pred, _ = postprocess_prediction_and_true(
task=self._task, task=self._task,
y_pred=predictions.predictions, y_pred=predictions,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
hf_args=self._training_args, hf_args=self._training_args,
X=X, X=X,