Add ModernBERT (#598)

This commit is contained in:
Sebastian Raschka 2025-04-05 09:13:30 -05:00 committed by GitHub
parent 396e96ab07
commit 14f976e024
4 changed files with 333 additions and 170 deletions

View File

@ -74,4 +74,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
9. **Padding vs no padding (Row 1 vs. 14 & 15, and 16)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy. In row 16, padding is applied, but the token position is selected based on the last non-padding token. Row 16 should be mathematically similar to row 15, which uses gradient accumulation. However, due to some challenges with gradient accumulation in cases of unequal token counts, there may be small discrepancies (this is discussed in [this](https://unsloth.ai/blog/gradient) blog post). 9. **Padding vs no padding (Row 1 vs. 14 & 15, and 16)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy. In row 16, padding is applied, but the token position is selected based on the last non-padding token. Row 16 should be mathematically similar to row 15, which uses gradient accumulation. However, due to some challenges with gradient accumulation in cases of unequal token counts, there may be small discrepancies (this is discussed in [this](https://unsloth.ai/blog/gradient) blog post).
10. **Disabling the causal attention mask (Row 1 vs. 17)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask. 10. **Disabling the causal attention mask (Row 1 vs. 17)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 18)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7. 11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 18)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7.
13. **Averaging the embeddings over all tokens (Row 1 vs. 19)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice. 12. **Averaging the embeddings over all tokens (Row 1 vs. 19)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.

View File

@ -1,5 +1,28 @@
# Additional Experiments Classifying the Sentiment of 50k IMDB Movie Reviews # Additional Experiments Classifying the Sentiment of 50k IMDB Movie Reviews
## Overview
This folder contains additional experiments to compare the (decoder-style) GPT-2 (2018) model from chapter 6 to encoder-style LLMs like [BERT (2018)](https://arxiv.org/abs/1810.04805), [RoBERTa (2019)](https://arxiv.org/abs/1907.11692), and [ModernBERT (2024)](https://arxiv.org/abs/2412.13663). Instead of using the small SPAM dataset from Chapter 6, we are using the 50k movie review dataset from IMDb ([dataset source](https://ai.stanford.edu/~amaas/data/sentiment/)) with a binary classification objective, predicting whether a reviewer liked the movie or not. This is a balanced dataset, so a random prediction should yield 50% accuracy.
| | Model | Test accuracy |
| ----- | ---------------------------- | ------------- |
| **1** | 124 M GPT-2 Baseline | 91.88% |
| **2** | 340 M BERT | 90.89% |
| **3** | 66 M DistilBERT | 91.40% |
| **4** | 355 M RoBERTa | 92.95% |
| **5** | 149 M ModernBERT Base | 93.79% |
| **6** | 395 M ModernBERT Large | 95.07% |
| **7** | Logistic Regression Baseline | 88.85% |
   
## Step 1: Install Dependencies ## Step 1: Install Dependencies
@ -24,7 +47,10 @@ python download_prepare_dataset.py
   
## Step 3: Run Models ## Step 3: Run Models
The 124M GPT-2 model used in the main chapter, starting with pretrained weights, and finetuning all weights:  
### 1) 124 M GPT-2 Baseline
The 124M GPT-2 model used in chapter 6, starting with pretrained weights, and finetuning all weights:
```bash ```bash
python train_gpt.py --trainable_layers "all" --num_epochs 1 python train_gpt.py --trainable_layers "all" --num_epochs 1
@ -53,6 +79,10 @@ Test accuracy: 91.88%
<br> <br>
&nbsp;
### 2) 340 M BERT
A 340M parameter encoder-style [BERT](https://arxiv.org/abs/1810.04805) model: A 340M parameter encoder-style [BERT](https://arxiv.org/abs/1810.04805) model:
```bash ```bash
@ -81,6 +111,9 @@ Test accuracy: 90.89%
<br> <br>
&nbsp;
### 3) 66 M DistilBERT
A 66M parameter encoder-style [DistilBERT](https://arxiv.org/abs/1910.01108) model (distilled down from a 340M parameter BERT model), starting for the pretrained weights and only training the last transformer block plus output layers: A 66M parameter encoder-style [DistilBERT](https://arxiv.org/abs/1910.01108) model (distilled down from a 340M parameter BERT model), starting for the pretrained weights and only training the last transformer block plus output layers:
@ -110,6 +143,9 @@ Test accuracy: 91.40%
<br> <br>
&nbsp;
### 4) 355 M RoBERTa
A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers: A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:
@ -133,6 +169,33 @@ Validation accuracy: 93.02%
Test accuracy: 92.95% Test accuracy: 92.95%
``` ```
<br>
---
<br>
&nbsp;
### 5) 149 M ModernBERT Base
[ModernBERT (2024)](https://arxiv.org/abs/2412.13663) is an optimized reimplementation of BERT that incorporates architectural improvements like parallel residual connections and gated linear units (GLUs) to boost efficiency and performance. It maintains BERTs original pretraining objectives while achieving faster inference and better scalability on modern hardware.
```
Ep 1 (Step 000000): Train loss 0.699, Val loss 0.698
Ep 1 (Step 000050): Train loss 0.564, Val loss 0.606
...
Ep 1 (Step 004300): Train loss 0.086, Val loss 0.168
Ep 1 (Step 004350): Train loss 0.160, Val loss 0.131
Training accuracy: 95.62% | Validation accuracy: 93.75%
Training completed in 10.27 minutes.
Evaluating on the full datasets ...
Training accuracy: 95.72%
Validation accuracy: 94.00%
Test accuracy: 93.79%
```
<br> <br>
@ -140,7 +203,44 @@ Test accuracy: 92.95%
<br> <br>
A scikit-learn logistic regression classifier as a baseline:
&nbsp;
### 6) 395 M ModernBERT Large
Same as above but using the larger ModernBERT variant.
```
Ep 1 (Step 000000): Train loss 0.666, Val loss 0.662
Ep 1 (Step 000050): Train loss 0.548, Val loss 0.556
...
Ep 1 (Step 004300): Train loss 0.083, Val loss 0.115
Ep 1 (Step 004350): Train loss 0.154, Val loss 0.116
Training accuracy: 96.88% | Validation accuracy: 95.62%
Training completed in 27.69 minutes.
Evaluating on the full datasets ...
Training accuracy: 97.04%
Validation accuracy: 95.30%
Test accuracy: 95.07%
```
<br>
---
<br>
&nbsp;
### 7) Logistic Regression Baseline
A scikit-learn [logistic regression](https://sebastianraschka.com/blog/2022/losses-learned-part1.html) classifier as a baseline:
```bash ```bash

View File

@ -8,55 +8,35 @@ from pathlib import Path
import time import time
import pandas as pd import pandas as pd
import tiktoken
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt
class IMDBDataset(Dataset): class IMDBDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, use_attention_mask=False): def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file) self.data = pd.read_csv(csv_file)
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer) self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
self.pad_token_id = pad_token_id
self.use_attention_mask = use_attention_mask
# Pre-tokenize texts and create attention masks if required # Pre-tokenize texts
self.encoded_texts = [ self.encoded_texts = [
tokenizer.encode(text, truncation=True, max_length=self.max_length) tokenizer.encode(text)[:self.max_length]
for text in self.data["text"] for text in self.data["text"]
] ]
# Pad sequences to the longest sequence
self.encoded_texts = [ self.encoded_texts = [
et + [pad_token_id] * (self.max_length - len(et)) et + [pad_token_id] * (self.max_length - len(et))
for et in self.encoded_texts for et in self.encoded_texts
] ]
if self.use_attention_mask:
self.attention_masks = [
self._create_attention_mask(et)
for et in self.encoded_texts
]
else:
self.attention_masks = None
def _create_attention_mask(self, encoded_text):
return [1 if token_id != self.pad_token_id else 0 for token_id in encoded_text]
def __getitem__(self, index): def __getitem__(self, index):
encoded = self.encoded_texts[index] encoded = self.encoded_texts[index]
label = self.data.iloc[index]["label"] label = self.data.iloc[index]["label"]
return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
if self.use_attention_mask:
attention_mask = self.attention_masks[index]
else:
attention_mask = torch.ones(self.max_length, dtype=torch.long)
return (
torch.tensor(encoded, dtype=torch.long),
torch.tensor(attention_mask, dtype=torch.long),
torch.tensor(label, dtype=torch.long)
)
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -70,27 +50,71 @@ class IMDBDataset(Dataset):
return max_length return max_length
def calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device): def instantiate_model(choose_model, load_weights):
attention_mask_batch = attention_mask_batch.to(device)
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[choose_model])
if not load_weights:
torch.manual_seed(123)
model = GPTModel(BASE_CONFIG)
if load_weights:
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
load_weights_into_gpt(model, params)
model.eval()
return model
def calc_loss_batch(input_batch, target_batch, model, device,
trainable_token_pos=-1, average_embeddings=False):
input_batch, target_batch = input_batch.to(device), target_batch.to(device) input_batch, target_batch = input_batch.to(device), target_batch.to(device)
# logits = model(input_batch)[:, -1, :] # Logits of last output token
logits = model(input_batch, attention_mask=attention_mask_batch).logits model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]
loss = torch.nn.functional.cross_entropy(logits, target_batch) loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss return loss
# Same as in chapter 5 def calc_loss_loader(data_loader, model, device,
def calc_loss_loader(data_loader, model, device, num_batches=None): num_batches=None, trainable_token_pos=-1,
average_embeddings=False):
total_loss = 0. total_loss = 0.
if num_batches is None: if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader) num_batches = len(data_loader)
else: else:
# Reduce the number of batches to match the total number of batches in the data loader # Reduce the number of batches to match the total number of batches in the data loader
# if num_batches exceeds the number of batches in the data loader # if num_batches exceeds the number of batches in the data loader
num_batches = min(num_batches, len(data_loader)) num_batches = min(num_batches, len(data_loader))
for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader): for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches: if i < num_batches:
loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device) loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
total_loss += loss.item() total_loss += loss.item()
else: else:
break break
@ -98,7 +122,9 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
@torch.no_grad() # Disable gradient tracking for efficiency @torch.no_grad() # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None): def calc_accuracy_loader(data_loader, model, device,
num_batches=None, trainable_token_pos=-1,
average_embeddings=False):
model.eval() model.eval()
correct_predictions, num_examples = 0, 0 correct_predictions, num_examples = 0, 0
@ -106,13 +132,20 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
num_batches = len(data_loader) num_batches = len(data_loader)
else: else:
num_batches = min(num_batches, len(data_loader)) num_batches = min(num_batches, len(data_loader))
for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader): for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches: if i < num_batches:
attention_mask_batch = attention_mask_batch.to(device)
input_batch, target_batch = input_batch.to(device), target_batch.to(device) input_batch, target_batch = input_batch.to(device), target_batch.to(device)
# logits = model(input_batch)[:, -1, :] # Logits of last output token
logits = model(input_batch, attention_mask=attention_mask_batch).logits model_output = model(input_batch)
predicted_labels = torch.argmax(logits, dim=1) if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0] num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item() correct_predictions += (predicted_labels == target_batch).sum().item()
else: else:
@ -120,17 +153,25 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
return correct_predictions / num_examples return correct_predictions / num_examples
def evaluate_model(model, train_loader, val_loader, device, eval_iter): def evaluate_model(model, train_loader, val_loader, device, eval_iter,
trainable_token_pos=-1, average_embeddings=False):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) train_loss = calc_loss_loader(
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) train_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
val_loss = calc_loss_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
model.train() model.train()
return train_loss, val_loss return train_loss, val_loss
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, max_steps=None): eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
average_embeddings=False):
# Initialize lists to track losses and tokens seen # Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], [] train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1 examples_seen, global_step = 0, -1
@ -139,9 +180,10 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
for epoch in range(num_epochs): for epoch in range(num_epochs):
model.train() # Set model to training mode model.train() # Set model to training mode
for input_batch, attention_mask_batch, target_batch in train_loader: for input_batch, target_batch in train_loader:
optimizer.zero_grad() # Reset loss gradients from previous batch iteration optimizer.zero_grad() # Reset loss gradients from previous batch iteration
loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device) loss = calc_loss_batch(input_batch, target_batch, model, device,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings)
loss.backward() # Calculate loss gradients loss.backward() # Calculate loss gradients
optimizer.step() # Update model weights using loss gradients optimizer.step() # Update model weights using loss gradients
examples_seen += input_batch.shape[0] # New: track examples instead of tokens examples_seen += input_batch.shape[0] # New: track examples instead of tokens
@ -150,7 +192,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
# Optional evaluation step # Optional evaluation step
if global_step % eval_freq == 0: if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model( train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter) model, train_loader, val_loader, device, eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
train_losses.append(train_loss) train_losses.append(train_loss)
val_losses.append(val_loss) val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): " print(f"Ep {epoch+1} (Step {global_step:06d}): "
@ -160,8 +204,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
break break
# New: Calculate accuracy after each epoch # New: Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter) train_accuracy = calc_accuracy_loader(
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter) train_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
val_accuracy = calc_accuracy_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%") print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy) train_accs.append(train_accuracy)
@ -176,28 +226,55 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--model_size",
type=str,
default="gpt2-small (124M)",
help=(
"Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)',"
" 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
)
)
parser.add_argument(
"--weights",
type=str,
default="pretrained",
help=(
"Whether to use 'pretrained' or 'random' weights."
)
)
parser.add_argument( parser.add_argument(
"--trainable_layers", "--trainable_layers",
type=str, type=str,
default="all", default="last_block",
help=( help=(
"Which layers to train. Options: 'all', 'last_block', 'last_layer'." "Which layers to train. Options: 'all', 'last_block', 'last_layer'."
) )
) )
parser.add_argument( parser.add_argument(
"--use_attention_mask", "--trainable_token_pos",
type=str, type=str,
default="true", default="last",
help=( help=(
"Whether to use a attention mask for padding tokens. Options: 'true', 'false'." "Which token to train. Options: 'first', 'last'."
) )
) )
parser.add_argument( parser.add_argument(
"--model", "--average_embeddings",
type=str, action='store_true',
default="distilbert", default=False,
help=( help=(
"Which model to train. Options: 'distilbert', 'bert', 'roberta'." "Average the output embeddings from all tokens instead of using"
" only the embedding at the token position specified by `--trainable_token_pos`."
)
)
parser.add_argument(
"--context_length",
type=str,
default="256",
help=(
"The context length of the data inputs."
"Options: 'longest_training_example', 'model_context_length' or integer value."
) )
) )
parser.add_argument( parser.add_argument(
@ -211,96 +288,73 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--learning_rate", "--learning_rate",
type=float, type=float,
default=5e-6, default=5e-5,
help=( help=(
"Learning rate." "Learning rate."
) )
) )
parser.add_argument(
"--compile",
action="store_true",
help="If set, model compilation will be enabled."
)
args = parser.parse_args() args = parser.parse_args()
if args.trainable_token_pos == "first":
args.trainable_token_pos = 0
elif args.trainable_token_pos == "last":
args.trainable_token_pos = -1
else:
raise ValueError("Invalid --trainable_token_pos argument")
############################### ###############################
# Load model # Load model
############################### ###############################
torch.manual_seed(123) if args.weights == "pretrained":
if args.model == "distilbert": load_weights = True
elif args.weights == "random":
model = AutoModelForSequenceClassification.from_pretrained( load_weights = False
"distilbert-base-uncased", num_labels=2
)
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
for param in model.parameters():
param.requires_grad = False
if args.trainable_layers == "last_layer":
for param in model.out_head.parameters():
param.requires_grad = True
elif args.trainable_layers == "last_block":
for param in model.pre_classifier.parameters():
param.requires_grad = True
for param in model.distilbert.transformer.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
elif args.model == "bert":
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased", num_labels=2
)
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
for param in model.parameters():
param.requires_grad = False
if args.trainable_layers == "last_layer":
for param in model.classifier.parameters():
param.requires_grad = True
elif args.trainable_layers == "last_block":
for param in model.classifier.parameters():
param.requires_grad = True
for param in model.bert.pooler.dense.parameters():
param.requires_grad = True
for param in model.bert.encoder.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
elif args.model == "roberta":
model = AutoModelForSequenceClassification.from_pretrained(
"FacebookAI/roberta-large", num_labels=2
)
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)
for param in model.parameters():
param.requires_grad = False
if args.trainable_layers == "last_layer":
for param in model.classifier.parameters():
param.requires_grad = True
elif args.trainable_layers == "last_block":
for param in model.classifier.parameters():
param.requires_grad = True
for param in model.roberta.encoder.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")
else: else:
raise ValueError("Selected --model {args.model} not supported.") raise ValueError("Invalid --weights argument.")
model = instantiate_model(args.model_size, load_weights)
for param in model.parameters():
param.requires_grad = False
if args.model_size == "gpt2-small (124M)":
in_features = 768
elif args.model_size == "gpt2-medium (355M)":
in_features = 1024
elif args.model_size == "gpt2-large (774M)":
in_features = 1280
elif args.model_size == "gpt2-xl (1558M)":
in_features = 1600
else:
raise ValueError("Invalid --model_size argument")
torch.manual_seed(123)
model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
if args.trainable_layers == "last_layer":
pass
elif args.trainable_layers == "last_block":
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) model.to(device)
model.eval()
if args.compile:
torch.set_float32_matmul_precision("high")
model = torch.compile(model)
############################### ###############################
# Instantiate dataloaders # Instantiate dataloaders
@ -308,34 +362,24 @@ if __name__ == "__main__":
base_path = Path(".") base_path = Path(".")
if args.use_attention_mask.lower() == "true": tokenizer = tiktoken.get_encoding("gpt2")
use_attention_mask = True
elif args.use_attention_mask.lower() == "false":
use_attention_mask = False
else:
raise ValueError("Invalid argument for `use_attention_mask`.")
train_dataset = IMDBDataset( train_dataset = None
base_path / "train.csv", if args.context_length == "model_context_length":
max_length=256, max_length = model.pos_emb.weight.shape[0]
tokenizer=tokenizer, elif args.context_length == "longest_training_example":
pad_token_id=tokenizer.pad_token_id, train_dataset = IMDBDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
use_attention_mask=use_attention_mask max_length = train_dataset.max_length
) else:
val_dataset = IMDBDataset( try:
base_path / "validation.csv", max_length = int(args.context_length)
max_length=256, except ValueError:
tokenizer=tokenizer, raise ValueError("Invalid --context_length argument")
pad_token_id=tokenizer.pad_token_id,
use_attention_mask=use_attention_mask if train_dataset is None:
) train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
test_dataset = IMDBDataset( val_dataset = IMDBDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
base_path / "test.csv", test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
max_length=256,
tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id,
use_attention_mask=use_attention_mask
)
num_workers = 0 num_workers = 0
batch_size = 8 batch_size = 8
@ -373,7 +417,8 @@ if __name__ == "__main__":
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device, model, train_loader, val_loader, optimizer, device,
num_epochs=args.num_epochs, eval_freq=50, eval_iter=20, num_epochs=args.num_epochs, eval_freq=50, eval_iter=20,
max_steps=None max_steps=None, trainable_token_pos=args.trainable_token_pos,
average_embeddings=args.average_embeddings
) )
end_time = time.time() end_time = time.time()
@ -386,9 +431,18 @@ if __name__ == "__main__":
print("\nEvaluating on the full datasets ...\n") print("\nEvaluating on the full datasets ...\n")
train_accuracy = calc_accuracy_loader(train_loader, model, device) train_accuracy = calc_accuracy_loader(
val_accuracy = calc_accuracy_loader(val_loader, model, device) train_loader, model, device,
test_accuracy = calc_accuracy_loader(test_loader, model, device) trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
val_accuracy = calc_accuracy_loader(
val_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
test_accuracy = calc_accuracy_loader(
test_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
print(f"Training accuracy: {train_accuracy*100:.2f}%") print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%") print(f"Validation accuracy: {val_accuracy*100:.2f}%")

View File

@ -293,6 +293,11 @@ if __name__ == "__main__":
"Learning rate." "Learning rate."
) )
) )
parser.add_argument(
"--compile",
action="store_true",
help="If set, model compilation will be enabled."
)
args = parser.parse_args() args = parser.parse_args()
if args.trainable_token_pos == "first": if args.trainable_token_pos == "first":
@ -347,6 +352,10 @@ if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) model.to(device)
if args.compile:
torch.set_float32_matmul_precision("high")
model = torch.compile(model)
############################### ###############################
# Instantiate dataloaders # Instantiate dataloaders
############################### ###############################