Add ModernBERT (#598)

This commit is contained in:
Sebastian Raschka 2025-04-05 09:13:30 -05:00 committed by GitHub
parent 396e96ab07
commit 14f976e024
4 changed files with 333 additions and 170 deletions

View File

@ -74,4 +74,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
9. **Padding vs no padding (Row 1 vs. 14 & 15, and 16)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy. In row 16, padding is applied, but the token position is selected based on the last non-padding token. Row 16 should be mathematically similar to row 15, which uses gradient accumulation. However, due to some challenges with gradient accumulation in cases of unequal token counts, there may be small discrepancies (this is discussed in [this](https://unsloth.ai/blog/gradient) blog post).
10. **Disabling the causal attention mask (Row 1 vs. 17)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 18)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7.
13. **Averaging the embeddings over all tokens (Row 1 vs. 19)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.
12. **Averaging the embeddings over all tokens (Row 1 vs. 19)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.

View File

@ -1,5 +1,28 @@
# Additional Experiments Classifying the Sentiment of 50k IMDB Movie Reviews
## Overview
This folder contains additional experiments to compare the (decoder-style) GPT-2 (2018) model from chapter 6 to encoder-style LLMs like [BERT (2018)](https://arxiv.org/abs/1810.04805), [RoBERTa (2019)](https://arxiv.org/abs/1907.11692), and [ModernBERT (2024)](https://arxiv.org/abs/2412.13663). Instead of using the small SPAM dataset from Chapter 6, we are using the 50k movie review dataset from IMDb ([dataset source](https://ai.stanford.edu/~amaas/data/sentiment/)) with a binary classification objective, predicting whether a reviewer liked the movie or not. This is a balanced dataset, so a random prediction should yield 50% accuracy.
| | Model | Test accuracy |
| ----- | ---------------------------- | ------------- |
| **1** | 124 M GPT-2 Baseline | 91.88% |
| **2** | 340 M BERT | 90.89% |
| **3** | 66 M DistilBERT | 91.40% |
| **4** | 355 M RoBERTa | 92.95% |
| **5** | 149 M ModernBERT Base | 93.79% |
| **6** | 395 M ModernBERT Large | 95.07% |
| **7** | Logistic Regression Baseline | 88.85% |
 
## Step 1: Install Dependencies
@ -24,7 +47,10 @@ python download_prepare_dataset.py
 
## Step 3: Run Models
The 124M GPT-2 model used in the main chapter, starting with pretrained weights, and finetuning all weights:
 
### 1) 124 M GPT-2 Baseline
The 124M GPT-2 model used in chapter 6, starting with pretrained weights, and finetuning all weights:
```bash
python train_gpt.py --trainable_layers "all" --num_epochs 1
@ -53,6 +79,10 @@ Test accuracy: 91.88%
<br>
&nbsp;
### 2) 340 M BERT
A 340M parameter encoder-style [BERT](https://arxiv.org/abs/1810.04805) model:
```bash
@ -81,6 +111,9 @@ Test accuracy: 90.89%
<br>
&nbsp;
### 3) 66 M DistilBERT
A 66M parameter encoder-style [DistilBERT](https://arxiv.org/abs/1910.01108) model (distilled down from a 340M parameter BERT model), starting for the pretrained weights and only training the last transformer block plus output layers:
@ -110,6 +143,9 @@ Test accuracy: 91.40%
<br>
&nbsp;
### 4) 355 M RoBERTa
A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:
@ -133,6 +169,33 @@ Validation accuracy: 93.02%
Test accuracy: 92.95%
```
<br>
---
<br>
&nbsp;
### 5) 149 M ModernBERT Base
[ModernBERT (2024)](https://arxiv.org/abs/2412.13663) is an optimized reimplementation of BERT that incorporates architectural improvements like parallel residual connections and gated linear units (GLUs) to boost efficiency and performance. It maintains BERTs original pretraining objectives while achieving faster inference and better scalability on modern hardware.
```
Ep 1 (Step 000000): Train loss 0.699, Val loss 0.698
Ep 1 (Step 000050): Train loss 0.564, Val loss 0.606
...
Ep 1 (Step 004300): Train loss 0.086, Val loss 0.168
Ep 1 (Step 004350): Train loss 0.160, Val loss 0.131
Training accuracy: 95.62% | Validation accuracy: 93.75%
Training completed in 10.27 minutes.
Evaluating on the full datasets ...
Training accuracy: 95.72%
Validation accuracy: 94.00%
Test accuracy: 93.79%
```
<br>
@ -140,7 +203,44 @@ Test accuracy: 92.95%
<br>
A scikit-learn logistic regression classifier as a baseline:
&nbsp;
### 6) 395 M ModernBERT Large
Same as above but using the larger ModernBERT variant.
```
Ep 1 (Step 000000): Train loss 0.666, Val loss 0.662
Ep 1 (Step 000050): Train loss 0.548, Val loss 0.556
...
Ep 1 (Step 004300): Train loss 0.083, Val loss 0.115
Ep 1 (Step 004350): Train loss 0.154, Val loss 0.116
Training accuracy: 96.88% | Validation accuracy: 95.62%
Training completed in 27.69 minutes.
Evaluating on the full datasets ...
Training accuracy: 97.04%
Validation accuracy: 95.30%
Test accuracy: 95.07%
```
<br>
---
<br>
&nbsp;
### 7) Logistic Regression Baseline
A scikit-learn [logistic regression](https://sebastianraschka.com/blog/2022/losses-learned-part1.html) classifier as a baseline:
```bash

View File

@ -8,55 +8,35 @@ from pathlib import Path
import time
import pandas as pd
import tiktoken
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt
class IMDBDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, use_attention_mask=False):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file)
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
self.pad_token_id = pad_token_id
self.use_attention_mask = use_attention_mask
# Pre-tokenize texts and create attention masks if required
# Pre-tokenize texts
self.encoded_texts = [
tokenizer.encode(text, truncation=True, max_length=self.max_length)
tokenizer.encode(text)[:self.max_length]
for text in self.data["text"]
]
# Pad sequences to the longest sequence
self.encoded_texts = [
et + [pad_token_id] * (self.max_length - len(et))
for et in self.encoded_texts
]
if self.use_attention_mask:
self.attention_masks = [
self._create_attention_mask(et)
for et in self.encoded_texts
]
else:
self.attention_masks = None
def _create_attention_mask(self, encoded_text):
return [1 if token_id != self.pad_token_id else 0 for token_id in encoded_text]
def __getitem__(self, index):
encoded = self.encoded_texts[index]
label = self.data.iloc[index]["label"]
if self.use_attention_mask:
attention_mask = self.attention_masks[index]
else:
attention_mask = torch.ones(self.max_length, dtype=torch.long)
return (
torch.tensor(encoded, dtype=torch.long),
torch.tensor(attention_mask, dtype=torch.long),
torch.tensor(label, dtype=torch.long)
)
return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
def __len__(self):
return len(self.data)
@ -70,27 +50,71 @@ class IMDBDataset(Dataset):
return max_length
def calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device):
attention_mask_batch = attention_mask_batch.to(device)
def instantiate_model(choose_model, load_weights):
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[choose_model])
if not load_weights:
torch.manual_seed(123)
model = GPTModel(BASE_CONFIG)
if load_weights:
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
load_weights_into_gpt(model, params)
model.eval()
return model
def calc_loss_batch(input_batch, target_batch, model, device,
trainable_token_pos=-1, average_embeddings=False):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
# logits = model(input_batch)[:, -1, :] # Logits of last output token
logits = model(input_batch, attention_mask=attention_mask_batch).logits
model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
# Same as in chapter 5
def calc_loss_loader(data_loader, model, device, num_batches=None):
def calc_loss_loader(data_loader, model, device,
num_batches=None, trainable_token_pos=-1,
average_embeddings=False):
total_loss = 0.
if num_batches is None:
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
# Reduce the number of batches to match the total number of batches in the data loader
# if num_batches exceeds the number of batches in the data loader
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader):
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device)
loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
total_loss += loss.item()
else:
break
@ -98,7 +122,9 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
@torch.no_grad() # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
def calc_accuracy_loader(data_loader, model, device,
num_batches=None, trainable_token_pos=-1,
average_embeddings=False):
model.eval()
correct_predictions, num_examples = 0, 0
@ -106,13 +132,20 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, attention_mask_batch, target_batch) in enumerate(data_loader):
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
attention_mask_batch = attention_mask_batch.to(device)
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
# logits = model(input_batch)[:, -1, :] # Logits of last output token
logits = model(input_batch, attention_mask=attention_mask_batch).logits
predicted_labels = torch.argmax(logits, dim=1)
model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
@ -120,17 +153,25 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
return correct_predictions / num_examples
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
def evaluate_model(model, train_loader, val_loader, device, eval_iter,
trainable_token_pos=-1, average_embeddings=False):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
train_loss = calc_loss_loader(
train_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
val_loss = calc_loss_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
model.train()
return train_loss, val_loss
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, max_steps=None):
eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
average_embeddings=False):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
@ -139,9 +180,10 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
for epoch in range(num_epochs):
model.train() # Set model to training mode
for input_batch, attention_mask_batch, target_batch in train_loader:
for input_batch, target_batch in train_loader:
optimizer.zero_grad() # Reset loss gradients from previous batch iteration
loss = calc_loss_batch(input_batch, attention_mask_batch, target_batch, model, device)
loss = calc_loss_batch(input_batch, target_batch, model, device,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings)
loss.backward() # Calculate loss gradients
optimizer.step() # Update model weights using loss gradients
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
@ -150,7 +192,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
# Optional evaluation step
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter)
model, train_loader, val_loader, device, eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
@ -160,8 +204,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
break
# New: Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
train_accuracy = calc_accuracy_loader(
train_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
val_accuracy = calc_accuracy_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
@ -176,28 +226,55 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_size",
type=str,
default="gpt2-small (124M)",
help=(
"Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)',"
" 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
)
)
parser.add_argument(
"--weights",
type=str,
default="pretrained",
help=(
"Whether to use 'pretrained' or 'random' weights."
)
)
parser.add_argument(
"--trainable_layers",
type=str,
default="all",
default="last_block",
help=(
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
)
)
parser.add_argument(
"--use_attention_mask",
"--trainable_token_pos",
type=str,
default="true",
default="last",
help=(
"Whether to use a attention mask for padding tokens. Options: 'true', 'false'."
"Which token to train. Options: 'first', 'last'."
)
)
parser.add_argument(
"--model",
type=str,
default="distilbert",
"--average_embeddings",
action='store_true',
default=False,
help=(
"Which model to train. Options: 'distilbert', 'bert', 'roberta'."
"Average the output embeddings from all tokens instead of using"
" only the embedding at the token position specified by `--trainable_token_pos`."
)
)
parser.add_argument(
"--context_length",
type=str,
default="256",
help=(
"The context length of the data inputs."
"Options: 'longest_training_example', 'model_context_length' or integer value."
)
)
parser.add_argument(
@ -211,96 +288,73 @@ if __name__ == "__main__":
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
default=5e-5,
help=(
"Learning rate."
)
)
parser.add_argument(
"--compile",
action="store_true",
help="If set, model compilation will be enabled."
)
args = parser.parse_args()
if args.trainable_token_pos == "first":
args.trainable_token_pos = 0
elif args.trainable_token_pos == "last":
args.trainable_token_pos = -1
else:
raise ValueError("Invalid --trainable_token_pos argument")
###############################
# Load model
###############################
torch.manual_seed(123)
if args.model == "distilbert":
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2
)
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
for param in model.parameters():
param.requires_grad = False
if args.trainable_layers == "last_layer":
for param in model.out_head.parameters():
param.requires_grad = True
elif args.trainable_layers == "last_block":
for param in model.pre_classifier.parameters():
param.requires_grad = True
for param in model.distilbert.transformer.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
elif args.model == "bert":
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased", num_labels=2
)
model.classifier = torch.nn.Linear(in_features=768, out_features=2)
for param in model.parameters():
param.requires_grad = False
if args.trainable_layers == "last_layer":
for param in model.classifier.parameters():
param.requires_grad = True
elif args.trainable_layers == "last_block":
for param in model.classifier.parameters():
param.requires_grad = True
for param in model.bert.pooler.dense.parameters():
param.requires_grad = True
for param in model.bert.encoder.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
elif args.model == "roberta":
model = AutoModelForSequenceClassification.from_pretrained(
"FacebookAI/roberta-large", num_labels=2
)
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)
for param in model.parameters():
param.requires_grad = False
if args.trainable_layers == "last_layer":
for param in model.classifier.parameters():
param.requires_grad = True
elif args.trainable_layers == "last_block":
for param in model.classifier.parameters():
param.requires_grad = True
for param in model.roberta.encoder.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")
if args.weights == "pretrained":
load_weights = True
elif args.weights == "random":
load_weights = False
else:
raise ValueError("Selected --model {args.model} not supported.")
raise ValueError("Invalid --weights argument.")
model = instantiate_model(args.model_size, load_weights)
for param in model.parameters():
param.requires_grad = False
if args.model_size == "gpt2-small (124M)":
in_features = 768
elif args.model_size == "gpt2-medium (355M)":
in_features = 1024
elif args.model_size == "gpt2-large (774M)":
in_features = 1280
elif args.model_size == "gpt2-xl (1558M)":
in_features = 1600
else:
raise ValueError("Invalid --model_size argument")
torch.manual_seed(123)
model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
if args.trainable_layers == "last_layer":
pass
elif args.trainable_layers == "last_block":
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
if args.compile:
torch.set_float32_matmul_precision("high")
model = torch.compile(model)
###############################
# Instantiate dataloaders
@ -308,34 +362,24 @@ if __name__ == "__main__":
base_path = Path(".")
if args.use_attention_mask.lower() == "true":
use_attention_mask = True
elif args.use_attention_mask.lower() == "false":
use_attention_mask = False
else:
raise ValueError("Invalid argument for `use_attention_mask`.")
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = IMDBDataset(
base_path / "train.csv",
max_length=256,
tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id,
use_attention_mask=use_attention_mask
)
val_dataset = IMDBDataset(
base_path / "validation.csv",
max_length=256,
tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id,
use_attention_mask=use_attention_mask
)
test_dataset = IMDBDataset(
base_path / "test.csv",
max_length=256,
tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id,
use_attention_mask=use_attention_mask
)
train_dataset = None
if args.context_length == "model_context_length":
max_length = model.pos_emb.weight.shape[0]
elif args.context_length == "longest_training_example":
train_dataset = IMDBDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
max_length = train_dataset.max_length
else:
try:
max_length = int(args.context_length)
except ValueError:
raise ValueError("Invalid --context_length argument")
if train_dataset is None:
train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
val_dataset = IMDBDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
num_workers = 0
batch_size = 8
@ -373,7 +417,8 @@ if __name__ == "__main__":
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=args.num_epochs, eval_freq=50, eval_iter=20,
max_steps=None
max_steps=None, trainable_token_pos=args.trainable_token_pos,
average_embeddings=args.average_embeddings
)
end_time = time.time()
@ -386,9 +431,18 @@ if __name__ == "__main__":
print("\nEvaluating on the full datasets ...\n")
train_accuracy = calc_accuracy_loader(train_loader, model, device)
val_accuracy = calc_accuracy_loader(val_loader, model, device)
test_accuracy = calc_accuracy_loader(test_loader, model, device)
train_accuracy = calc_accuracy_loader(
train_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
val_accuracy = calc_accuracy_loader(
val_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
test_accuracy = calc_accuracy_loader(
test_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")

View File

@ -293,6 +293,11 @@ if __name__ == "__main__":
"Learning rate."
)
)
parser.add_argument(
"--compile",
action="store_true",
help="If set, model compilation will be enabled."
)
args = parser.parse_args()
if args.trainable_token_pos == "first":
@ -347,6 +352,10 @@ if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
if args.compile:
torch.set_float32_matmul_precision("high")
model = torch.compile(model)
###############################
# Instantiate dataloaders
###############################