mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 19:10:19 +00:00
add RoBERTa and params frozen (#335)
* add roberta experiment result * add roberta & params frozen * Update README.md * modify lr * modify lr --------- Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
parent
12655594d5
commit
5acab58d41
@ -56,7 +56,7 @@ Test accuracy: 91.88%
|
|||||||
A 340M parameter encoder-style [BERT](https://arxiv.org/abs/1810.04805) model:
|
A 340M parameter encoder-style [BERT](https://arxiv.org/abs/1810.04805) model:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
!python train_bert_hf --trainable_layers "all" --num_epochs 1 --model "bert"
|
python train_bert_hf --trainable_layers "all" --num_epochs 1 --model "bert"
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -86,7 +86,7 @@ A 66M parameter encoder-style [DistilBERT](https://arxiv.org/abs/1910.01108) mod
|
|||||||
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
!python train_bert_hf.py --trainable_layers "all" --num_epochs 1 --model "distilbert"
|
python train_bert_hf.py --trainable_layers "all" --num_epochs 1 --model "distilbert"
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -104,4 +104,31 @@ Training accuracy: 95.30%
|
|||||||
Validation accuracy: 91.12%
|
Validation accuracy: 91.12%
|
||||||
Test accuracy: 91.40%
|
Test accuracy: 91.40%
|
||||||
```
|
```
|
||||||
|
<br>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train_bert_hf.py --trainable_layers "last_block" --num_epochs 1 --bert_model "roberta"
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
Ep 1 (Step 000000): Train loss 0.695, Val loss 0.698
|
||||||
|
Ep 1 (Step 000050): Train loss 0.670, Val loss 0.690
|
||||||
|
...
|
||||||
|
Ep 1 (Step 004300): Train loss 0.126, Val loss 0.149
|
||||||
|
Ep 1 (Step 004350): Train loss 0.211, Val loss 0.138
|
||||||
|
Training accuracy: 92.50% | Validation accuracy: 94.38%
|
||||||
|
Training completed in 7.20 minutes.
|
||||||
|
|
||||||
|
Evaluating on the full datasets ...
|
||||||
|
|
||||||
|
Training accuracy: 93.44%
|
||||||
|
Validation accuracy: 93.02%
|
||||||
|
Test accuracy: 92.95%
|
||||||
|
```
|
||||||
|
@ -208,6 +208,14 @@ if __name__ == "__main__":
|
|||||||
"Number of epochs."
|
"Number of epochs."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-6,
|
||||||
|
help=(
|
||||||
|
"Learning rate."
|
||||||
|
)
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
@ -221,9 +229,11 @@ if __name__ == "__main__":
|
|||||||
"distilbert-base-uncased", num_labels=2
|
"distilbert-base-uncased", num_labels=2
|
||||||
)
|
)
|
||||||
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
|
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
if args.trainable_layers == "last_layer":
|
if args.trainable_layers == "last_layer":
|
||||||
pass
|
for param in model.out_head.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
elif args.trainable_layers == "last_block":
|
elif args.trainable_layers == "last_block":
|
||||||
for param in model.pre_classifier.parameters():
|
for param in model.pre_classifier.parameters():
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
@ -243,9 +253,11 @@ if __name__ == "__main__":
|
|||||||
"bert-base-uncased", num_labels=2
|
"bert-base-uncased", num_labels=2
|
||||||
)
|
)
|
||||||
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
|
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
if args.trainable_layers == "last_layer":
|
if args.trainable_layers == "last_layer":
|
||||||
pass
|
for param in model.classifier.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
elif args.trainable_layers == "last_block":
|
elif args.trainable_layers == "last_block":
|
||||||
for param in model.classifier.parameters():
|
for param in model.classifier.parameters():
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
@ -260,7 +272,29 @@ if __name__ == "__main__":
|
|||||||
raise ValueError("Invalid --trainable_layers argument.")
|
raise ValueError("Invalid --trainable_layers argument.")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
elif args.bert_model == "roberta":
|
||||||
|
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
"FacebookAI/roberta-large", num_labels=2
|
||||||
|
)
|
||||||
|
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if args.trainable_layers == "last_layer":
|
||||||
|
for param in model.classifier.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
elif args.trainable_layers == "last_block":
|
||||||
|
for param in model.classifier.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
for param in model.roberta.encoder.layer[-1].parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
elif args.trainable_layers == "all":
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid --trainable_layers argument.")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Selected --bert_model not supported.")
|
raise ValueError("Selected --bert_model not supported.")
|
||||||
|
|
||||||
@ -334,7 +368,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1)
|
||||||
|
|
||||||
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
||||||
model, train_loader, val_loader, optimizer, device,
|
model, train_loader, val_loader, optimizer, device,
|
||||||
|
@ -299,6 +299,14 @@ if __name__ == "__main__":
|
|||||||
"Number of epochs."
|
"Number of epochs."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-6,
|
||||||
|
help=(
|
||||||
|
"Learning rate."
|
||||||
|
)
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
@ -312,9 +320,11 @@ if __name__ == "__main__":
|
|||||||
"distilbert-base-uncased", num_labels=2
|
"distilbert-base-uncased", num_labels=2
|
||||||
)
|
)
|
||||||
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
|
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
if args.trainable_layers == "last_layer":
|
if args.trainable_layers == "last_layer":
|
||||||
pass
|
for param in model.out_head.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
elif args.trainable_layers == "last_block":
|
elif args.trainable_layers == "last_block":
|
||||||
for param in model.pre_classifier.parameters():
|
for param in model.pre_classifier.parameters():
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
@ -334,9 +344,11 @@ if __name__ == "__main__":
|
|||||||
"bert-base-uncased", num_labels=2
|
"bert-base-uncased", num_labels=2
|
||||||
)
|
)
|
||||||
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
|
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
if args.trainable_layers == "last_layer":
|
if args.trainable_layers == "last_layer":
|
||||||
pass
|
for param in model.classifier.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
elif args.trainable_layers == "last_block":
|
elif args.trainable_layers == "last_block":
|
||||||
for param in model.classifier.parameters():
|
for param in model.classifier.parameters():
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
@ -351,7 +363,29 @@ if __name__ == "__main__":
|
|||||||
raise ValueError("Invalid --trainable_layers argument.")
|
raise ValueError("Invalid --trainable_layers argument.")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
elif args.bert_model == "roberta":
|
||||||
|
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
"FacebookAI/roberta-large", num_labels=2
|
||||||
|
)
|
||||||
|
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if args.trainable_layers == "last_layer":
|
||||||
|
for param in model.classifier.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
elif args.trainable_layers == "last_block":
|
||||||
|
for param in model.classifier.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
for param in model.roberta.encoder.layer[-1].parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
elif args.trainable_layers == "all":
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid --trainable_layers argument.")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Selected --bert_model not supported.")
|
raise ValueError("Selected --bert_model not supported.")
|
||||||
|
|
||||||
@ -436,7 +470,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1)
|
||||||
|
|
||||||
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
||||||
model, train_loader, val_loader, optimizer, device,
|
model, train_loader, val_loader, optimizer, device,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user