285 lines
9.6 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from typing import Optional
import jieba
import numpy as np
import torch
from datasets import load_dataset
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_chinese import Rouge
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainingArguments,
Trainer,
)
from transformers.trainer import TRAINING_ARGS_NAME
from arguments import ModelArguments, DataTrainingArguments
class PrefixTrainer(Trainer):
def __init__(self, *args, save_changed=False, **kwargs):
self.save_changed = save_changed
super().__init__(*args, **kwargs)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
print(f"Saving model checkpoint to {output_dir}")
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
print("Saving PrefixEncoder")
state_dict = self.model.state_dict()
filtered_state_dict = {}
for k, v in self.model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k]
self.model.save_pretrained(output_dir, state_dict=filtered_state_dict)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
def load_training_dataset(tokenizer, data_args, model_args, training_args):
# Load dataset
data_files = {}
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
# Preprocess dataset
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
with training_args.main_process_first(desc="train dataset map pre-processing"):
return train_dataset.map(
preprocess(tokenizer, data_args),
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
)
def load_data_collator(tokenizer, model, data_args):
label_pad_token_id = (
-100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
return DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=None,
padding=False,
)
def load_trainer(
tokenizer, model, train_dataset, data_collator, data_args, training_args
):
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams
if data_args.num_beams is not None
else training_args.generation_num_beams
)
# Init PrefixTrainer
return PrefixTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics(tokenizer, data_args)
if training_args.predict_with_generate
else None,
)
def preprocess(tokenizer, data_args):
def preprocess_function_train(examples):
# Get the column names for input/target.
prompt_column = data_args.prompt_column
response_column = data_args.response_column
max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
model_inputs = {
"input_ids": [],
"labels": [],
}
for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]:
query, answer = examples[prompt_column][i], examples[response_column][i]
a_ids = tokenizer.encode(
text=str(query),
add_special_tokens=True,
truncation=True,
max_length=data_args.max_source_length,
)
b_ids = tokenizer.encode(
text=str(answer),
add_special_tokens=False,
truncation=True,
max_length=data_args.max_target_length,
)
context_length = len(a_ids)
input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
labels = (
[tokenizer.pad_token_id] * context_length
+ b_ids
+ [tokenizer.eos_token_id]
)
pad_len = max_seq_length - len(input_ids)
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
labels = labels + [tokenizer.pad_token_id] * pad_len
if data_args.ignore_pad_token_for_loss:
labels = [
(l if l != tokenizer.pad_token_id else -100) for l in labels
]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
return preprocess_function_train
def compute_metrics(tokenizer, data_args):
def metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if data_args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
for pred, label in zip(decoded_preds, decoded_labels):
hypothesis = list(jieba.cut(pred))
reference = list(jieba.cut(label))
rouge = Rouge()
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
result = scores[0]
for k, v in result.items():
score_dict[k].append(round(v["f"] * 100, 4))
bleu_score = sentence_bleu(
[list(label)],
list(pred),
smoothing_function=SmoothingFunction().method3,
)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
for k, v in score_dict.items():
score_dict[k] = float(np.mean(v))
return score_dict
return metrics
def main():
# Load parameters
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Load config
config = AutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
)
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
)
# Load model for P-tuning v2
model = AutoModel.from_pretrained(
model_args.model_name_or_path, config=config, trust_remote_code=True
)
model = model.half()
model.transformer.prefix_encoder.float()
# Load training dataset
train_dataset = load_training_dataset(
tokenizer, data_args, model_args, training_args
)
# Load data collator
data_collator = load_data_collator
# Load trainer
trainer = load_trainer(
tokenizer, model, train_dataset, data_collator, data_args, training_args
)
# Training
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
train_result = trainer.train(resume_from_checkpoint=checkpoint)
# Save model
trainer.save_model() # Saves the tokenizer too for easy upload
# Save metrics
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples
if data_args.max_train_samples is not None
else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
# Save state
trainer.save_state()
if __name__ == "__main__":
main()