diff --git a/pkg/llms_from_scratch/README.md b/pkg/llms_from_scratch/README.md index 7f61c63..2cffbec 100644 --- a/pkg/llms_from_scratch/README.md +++ b/pkg/llms_from_scratch/README.md @@ -50,11 +50,12 @@ Once installed, you can import code from any chapter using: from llms_from_scratch.ch02 import GPTDatasetV1, create_dataloader_v1 from llms_from_scratch.ch03 import ( - MultiHeadAttention, SelfAttention_v1, SelfAttention_v2, CausalAttention, - MultiHeadAttentionWrapper + MultiHeadAttentionWrapper, + MultiHeadAttention, + PyTorchMultiHeadAttention # Bonus: Faster variant using PyTorch's scaled_dot_product_attention ) from llms_from_scratch.ch04 import ( @@ -63,6 +64,7 @@ from llms_from_scratch.ch04 import ( FeedForward, TransformerBlock, GPTModel, + GPTModelFast # Bonus: Faster variant using PyTorch's scaled_dot_product_attention generate_text_simple ) diff --git a/pkg/llms_from_scratch/ch03.py b/pkg/llms_from_scratch/ch03.py index 7f64439..99ff4a7 100644 --- a/pkg/llms_from_scratch/ch03.py +++ b/pkg/llms_from_scratch/ch03.py @@ -149,3 +149,50 @@ class MultiHeadAttention(nn.Module): context_vec = self.out_proj(context_vec) # optional projection return context_vec + + +###################### +# Bonus +###################### + + +class PyTorchMultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False): + super().__init__() + + assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads" + + self.num_heads = num_heads + self.head_dim = d_out // num_heads + self.d_out = d_out + + self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias) + self.proj = nn.Linear(d_out, d_out) + self.dropout = dropout + + def forward(self, x): + batch_size, num_tokens, embed_dim = x.shape + + # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim) + qkv = self.qkv(x) + + # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim) + qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim) + + # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) + + # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim) + queries, keys, values = qkv + + use_dropout = 0. if not self.training else self.dropout + + context_vec = nn.functional.scaled_dot_product_attention( + queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out) + + context_vec = self.proj(context_vec) + + return context_vec diff --git a/pkg/llms_from_scratch/ch04.py b/pkg/llms_from_scratch/ch04.py index 1a353a1..5e2e43c 100644 --- a/pkg/llms_from_scratch/ch04.py +++ b/pkg/llms_from_scratch/ch04.py @@ -3,7 +3,7 @@ # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch -from .ch03 import MultiHeadAttention +from .ch03 import MultiHeadAttention, PyTorchMultiHeadAttention import torch import torch.nn as nn @@ -128,3 +128,90 @@ def generate_text_simple(model, idx, max_new_tokens, context_size): idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1) return idx + +###################### +# Bonus +###################### + + +class FeedForwardFast(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), + nn.GELU(approximate="tanh"), + nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), + ) + + def forward(self, x): + return self.layers(x) + + +class TransformerBlockFast(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = PyTorchMultiHeadAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"]) + self.ff = FeedForwardFast(cfg) + self.norm1 = nn.LayerNorm(cfg["emb_dim"]) + self.norm2 = nn.LayerNorm(cfg["emb_dim"]) + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModelFast(nn.Module): + """ + A faster variant of GPTModel optimized for training speed. + + This version is only marginally faster on CPU (~1.02x) but significantly + faster on GPU (~2.05x) during training, thanks to optimized CUDA kernels + and FlashAttention support. + + Key differences from the original GPTModel: + 1. Uses PyTorch's built-in LayerNorm instead of a custom implementation. + 2. Uses PyTorch's built-in GELU instead of a custom implementation. + 3. Uses PyTorch's scaled_dot_product_attention instead of a custom MultiHeadAttention. + 4. Automatically enables FlashAttention on compatible GPUs. + """ + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + self.trf_blocks = nn.Sequential( + *[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])]) + + self.final_norm = nn.LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + x = tok_embeds + pos_embeds + x = self.drop_emb(x) + x = self.trf_blocks(x) + x = self.final_norm(x) + logits = self.out_head(x) + return logits diff --git a/pkg/llms_from_scratch/tests/test_ch03.py b/pkg/llms_from_scratch/tests/test_ch03.py index 91d2606..dee3b2f 100644 --- a/pkg/llms_from_scratch/tests/test_ch03.py +++ b/pkg/llms_from_scratch/tests/test_ch03.py @@ -4,7 +4,7 @@ # Code: https://github.com/rasbt/LLMs-from-scratch -from llms_from_scratch.ch03 import MultiHeadAttention +from llms_from_scratch.ch03 import MultiHeadAttention, PyTorchMultiHeadAttention import torch @@ -14,7 +14,15 @@ def test_mha(): d_in = 256 d_out = 16 - mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2) + mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=2) + + batch = torch.rand(8, 6, d_in) + context_vecs = mha(batch) + + context_vecs.shape == torch.Size([8, 6, d_out]) + + # Test bonus class + mha = PyTorchMultiHeadAttention(d_in, d_out, num_heads=2) batch = torch.rand(8, 6, d_in) context_vecs = mha(batch) diff --git a/pkg/llms_from_scratch/tests/test_ch04.py b/pkg/llms_from_scratch/tests/test_ch04.py index c84ad15..4f1cdc4 100644 --- a/pkg/llms_from_scratch/tests/test_ch04.py +++ b/pkg/llms_from_scratch/tests/test_ch04.py @@ -3,26 +3,29 @@ # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch -from llms_from_scratch.ch04 import GPTModel +from llms_from_scratch.ch04 import GPTModel, GPTModelFast from llms_from_scratch.ch04 import generate_text_simple +import pytest import torch import tiktoken -def test_GPTModel(): - GPT_CONFIG_124M = { - "vocab_size": 50257, # Vocabulary size - "context_length": 1024, # Context length - "emb_dim": 768, # Embedding dimension - "n_heads": 12, # Number of attention heads - "n_layers": 12, # Number of layers - "drop_rate": 0.1, # Dropout rate - "qkv_bias": False # Query-Key-Value bias - } +GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": 1024, # Context length + "emb_dim": 768, # Embedding dimension + "n_heads": 12, # Number of attention heads + "n_layers": 12, # Number of layers + "drop_rate": 0.1, # Dropout rate + "qkv_bias": False # Query-Key-Value bias +} + +@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast]) +def test_gpt_model_variants(ModelClass): torch.manual_seed(123) - model = GPTModel(GPT_CONFIG_124M) + model = ModelClass(GPT_CONFIG_124M) model.eval() # disable dropout start_context = "Hello, I am" @@ -47,4 +50,4 @@ def test_GPTModel(): [15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267, 49706, 43231, 47062, 34657] ]) - torch.equal(expect, out) + assert torch.equal(expect, out), "Generated output does not match expected output" diff --git a/pkg/llms_from_scratch/tests/test_ch05.py b/pkg/llms_from_scratch/tests/test_ch05.py index 617440f..3a9778a 100644 --- a/pkg/llms_from_scratch/tests/test_ch05.py +++ b/pkg/llms_from_scratch/tests/test_ch05.py @@ -4,7 +4,7 @@ # Code: https://github.com/rasbt/LLMs-from-scratch from llms_from_scratch.ch02 import create_dataloader_v1 -from llms_from_scratch.ch04 import GPTModel +from llms_from_scratch.ch04 import GPTModel, GPTModelFast from llms_from_scratch.ch05 import train_model_simple import os @@ -16,60 +16,47 @@ import torch from torch.utils.data import Subset, DataLoader -@pytest.mark.parametrize("file_name", ["the-verdict.txt"]) -def test_train_simple(tmp_path, file_name): +GPT_CONFIG_124M = { + "vocab_size": 50257, + "context_length": 256, # Shortened for test speed + "emb_dim": 768, + "n_heads": 12, + "n_layers": 12, + "drop_rate": 0.1, + "qkv_bias": False +} - GPT_CONFIG_124M = { - "vocab_size": 50257, # Vocabulary size - "context_length": 256, # Shortened context length (orig: 1024) - "emb_dim": 768, # Embedding dimension - "n_heads": 12, # Number of attention heads - "n_layers": 12, # Number of layers - "drop_rate": 0.1, # Dropout rate - "qkv_bias": False # Query-key-value bias - } +OTHER_SETTINGS = { + "learning_rate": 5e-4, + "num_epochs": 2, + "batch_size": 1, + "weight_decay": 0.1 +} - OTHER_SETTINGS = { - "learning_rate": 5e-4, - "num_epochs": 2, - "batch_size": 1, - "weight_decay": 0.1 - } +@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast]) +def test_train_simple(tmp_path, ModelClass): torch.manual_seed(123) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ############################## # Download data if necessary ############################## - file_path = tmp_path / "the-verdict.txt" url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt" if not os.path.exists(file_path): with urllib.request.urlopen(url) as response: - text_data = response.read().decode('utf-8') - with open(file_path, "w", encoding="utf-8") as file: - file.write(text_data) + text_data = response.read().decode("utf-8") + with open(file_path, "w", encoding="utf-8") as f: + f.write(text_data) else: - with open(file_path, "r", encoding="utf-8") as file: - text_data = file.read() - - ############################## - # Initialize model - ############################## - - model = GPTModel(GPT_CONFIG_124M) - model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes - optimizer = torch.optim.AdamW( - model.parameters(), lr=OTHER_SETTINGS["learning_rate"], weight_decay=OTHER_SETTINGS["weight_decay"] - ) + with open(file_path, "r", encoding="utf-8") as f: + text_data = f.read() ############################## # Set up dataloaders ############################## - - # Train/validation ratio train_ratio = 0.90 split_idx = int(train_ratio * len(text_data)) @@ -93,17 +80,26 @@ def test_train_simple(tmp_path, file_name): num_workers=0 ) - ############################## - # Train model - ############################## - - tokenizer = tiktoken.get_encoding("gpt2") - + # Limit to 1 batch for speed train_subset = Subset(train_loader.dataset, range(1)) one_batch_train_loader = DataLoader(train_subset, batch_size=1) val_subset = Subset(val_loader.dataset, range(1)) one_batch_val_loader = DataLoader(val_subset, batch_size=1) + ############################## + # Train model + ############################## + model = ModelClass(GPT_CONFIG_124M) + model.to(device) + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=OTHER_SETTINGS["learning_rate"], + weight_decay=OTHER_SETTINGS["weight_decay"] + ) + + tokenizer = tiktoken.get_encoding("gpt2") + train_losses, val_losses, tokens_seen = train_model_simple( model, one_batch_train_loader, one_batch_val_loader, optimizer, device, num_epochs=OTHER_SETTINGS["num_epochs"], eval_freq=1, eval_iter=1, diff --git a/pyproject.toml b/pyproject.toml index e543439..690dc24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "llms-from-scratch" -version = "1.0.0" +version = "1.0.1" description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step" readme = "README.md" requires-python = ">=3.10"