Add GPTModelFast (#584)

* Add GPTModelFast

* update
This commit is contained in:
Sebastian Raschka 2025-03-27 14:00:25 -05:00 committed by GitHub
parent c9271ac427
commit e07a7abdd5
7 changed files with 204 additions and 61 deletions

View File

@ -50,11 +50,12 @@ Once installed, you can import code from any chapter using:
from llms_from_scratch.ch02 import GPTDatasetV1, create_dataloader_v1 from llms_from_scratch.ch02 import GPTDatasetV1, create_dataloader_v1
from llms_from_scratch.ch03 import ( from llms_from_scratch.ch03 import (
MultiHeadAttention,
SelfAttention_v1, SelfAttention_v1,
SelfAttention_v2, SelfAttention_v2,
CausalAttention, CausalAttention,
MultiHeadAttentionWrapper MultiHeadAttentionWrapper,
MultiHeadAttention,
PyTorchMultiHeadAttention # Bonus: Faster variant using PyTorch's scaled_dot_product_attention
) )
from llms_from_scratch.ch04 import ( from llms_from_scratch.ch04 import (
@ -63,6 +64,7 @@ from llms_from_scratch.ch04 import (
FeedForward, FeedForward,
TransformerBlock, TransformerBlock,
GPTModel, GPTModel,
GPTModelFast # Bonus: Faster variant using PyTorch's scaled_dot_product_attention
generate_text_simple generate_text_simple
) )

View File

@ -149,3 +149,50 @@ class MultiHeadAttention(nn.Module):
context_vec = self.out_proj(context_vec) # optional projection context_vec = self.out_proj(context_vec) # optional projection
return context_vec return context_vec
######################
# Bonus
######################
class PyTorchMultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.d_out = d_out
self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
self.proj = nn.Linear(d_out, d_out)
self.dropout = dropout
def forward(self, x):
batch_size, num_tokens, embed_dim = x.shape
# (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
qkv = self.qkv(x)
# (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
# (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
# (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
queries, keys, values = qkv
use_dropout = 0. if not self.training else self.dropout
context_vec = nn.functional.scaled_dot_product_attention(
queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)
context_vec = self.proj(context_vec)
return context_vec

View File

@ -3,7 +3,7 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch # - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch
from .ch03 import MultiHeadAttention from .ch03 import MultiHeadAttention, PyTorchMultiHeadAttention
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -128,3 +128,90 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1) idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
return idx return idx
######################
# Bonus
######################
class FeedForwardFast(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
nn.GELU(approximate="tanh"),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)
def forward(self, x):
return self.layers(x)
class TransformerBlockFast(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = PyTorchMultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
self.ff = FeedForwardFast(cfg)
self.norm1 = nn.LayerNorm(cfg["emb_dim"])
self.norm2 = nn.LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
def forward(self, x):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
return x
class GPTModelFast(nn.Module):
"""
A faster variant of GPTModel optimized for training speed.
This version is only marginally faster on CPU (~1.02x) but significantly
faster on GPU (~2.05x) during training, thanks to optimized CUDA kernels
and FlashAttention support.
Key differences from the original GPTModel:
1. Uses PyTorch's built-in LayerNorm instead of a custom implementation.
2. Uses PyTorch's built-in GELU instead of a custom implementation.
3. Uses PyTorch's scaled_dot_product_attention instead of a custom MultiHeadAttention.
4. Automatically enables FlashAttention on compatible GPUs.
"""
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
*[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])])
self.final_norm = nn.LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds
x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits

View File

@ -4,7 +4,7 @@
# Code: https://github.com/rasbt/LLMs-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch
from llms_from_scratch.ch03 import MultiHeadAttention from llms_from_scratch.ch03 import MultiHeadAttention, PyTorchMultiHeadAttention
import torch import torch
@ -14,7 +14,15 @@ def test_mha():
d_in = 256 d_in = 256
d_out = 16 d_out = 16
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2) mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=2)
batch = torch.rand(8, 6, d_in)
context_vecs = mha(batch)
context_vecs.shape == torch.Size([8, 6, d_out])
# Test bonus class
mha = PyTorchMultiHeadAttention(d_in, d_out, num_heads=2)
batch = torch.rand(8, 6, d_in) batch = torch.rand(8, 6, d_in)
context_vecs = mha(batch) context_vecs = mha(batch)

View File

@ -3,15 +3,15 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch # - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch
from llms_from_scratch.ch04 import GPTModel from llms_from_scratch.ch04 import GPTModel, GPTModelFast
from llms_from_scratch.ch04 import generate_text_simple from llms_from_scratch.ch04 import generate_text_simple
import pytest
import torch import torch
import tiktoken import tiktoken
def test_GPTModel(): GPT_CONFIG_124M = {
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size "vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length "context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension "emb_dim": 768, # Embedding dimension
@ -19,10 +19,13 @@ def test_GPTModel():
"n_layers": 12, # Number of layers "n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate "drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias "qkv_bias": False # Query-Key-Value bias
} }
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
def test_gpt_model_variants(ModelClass):
torch.manual_seed(123) torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M) model = ModelClass(GPT_CONFIG_124M)
model.eval() # disable dropout model.eval() # disable dropout
start_context = "Hello, I am" start_context = "Hello, I am"
@ -47,4 +50,4 @@ def test_GPTModel():
[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267, [15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267,
49706, 43231, 47062, 34657] 49706, 43231, 47062, 34657]
]) ])
torch.equal(expect, out) assert torch.equal(expect, out), "Generated output does not match expected output"

View File

@ -4,7 +4,7 @@
# Code: https://github.com/rasbt/LLMs-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch
from llms_from_scratch.ch02 import create_dataloader_v1 from llms_from_scratch.ch02 import create_dataloader_v1
from llms_from_scratch.ch04 import GPTModel from llms_from_scratch.ch04 import GPTModel, GPTModelFast
from llms_from_scratch.ch05 import train_model_simple from llms_from_scratch.ch05 import train_model_simple
import os import os
@ -16,60 +16,47 @@ import torch
from torch.utils.data import Subset, DataLoader from torch.utils.data import Subset, DataLoader
@pytest.mark.parametrize("file_name", ["the-verdict.txt"]) GPT_CONFIG_124M = {
def test_train_simple(tmp_path, file_name): "vocab_size": 50257,
"context_length": 256, # Shortened for test speed
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False
}
GPT_CONFIG_124M = { OTHER_SETTINGS = {
"vocab_size": 50257, # Vocabulary size
"context_length": 256, # Shortened context length (orig: 1024)
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-key-value bias
}
OTHER_SETTINGS = {
"learning_rate": 5e-4, "learning_rate": 5e-4,
"num_epochs": 2, "num_epochs": 2,
"batch_size": 1, "batch_size": 1,
"weight_decay": 0.1 "weight_decay": 0.1
} }
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
def test_train_simple(tmp_path, ModelClass):
torch.manual_seed(123) torch.manual_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
############################## ##############################
# Download data if necessary # Download data if necessary
############################## ##############################
file_path = tmp_path / "the-verdict.txt" file_path = tmp_path / "the-verdict.txt"
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt" url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
if not os.path.exists(file_path): if not os.path.exists(file_path):
with urllib.request.urlopen(url) as response: with urllib.request.urlopen(url) as response:
text_data = response.read().decode('utf-8') text_data = response.read().decode("utf-8")
with open(file_path, "w", encoding="utf-8") as file: with open(file_path, "w", encoding="utf-8") as f:
file.write(text_data) f.write(text_data)
else: else:
with open(file_path, "r", encoding="utf-8") as file: with open(file_path, "r", encoding="utf-8") as f:
text_data = file.read() text_data = f.read()
##############################
# Initialize model
##############################
model = GPTModel(GPT_CONFIG_124M)
model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
optimizer = torch.optim.AdamW(
model.parameters(), lr=OTHER_SETTINGS["learning_rate"], weight_decay=OTHER_SETTINGS["weight_decay"]
)
############################## ##############################
# Set up dataloaders # Set up dataloaders
############################## ##############################
# Train/validation ratio
train_ratio = 0.90 train_ratio = 0.90
split_idx = int(train_ratio * len(text_data)) split_idx = int(train_ratio * len(text_data))
@ -93,17 +80,26 @@ def test_train_simple(tmp_path, file_name):
num_workers=0 num_workers=0
) )
############################## # Limit to 1 batch for speed
# Train model
##############################
tokenizer = tiktoken.get_encoding("gpt2")
train_subset = Subset(train_loader.dataset, range(1)) train_subset = Subset(train_loader.dataset, range(1))
one_batch_train_loader = DataLoader(train_subset, batch_size=1) one_batch_train_loader = DataLoader(train_subset, batch_size=1)
val_subset = Subset(val_loader.dataset, range(1)) val_subset = Subset(val_loader.dataset, range(1))
one_batch_val_loader = DataLoader(val_subset, batch_size=1) one_batch_val_loader = DataLoader(val_subset, batch_size=1)
##############################
# Train model
##############################
model = ModelClass(GPT_CONFIG_124M)
model.to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=OTHER_SETTINGS["learning_rate"],
weight_decay=OTHER_SETTINGS["weight_decay"]
)
tokenizer = tiktoken.get_encoding("gpt2")
train_losses, val_losses, tokens_seen = train_model_simple( train_losses, val_losses, tokens_seen = train_model_simple(
model, one_batch_train_loader, one_batch_val_loader, optimizer, device, model, one_batch_train_loader, one_batch_val_loader, optimizer, device,
num_epochs=OTHER_SETTINGS["num_epochs"], eval_freq=1, eval_iter=1, num_epochs=OTHER_SETTINGS["num_epochs"], eval_freq=1, eval_iter=1,

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "llms-from-scratch" name = "llms-from-scratch"
version = "1.0.0" version = "1.0.1"
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step" description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"