mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-30 11:31:08 +00:00
Add ModernBERT (#598)
This commit is contained in:
parent
396e96ab07
commit
14f976e024
@ -74,4 +74,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
|
||||
9. **Padding vs no padding (Row 1 vs. 14 & 15, and 16)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy. In row 16, padding is applied, but the token position is selected based on the last non-padding token. Row 16 should be mathematically similar to row 15, which uses gradient accumulation. However, due to some challenges with gradient accumulation in cases of unequal token counts, there may be small discrepancies (this is discussed in [this](https://unsloth.ai/blog/gradient) blog post).
|
||||
10. **Disabling the causal attention mask (Row 1 vs. 17)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
|
||||
11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 18)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7.
|
||||
13. **Averaging the embeddings over all tokens (Row 1 vs. 19)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.
|
||||
12. **Averaging the embeddings over all tokens (Row 1 vs. 19)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.
|
||||
|
@ -1,5 +1,28 @@
|
||||
# Additional Experiments Classifying the Sentiment of 50k IMDB Movie Reviews
|
||||
|
||||
## Overview
|
||||
|
||||
This folder contains additional experiments to compare the (decoder-style) GPT-2 (2018) model from chapter 6 to encoder-style LLMs like [BERT (2018)](https://arxiv.org/abs/1810.04805), [RoBERTa (2019)](https://arxiv.org/abs/1907.11692), and [ModernBERT (2024)](https://arxiv.org/abs/2412.13663). Instead of using the small SPAM dataset from Chapter 6, we are using the 50k movie review dataset from IMDb ([dataset source](https://ai.stanford.edu/~amaas/data/sentiment/)) with a binary classification objective, predicting whether a reviewer liked the movie or not. This is a balanced dataset, so a random prediction should yield 50% accuracy.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
| | Model | Test accuracy |
|
||||
| ----- | ---------------------------- | ------------- |
|
||||
| **1** | 124 M GPT-2 Baseline | 91.88% |
|
||||
| **2** | 340 M BERT | 90.89% |
|
||||
| **3** | 66 M DistilBERT | 91.40% |
|
||||
| **4** | 355 M RoBERTa | 92.95% |
|
||||
| **5** | 149 M ModernBERT Base | 93.79% |
|
||||
| **6** | 395 M ModernBERT Large | 95.07% |
|
||||
| **7** | Logistic Regression Baseline | 88.85% |
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Step 1: Install Dependencies
|
||||
|
||||
@ -24,7 +47,10 @@ python download_prepare_dataset.py
|
||||
|
||||
## Step 3: Run Models
|
||||
|
||||
The 124M GPT-2 model used in the main chapter, starting with pretrained weights, and finetuning all weights:
|
||||
|
||||
### 1) 124 M GPT-2 Baseline
|
||||
|
||||
The 124M GPT-2 model used in chapter 6, starting with pretrained weights, and finetuning all weights:
|
||||
|
||||
```bash
|
||||
python train_gpt.py --trainable_layers "all" --num_epochs 1
|
||||
@ -53,6 +79,10 @@ Test accuracy: 91.88%
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
### 2) 340 M BERT
|
||||
|
||||
|
||||
A 340M parameter encoder-style [BERT](https://arxiv.org/abs/1810.04805) model:
|
||||
|
||||
```bash
|
||||
@ -81,6 +111,9 @@ Test accuracy: 90.89%
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
### 3) 66 M DistilBERT
|
||||
|
||||
A 66M parameter encoder-style [DistilBERT](https://arxiv.org/abs/1910.01108) model (distilled down from a 340M parameter BERT model), starting for the pretrained weights and only training the last transformer block plus output layers:
|
||||
|
||||
|
||||
@ -110,6 +143,9 @@ Test accuracy: 91.40%
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
### 4) 355 M RoBERTa
|
||||
|
||||
A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:
|
||||
|
||||
|
||||
@ -133,6 +169,33 @@ Validation accuracy: 93.02%
|
||||
Test accuracy: 92.95%
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
---
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
|
||||
### 5) 149 M ModernBERT Base
|
||||
|
||||
[ModernBERT (2024)](https://arxiv.org/abs/2412.13663) is an optimized reimplementation of BERT that incorporates architectural improvements like parallel residual connections and gated linear units (GLUs) to boost efficiency and performance. It maintains BERT’s original pretraining objectives while achieving faster inference and better scalability on modern hardware.
|
||||
|
||||
```
|
||||
Ep 1 (Step 000000): Train loss 0.699, Val loss 0.698
|
||||
Ep 1 (Step 000050): Train loss 0.564, Val loss 0.606
|
||||
...
|
||||
Ep 1 (Step 004300): Train loss 0.086, Val loss 0.168
|
||||
Ep 1 (Step 004350): Train loss 0.160, Val loss 0.131
|
||||
Training accuracy: 95.62% | Validation accuracy: 93.75%
|
||||
Training completed in 10.27 minutes.
|
||||
|
||||
Evaluating on the full datasets ...
|
||||
|
||||
Training accuracy: 95.72%
|
||||
Validation accuracy: 94.00%
|
||||
Test accuracy: 93.79%
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
@ -140,7 +203,44 @@ Test accuracy: 92.95%
|
||||
|
||||
<br>
|
||||
|
||||
A scikit-learn logistic regression classifier as a baseline:
|
||||
|
||||
|
||||
### 6) 395 M ModernBERT Large
|
||||
|
||||
Same as above but using the larger ModernBERT variant.
|
||||
|
||||
|
||||
|
||||
```
|
||||
Ep 1 (Step 000000): Train loss 0.666, Val loss 0.662
|
||||
Ep 1 (Step 000050): Train loss 0.548, Val loss 0.556
|
||||
...
|
||||
Ep 1 (Step 004300): Train loss 0.083, Val loss 0.115
|
||||
Ep 1 (Step 004350): Train loss 0.154, Val loss 0.116
|
||||
Training accuracy: 96.88% | Validation accuracy: 95.62%
|
||||
Training completed in 27.69 minutes.
|
||||
|
||||
Evaluating on the full datasets ...
|
||||
|
||||
Training accuracy: 97.04%
|
||||
Validation accuracy: 95.30%
|
||||
Test accuracy: 95.07%
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<br>
|
||||
|
||||
---
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
### 7) Logistic Regression Baseline
|
||||
|
||||
A scikit-learn [logistic regression](https://sebastianraschka.com/blog/2022/losses-learned-part1.html) classifier as a baseline:
|
||||
|
||||
|
||||
```bash
|
||||
|
@ -8,55 +8,35 @@ from pathlib import Path
|
||||
import time
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from gpt_download import download_and_load_gpt2
|
||||
from previous_chapters import GPTModel, load_weights_into_gpt
|
||||
|
||||
|
||||
class IMDBDataset(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, use_attention_mask=False):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
|
||||
self.pad_token_id = pad_token_id
|
||||
self.use_attention_mask = use_attention_mask
|
||||
|
||||
# Pre-tokenize texts and create attention masks if required
|
||||
# Pre-tokenize texts
|
||||
self.encoded_texts = [
|
||||
tokenizer.encode(text, truncation=True, max_length=self.max_length)
|
||||
tokenizer.encode(text)[:self.max_length]
|
||||
for text in self.data["text"]
|
||||
]
|
||||
# Pad sequences to the longest sequence
|
||||
self.encoded_texts = [
|
||||
et + [pad_token_id] * (self.max_length - len(et))
|
||||
for et in self.encoded_texts
|
||||
]
|
||||
|
||||
if self.use_attention_mask:
|
||||
self.attention_masks = [
|
||||
self._create_attention_mask(et)
|
||||
for et in self.encoded_texts
|
||||
]
|
||||
else:
|
||||
self.attention_masks = None
|
||||
|
||||
def _create_attention_mask(self, encoded_text):
|
||||
return [1 if token_id != self.pad_token_id else 0 for token_id in encoded_text]
|
||||
|
||||
def __getitem__(self, index):
|
||||
encoded = self.encoded_texts[index]
|
||||
label = self.data.iloc[index]["label"]
|
||||
|
||||
if self.use_attention_mask:
|
||||
attention_mask = self.attention_masks[index]
|
||||
else:
|
||||
attention_mask = torch.ones(self.max_length, dtype=torch.long)
|
||||
|
||||
return (
|
||||
torch.tensor(encoded, dtype=torch.long),
|
||||
torch.tensor(attention_mask, dtype=torch.long),
|
||||
torch.tensor(label, dtype=torch.long)
|
||||
)
|
||||
return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -70,27 +50,71 @@ class IMDBDataset(Dataset):
|
||||
return max_length
|
||||
|
||||
|
||||
def calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device):
|
||||
attention_mask_batch = attention_mask_batch.to(device)
|
||||
def instantiate_model(choose_model, load_weights):
|
||||
|
||||
BASE_CONFIG = {
|
||||
"vocab_size": 50257, # Vocabulary size
|
||||
"context_length": 1024, # Context length
|
||||
"drop_rate": 0.0, # Dropout rate
|
||||
"qkv_bias": True # Query-key-value bias
|
||||
}
|
||||
|
||||
model_configs = {
|
||||
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
||||
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
||||
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
||||
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
||||
}
|
||||
|
||||
BASE_CONFIG.update(model_configs[choose_model])
|
||||
|
||||
if not load_weights:
|
||||
torch.manual_seed(123)
|
||||
model = GPTModel(BASE_CONFIG)
|
||||
|
||||
if load_weights:
|
||||
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
|
||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||
load_weights_into_gpt(model, params)
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device,
|
||||
trainable_token_pos=-1, average_embeddings=False):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
# logits = model(input_batch)[:, -1, :] # Logits of last output token
|
||||
logits = model(input_batch, attention_mask=attention_mask_batch).logits
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
||||
return loss
|
||||
|
||||
|
||||
# Same as in chapter 5
|
||||
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||
def calc_loss_loader(data_loader, model, device,
|
||||
num_batches=None, trainable_token_pos=-1,
|
||||
average_embeddings=False):
|
||||
total_loss = 0.
|
||||
if num_batches is None:
|
||||
if len(data_loader) == 0:
|
||||
return float("nan")
|
||||
elif num_batches is None:
|
||||
num_batches = len(data_loader)
|
||||
else:
|
||||
# Reduce the number of batches to match the total number of batches in the data loader
|
||||
# if num_batches exceeds the number of batches in the data loader
|
||||
num_batches = min(num_batches, len(data_loader))
|
||||
for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader):
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device)
|
||||
loss = calc_loss_batch(
|
||||
input_batch, target_batch, model, device,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
total_loss += loss.item()
|
||||
else:
|
||||
break
|
||||
@ -98,7 +122,9 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||
|
||||
|
||||
@torch.no_grad() # Disable gradient tracking for efficiency
|
||||
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
||||
def calc_accuracy_loader(data_loader, model, device,
|
||||
num_batches=None, trainable_token_pos=-1,
|
||||
average_embeddings=False):
|
||||
model.eval()
|
||||
correct_predictions, num_examples = 0, 0
|
||||
|
||||
@ -106,13 +132,20 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
||||
num_batches = len(data_loader)
|
||||
else:
|
||||
num_batches = min(num_batches, len(data_loader))
|
||||
for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader):
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
attention_mask_batch = attention_mask_batch.to(device)
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
# logits = model(input_batch)[:, -1, :] # Logits of last output token
|
||||
logits = model(input_batch, attention_mask=attention_mask_batch).logits
|
||||
predicted_labels = torch.argmax(logits, dim=1)
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
predicted_labels = torch.argmax(logits, dim=-1)
|
||||
|
||||
num_examples += predicted_labels.shape[0]
|
||||
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||||
else:
|
||||
@ -120,17 +153,25 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
||||
return correct_predictions / num_examples
|
||||
|
||||
|
||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter,
|
||||
trainable_token_pos=-1, average_embeddings=False):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
|
||||
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
|
||||
train_loss = calc_loss_loader(
|
||||
train_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
val_loss = calc_loss_loader(
|
||||
val_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
model.train()
|
||||
return train_loss, val_loss
|
||||
|
||||
|
||||
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
||||
eval_freq, eval_iter, max_steps=None):
|
||||
eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
|
||||
average_embeddings=False):
|
||||
# Initialize lists to track losses and tokens seen
|
||||
train_losses, val_losses, train_accs, val_accs = [], [], [], []
|
||||
examples_seen, global_step = 0, -1
|
||||
@ -139,9 +180,10 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
for epoch in range(num_epochs):
|
||||
model.train() # Set model to training mode
|
||||
|
||||
for input_batch, attention_mask_batch, target_batch in train_loader:
|
||||
for input_batch, target_batch in train_loader:
|
||||
optimizer.zero_grad() # Reset loss gradients from previous batch iteration
|
||||
loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device)
|
||||
loss = calc_loss_batch(input_batch, target_batch, model, device,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings)
|
||||
loss.backward() # Calculate loss gradients
|
||||
optimizer.step() # Update model weights using loss gradients
|
||||
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
|
||||
@ -150,7 +192,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
# Optional evaluation step
|
||||
if global_step % eval_freq == 0:
|
||||
train_loss, val_loss = evaluate_model(
|
||||
model, train_loader, val_loader, device, eval_iter)
|
||||
model, train_loader, val_loader, device, eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
train_losses.append(train_loss)
|
||||
val_losses.append(val_loss)
|
||||
print(f"Ep {epoch+1} (Step {global_step:06d}): "
|
||||
@ -160,8 +204,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
break
|
||||
|
||||
# New: Calculate accuracy after each epoch
|
||||
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
|
||||
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
|
||||
train_accuracy = calc_accuracy_loader(
|
||||
train_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
val_accuracy = calc_accuracy_loader(
|
||||
val_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
||||
)
|
||||
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
|
||||
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
||||
train_accs.append(train_accuracy)
|
||||
@ -176,28 +226,55 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="gpt2-small (124M)",
|
||||
help=(
|
||||
"Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)',"
|
||||
" 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weights",
|
||||
type=str,
|
||||
default="pretrained",
|
||||
help=(
|
||||
"Whether to use 'pretrained' or 'random' weights."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainable_layers",
|
||||
type=str,
|
||||
default="all",
|
||||
default="last_block",
|
||||
help=(
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_attention_mask",
|
||||
"--trainable_token_pos",
|
||||
type=str,
|
||||
default="true",
|
||||
default="last",
|
||||
help=(
|
||||
"Whether to use a attention mask for padding tokens. Options: 'true', 'false'."
|
||||
"Which token to train. Options: 'first', 'last'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="distilbert",
|
||||
"--average_embeddings",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=(
|
||||
"Which model to train. Options: 'distilbert', 'bert', 'roberta'."
|
||||
"Average the output embeddings from all tokens instead of using"
|
||||
" only the embedding at the token position specified by `--trainable_token_pos`."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_length",
|
||||
type=str,
|
||||
default="256",
|
||||
help=(
|
||||
"The context length of the data inputs."
|
||||
"Options: 'longest_training_example', 'model_context_length' or integer value."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -211,96 +288,73 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
default=5e-5,
|
||||
help=(
|
||||
"Learning rate."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compile",
|
||||
action="store_true",
|
||||
help="If set, model compilation will be enabled."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.trainable_token_pos == "first":
|
||||
args.trainable_token_pos = 0
|
||||
elif args.trainable_token_pos == "last":
|
||||
args.trainable_token_pos = -1
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_token_pos argument")
|
||||
|
||||
###############################
|
||||
# Load model
|
||||
###############################
|
||||
|
||||
torch.manual_seed(123)
|
||||
if args.model == "distilbert":
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"distilbert-base-uncased", num_labels=2
|
||||
)
|
||||
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
if args.trainable_layers == "last_layer":
|
||||
for param in model.out_head.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.pre_classifier.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.distilbert.transformer.layer[-1].parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
elif args.model == "bert":
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased", num_labels=2
|
||||
)
|
||||
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
if args.trainable_layers == "last_layer":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.bert.pooler.dense.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.bert.encoder.layer[-1].parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
elif args.model == "roberta":
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"FacebookAI/roberta-large", num_labels=2
|
||||
)
|
||||
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
if args.trainable_layers == "last_layer":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.roberta.encoder.layer[-1].parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")
|
||||
if args.weights == "pretrained":
|
||||
load_weights = True
|
||||
elif args.weights == "random":
|
||||
load_weights = False
|
||||
else:
|
||||
raise ValueError("Selected --model {args.model} not supported.")
|
||||
raise ValueError("Invalid --weights argument.")
|
||||
|
||||
model = instantiate_model(args.model_size, load_weights)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if args.model_size == "gpt2-small (124M)":
|
||||
in_features = 768
|
||||
elif args.model_size == "gpt2-medium (355M)":
|
||||
in_features = 1024
|
||||
elif args.model_size == "gpt2-large (774M)":
|
||||
in_features = 1280
|
||||
elif args.model_size == "gpt2-xl (1558M)":
|
||||
in_features = 1600
|
||||
else:
|
||||
raise ValueError("Invalid --model_size argument")
|
||||
|
||||
torch.manual_seed(123)
|
||||
model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
|
||||
|
||||
if args.trainable_layers == "last_layer":
|
||||
pass
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.trf_blocks[-1].parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.final_norm.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.compile:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
model = torch.compile(model)
|
||||
|
||||
###############################
|
||||
# Instantiate dataloaders
|
||||
@ -308,34 +362,24 @@ if __name__ == "__main__":
|
||||
|
||||
base_path = Path(".")
|
||||
|
||||
if args.use_attention_mask.lower() == "true":
|
||||
use_attention_mask = True
|
||||
elif args.use_attention_mask.lower() == "false":
|
||||
use_attention_mask = False
|
||||
else:
|
||||
raise ValueError("Invalid argument for `use_attention_mask`.")
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_dataset = IMDBDataset(
|
||||
base_path / "train.csv",
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_attention_mask=use_attention_mask
|
||||
)
|
||||
val_dataset = IMDBDataset(
|
||||
base_path / "validation.csv",
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_attention_mask=use_attention_mask
|
||||
)
|
||||
test_dataset = IMDBDataset(
|
||||
base_path / "test.csv",
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_attention_mask=use_attention_mask
|
||||
)
|
||||
train_dataset = None
|
||||
if args.context_length == "model_context_length":
|
||||
max_length = model.pos_emb.weight.shape[0]
|
||||
elif args.context_length == "longest_training_example":
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
|
||||
max_length = train_dataset.max_length
|
||||
else:
|
||||
try:
|
||||
max_length = int(args.context_length)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid --context_length argument")
|
||||
|
||||
if train_dataset is None:
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
val_dataset = IMDBDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
|
||||
num_workers = 0
|
||||
batch_size = 8
|
||||
@ -373,7 +417,8 @@ if __name__ == "__main__":
|
||||
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=args.num_epochs, eval_freq=50, eval_iter=20,
|
||||
max_steps=None
|
||||
max_steps=None, trainable_token_pos=args.trainable_token_pos,
|
||||
average_embeddings=args.average_embeddings
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
@ -386,9 +431,18 @@ if __name__ == "__main__":
|
||||
|
||||
print("\nEvaluating on the full datasets ...\n")
|
||||
|
||||
train_accuracy = calc_accuracy_loader(train_loader, model, device)
|
||||
val_accuracy = calc_accuracy_loader(val_loader, model, device)
|
||||
test_accuracy = calc_accuracy_loader(test_loader, model, device)
|
||||
train_accuracy = calc_accuracy_loader(
|
||||
train_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
val_accuracy = calc_accuracy_loader(
|
||||
val_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
test_accuracy = calc_accuracy_loader(
|
||||
test_loader, model, device,
|
||||
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
||||
)
|
||||
|
||||
print(f"Training accuracy: {train_accuracy*100:.2f}%")
|
||||
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
||||
|
@ -293,6 +293,11 @@ if __name__ == "__main__":
|
||||
"Learning rate."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compile",
|
||||
action="store_true",
|
||||
help="If set, model compilation will be enabled."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.trainable_token_pos == "first":
|
||||
@ -347,6 +352,10 @@ if __name__ == "__main__":
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
if args.compile:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
model = torch.compile(model)
|
||||
|
||||
###############################
|
||||
# Instantiate dataloaders
|
||||
###############################
|
||||
|
Loading…
x
Reference in New Issue
Block a user