From d1edfcb63f4f7393ec46a2f9dbc46c35743e4fca Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sun, 28 Apr 2024 13:57:36 -0500 Subject: [PATCH] add roberta option (#135) --- ch06/03_bonus_imdb-classification/README.md | 9 +++ .../train-bert-hf.py | 72 ++++++++++++++----- 2 files changed, 64 insertions(+), 17 deletions(-) diff --git a/ch06/03_bonus_imdb-classification/README.md b/ch06/03_bonus_imdb-classification/README.md index a5e96f6..e5c35fa 100644 --- a/ch06/03_bonus_imdb-classification/README.md +++ b/ch06/03_bonus_imdb-classification/README.md @@ -97,6 +97,15 @@ Test accuracy: 90.81% --- +A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers: + + +```bash +python train-bert-hf.py --bert_model roberta +``` + +--- + A scikit-learn Logistic Regression model as a basline. ```bash diff --git a/ch06/03_bonus_imdb-classification/train-bert-hf.py b/ch06/03_bonus_imdb-classification/train-bert-hf.py index 5337593..df78cd9 100644 --- a/ch06/03_bonus_imdb-classification/train-bert-hf.py +++ b/ch06/03_bonus_imdb-classification/train-bert-hf.py @@ -164,32 +164,71 @@ if __name__ == "__main__": "Which layers to train. Options: 'all', 'last_block', 'last_layer'." ) ) + parser.add_argument( + "--bert_model", + type=str, + default="distilbert", + help=( + "Which layers to train. Options: 'all', 'last_block', 'last_layer'." + ) + ) args = parser.parse_args() ############################### # Load model ############################### - model = AutoModelForSequenceClassification.from_pretrained( - "distilbert-base-uncased", num_labels=2 - ) - torch.manual_seed(123) - model.out_head = torch.nn.Linear(in_features=768, out_features=2) - if args.trainable_layers == "last_layer": - pass - elif args.trainable_layers == "last_block": - for param in model.pre_classifier.parameters(): - param.requires_grad = True - for param in model.distilbert.transformer.layer[-1].parameters(): - param.requires_grad = True - elif args.trainable_layers == "all": - for param in model.parameters(): - param.requires_grad = True + torch.manual_seed(123) + if args.bert_model == "distilbert": + + model = AutoModelForSequenceClassification.from_pretrained( + "distilbert-base-uncased", num_labels=2 + ) + model.out_head = torch.nn.Linear(in_features=768, out_features=2) + + if args.trainable_layers == "last_layer": + pass + elif args.trainable_layers == "last_block": + for param in model.pre_classifier.parameters(): + param.requires_grad = True + for param in model.distilbert.transformer.layer[-1].parameters(): + param.requires_grad = True + elif args.trainable_layers == "all": + for param in model.parameters(): + param.requires_grad = True + else: + raise ValueError("Invalid --trainable_layers argument.") + + tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") + + elif args.bert_model == "roberta": + + model = AutoModelForSequenceClassification.from_pretrained( + "FacebookAI/roberta-large", num_labels=2 + ) + model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2) + + if args.trainable_layers == "last_layer": + pass + elif args.trainable_layers == "last_block": + for param in model.classifier.parameters(): + param.requires_grad = True + for param in model.roberta.encoder.layer[-1].parameters(): + param.requires_grad = True + elif args.trainable_layers == "all": + for param in model.parameters(): + param.requires_grad = True + else: + raise ValueError("Invalid --trainable_layers argument.") + + tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large") + else: - raise ValueError("Invalid --trainable_layers argument.") + raise ValueError("Selected --bert_model not supported.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) + model.eval() ############################### # Instantiate dataloaders @@ -204,7 +243,6 @@ if __name__ == "__main__": file_names = ["train.csv", "val.csv", "test.csv"] all_exist = all((base_path / file_name).exists() for file_name in file_names) - tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") pad_token_id = tokenizer.encode(tokenizer.pad_token) train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)