DeBERTa-v3 baseline (#630)

* Llama3 from scratch improvements

* deberta-baseline

* restore
This commit is contained in:
Sebastian Raschka 2025-04-19 21:16:17 -05:00 committed by GitHub
parent 4ff743051e
commit c278745aff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 55 additions and 20 deletions

View File

@ -10,13 +10,14 @@ This folder contains additional experiments to compare the (decoder-style) GPT-2
| | Model | Test accuracy | | | Model | Test accuracy |
| ----- | ---------------------------- | ------------- | | ----- | ---------------------------- | ------------- |
| **1** | 124 M GPT-2 Baseline | 91.88% | | **1** | 124M GPT-2 Baseline | 91.88% |
| **2** | 340 M BERT | 90.89% | | **2** | 340M BERT | 90.89% |
| **3** | 66 M DistilBERT | 91.40% | | **3** | 66M DistilBERT | 91.40% |
| **4** | 355 M RoBERTa | 92.95% | | **4** | 355M RoBERTa | 92.95% |
| **5** | 149 M ModernBERT Base | 93.79% | | **5** | 304M DeBERTa-v3 | 94.69% |
| **6** | 395 M ModernBERT Large | 95.07% | | **6** | 149M ModernBERT Base | 93.79% |
| **7** | Logistic Regression Baseline | 88.85% | | **7** | 395M ModernBERT Large | 95.07% |
| **8** | Logistic Regression Baseline | 88.85% |
@ -48,7 +49,7 @@ python download_prepare_dataset.py
## Step 3: Run Models ## Step 3: Run Models
   
### 1) 124 M GPT-2 Baseline ### 1) 124M GPT-2 Baseline
The 124M GPT-2 model used in chapter 6, starting with pretrained weights, and finetuning all weights: The 124M GPT-2 model used in chapter 6, starting with pretrained weights, and finetuning all weights:
@ -80,7 +81,7 @@ Test accuracy: 91.88%
<br> <br>
&nbsp; &nbsp;
### 2) 340 M BERT ### 2) 340M BERT
A 340M parameter encoder-style [BERT](https://arxiv.org/abs/1810.04805) model: A 340M parameter encoder-style [BERT](https://arxiv.org/abs/1810.04805) model:
@ -112,7 +113,7 @@ Test accuracy: 90.89%
<br> <br>
&nbsp; &nbsp;
### 3) 66 M DistilBERT ### 3) 66M DistilBERT
A 66M parameter encoder-style [DistilBERT](https://arxiv.org/abs/1910.01108) model (distilled down from a 340M parameter BERT model), starting for the pretrained weights and only training the last transformer block plus output layers: A 66M parameter encoder-style [DistilBERT](https://arxiv.org/abs/1910.01108) model (distilled down from a 340M parameter BERT model), starting for the pretrained weights and only training the last transformer block plus output layers:
@ -144,7 +145,7 @@ Test accuracy: 91.40%
<br> <br>
&nbsp; &nbsp;
### 4) 355 M RoBERTa ### 4) 355M RoBERTa
A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers: A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:
@ -157,6 +158,38 @@ python train_bert_hf.py --trainable_layers "last_block" --num_epochs 1 --model "
Ep 1 (Step 000000): Train loss 0.695, Val loss 0.698 Ep 1 (Step 000000): Train loss 0.695, Val loss 0.698
Ep 1 (Step 000050): Train loss 0.670, Val loss 0.690 Ep 1 (Step 000050): Train loss 0.670, Val loss 0.690
... ...
Ep 1 (Step 004300): Train loss 0.083, Val loss 0.098
Ep 1 (Step 004350): Train loss 0.170, Val loss 0.086
Training accuracy: 98.12% | Validation accuracy: 96.88%
Training completed in 11.22 minutes.
Evaluating on the full datasets ...
Training accuracy: 96.23%
Validation accuracy: 94.52%
Test accuracy: 94.69%
```
<br>
---
<br>
&nbsp;
### 5) 304M DeBERTa-v3
A 304M parameter encoder-style [DeBERTa-v3](https://arxiv.org/abs/2111.09543) model. DeBERTa-v3 improves upon earlier versions with disentangled attention and improved position encoding.
```bash
python train_bert_hf.py --trainable_layers "all" --num_epochs 1 --model "deberta-v3-base"
```
```
Ep 1 (Step 000000): Train loss 0.689, Val loss 0.694
Ep 1 (Step 000050): Train loss 0.673, Val loss 0.683
...
Ep 1 (Step 004300): Train loss 0.126, Val loss 0.149 Ep 1 (Step 004300): Train loss 0.126, Val loss 0.149
Ep 1 (Step 004350): Train loss 0.211, Val loss 0.138 Ep 1 (Step 004350): Train loss 0.211, Val loss 0.138
Training accuracy: 92.50% | Validation accuracy: 94.38% Training accuracy: 92.50% | Validation accuracy: 94.38%
@ -176,8 +209,9 @@ Test accuracy: 92.95%
<br> <br>
&nbsp; &nbsp;
### 5) 149 M ModernBERT Base ### 6) 149M ModernBERT Base
[ModernBERT (2024)](https://arxiv.org/abs/2412.13663) is an optimized reimplementation of BERT that incorporates architectural improvements like parallel residual connections and gated linear units (GLUs) to boost efficiency and performance. It maintains BERTs original pretraining objectives while achieving faster inference and better scalability on modern hardware. [ModernBERT (2024)](https://arxiv.org/abs/2412.13663) is an optimized reimplementation of BERT that incorporates architectural improvements like parallel residual connections and gated linear units (GLUs) to boost efficiency and performance. It maintains BERTs original pretraining objectives while achieving faster inference and better scalability on modern hardware.
@ -211,7 +245,7 @@ Test accuracy: 93.79%
&nbsp; &nbsp;
### 6) 395 M ModernBERT Large ### 7) 395M ModernBERT Large
Same as above but using the larger ModernBERT variant. Same as above but using the larger ModernBERT variant.
@ -248,7 +282,7 @@ Test accuracy: 95.07%
<br> <br>
&nbsp; &nbsp;
### 7) Logistic Regression Baseline ### 8) Logistic Regression Baseline
A scikit-learn [logistic regression](https://sebastianraschka.com/blog/2022/losses-learned-part1.html) classifier as a baseline: A scikit-learn [logistic regression](https://sebastianraschka.com/blog/2022/losses-learned-part1.html) classifier as a baseline:

View File

@ -197,7 +197,7 @@ if __name__ == "__main__":
type=str, type=str,
default="distilbert", default="distilbert",
help=( help=(
"Which model to train. Options: 'distilbert', 'bert', 'roberta', 'modernbert-base/-large'." "Which model to train. Options: 'distilbert', 'bert', 'roberta', 'modernbert-base/-large', 'deberta-v3-base'."
) )
) )
parser.add_argument( parser.add_argument(
@ -330,11 +330,10 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
elif args.model == "modernbert-base": elif args.model == "deberta-v3-base":
model = AutoModelForSequenceClassification.from_pretrained( model = AutoModelForSequenceClassification.from_pretrained(
"answerdotai/ModernBERT-base", num_labels=2 "microsoft/deberta-v3-base", num_labels=2
) )
print(model)
model.classifier = torch.nn.Linear(in_features=768, out_features=2) model.classifier = torch.nn.Linear(in_features=768, out_features=2)
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
@ -344,7 +343,9 @@ if __name__ == "__main__":
elif args.trainable_layers == "last_block": elif args.trainable_layers == "last_block":
for param in model.classifier.parameters(): for param in model.classifier.parameters():
param.requires_grad = True param.requires_grad = True
for param in model.layers.layer[-1].parameters(): for param in model.pooler.parameters():
param.requires_grad = True
for param in model.deberta.encoder.layer[-1].parameters():
param.requires_grad = True param.requires_grad = True
elif args.trainable_layers == "all": elif args.trainable_layers == "all":
for param in model.parameters(): for param in model.parameters():
@ -352,7 +353,7 @@ if __name__ == "__main__":
else: else:
raise ValueError("Invalid --trainable_layers argument.") raise ValueError("Invalid --trainable_layers argument.")
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
else: else:
raise ValueError("Selected --model {args.model} not supported.") raise ValueError("Selected --model {args.model} not supported.")