2021-11-16 14:06:20 -05:00
|
|
|
import os
|
|
|
|
|
|
|
|
try:
|
|
|
|
from transformers import Trainer as TFTrainer
|
2021-12-03 12:45:16 -05:00
|
|
|
from transformers import Seq2SeqTrainer
|
2021-11-16 14:06:20 -05:00
|
|
|
except ImportError:
|
|
|
|
TFTrainer = object
|
|
|
|
|
|
|
|
|
|
|
|
class TrainerForAuto(TFTrainer):
|
2021-12-03 12:45:16 -05:00
|
|
|
def evaluate(
|
|
|
|
self,
|
|
|
|
eval_dataset=None,
|
|
|
|
ignore_keys=None,
|
|
|
|
metric_key_prefix="eval",
|
|
|
|
is_seq2seq=False,
|
|
|
|
):
|
2021-12-16 17:11:33 -08:00
|
|
|
"""Overriding transformers.Trainer.evaluate by saving metrics and checkpoint path."""
|
2021-11-16 14:06:20 -05:00
|
|
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
|
|
|
|
2021-11-18 09:39:45 -08:00
|
|
|
ckpt_dir = os.path.join(
|
|
|
|
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
|
|
|
)
|
|
|
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
2021-12-03 12:45:16 -05:00
|
|
|
|
|
|
|
# TODO: if your task is seq2seq (i.e., SUMMARIZATION), uncomment the code below (add indentation before metrics = eval_dataset...
|
|
|
|
|
|
|
|
# if is_seq2seq:
|
|
|
|
# metrics = eval_dataset and super().evaluate(
|
|
|
|
# eval_dataset,
|
|
|
|
# ignore_keys,
|
|
|
|
# metric_key_prefix,
|
|
|
|
# num_beams=self.args.num_beams,
|
|
|
|
# )
|
|
|
|
# else:
|
2021-11-18 09:39:45 -08:00
|
|
|
metrics = eval_dataset and super().evaluate(
|
2021-12-03 12:45:16 -05:00
|
|
|
eval_dataset,
|
|
|
|
ignore_keys,
|
|
|
|
metric_key_prefix,
|
2021-11-18 09:39:45 -08:00
|
|
|
)
|
|
|
|
if metrics:
|
|
|
|
for key in list(metrics.keys()):
|
|
|
|
if key.startswith("eval_"):
|
|
|
|
metrics[key[5:]] = metrics.pop(key)
|
|
|
|
if hasattr(self, "ckpt_to_global_step"):
|
|
|
|
self.ckpt_to_global_step[ckpt_dir] = self.state.global_step
|
|
|
|
if metrics:
|
|
|
|
self.ckpt_to_metric[ckpt_dir] = metrics
|
|
|
|
else:
|
|
|
|
self.ckpt_to_global_step = {ckpt_dir: self.state.global_step}
|
|
|
|
self.ckpt_to_metric = {ckpt_dir: metrics} if metrics else {}
|
2021-12-03 12:45:16 -05:00
|
|
|
|
|
|
|
|
|
|
|
# TODO: if your task is SUMMARIZATION, you need a different
|
|
|
|
# class Seq2SeqTrainerForAuto, uncomment the code below
|
|
|
|
# Note: I have implemented it here,
|
|
|
|
# but I don't know whether it's correct, you need to debug
|
|
|
|
# Seq2SeqTrainerForAuto to make sure it's correct
|
|
|
|
|
|
|
|
|
|
|
|
# class Seq2SeqTrainerForAuto(Seq2SeqTrainer, TrainerForAuto):
|
|
|
|
# def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
|
|
|
|
# """Overriding transformers.Trainer.evaluate by saving metrics and checkpoint path"""
|
|
|
|
# super(TrainerForAuto).evaluate(
|
|
|
|
# eval_dataset, ignore_keys, metric_key_prefix, is_seq2seq=True
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: if your task is QUESTIONANSWERING, uncomment the code below
|
|
|
|
# by adapting the code in https://github.com/huggingface/transformers/blob/master/examples/pytorch/question-answering/trainer_qa.py#L28
|
|
|
|
|
|
|
|
|
|
|
|
# class QATrainerForAuto(TrainerForAuto):
|
|
|
|
# pass
|
|
|
|
# TODO: if your task is QUESTIONANSWERING, do the post processing here
|