2024-10-16 19:14:36 +08:00

32 lines
1.2 KiB
Python

import logging
import os
from typing import Optional
import torch
from transformers.trainer import Trainer
from .modeling import CrossEncoder
logger = logging.getLogger(__name__)
class CETrainer(Trainer):
def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not hasattr(self.model, 'save_pretrained'):
raise NotImplementedError(f'MODEL {self.model.__class__.__name__} ' f'does not support save_pretrained interface')
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def compute_loss(self, model: CrossEncoder, inputs):
return model(inputs)['loss']