add roberta option (#135)

This commit is contained in:
Sebastian Raschka 2024-04-28 13:57:36 -05:00 committed by GitHub
parent d088753fca
commit d1edfcb63f
2 changed files with 64 additions and 17 deletions

View File

@ -97,6 +97,15 @@ Test accuracy: 90.81%
---
A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:
```bash
python train-bert-hf.py --bert_model roberta
```
---
A scikit-learn Logistic Regression model as a basline.
```bash

View File

@ -164,15 +164,26 @@ if __name__ == "__main__":
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
)
)
parser.add_argument(
"--bert_model",
type=str,
default="distilbert",
help=(
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
)
)
args = parser.parse_args()
###############################
# Load model
###############################
torch.manual_seed(123)
if args.bert_model == "distilbert":
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2
)
torch.manual_seed(123)
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
if args.trainable_layers == "last_layer":
@ -188,8 +199,36 @@ if __name__ == "__main__":
else:
raise ValueError("Invalid --trainable_layers argument.")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
elif args.bert_model == "roberta":
model = AutoModelForSequenceClassification.from_pretrained(
"FacebookAI/roberta-large", num_labels=2
)
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)
if args.trainable_layers == "last_layer":
pass
elif args.trainable_layers == "last_block":
for param in model.classifier.parameters():
param.requires_grad = True
for param in model.roberta.encoder.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")
else:
raise ValueError("Selected --bert_model not supported.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
###############################
# Instantiate dataloaders
@ -204,7 +243,6 @@ if __name__ == "__main__":
file_names = ["train.csv", "val.csv", "test.csv"]
all_exist = all((base_path / file_name).exists() for file_name in file_names)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
pad_token_id = tokenizer.encode(tokenizer.pad_token)
train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)