mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-27 07:49:25 +00:00
parent
c9271ac427
commit
e07a7abdd5
@ -50,11 +50,12 @@ Once installed, you can import code from any chapter using:
|
||||
from llms_from_scratch.ch02 import GPTDatasetV1, create_dataloader_v1
|
||||
|
||||
from llms_from_scratch.ch03 import (
|
||||
MultiHeadAttention,
|
||||
SelfAttention_v1,
|
||||
SelfAttention_v2,
|
||||
CausalAttention,
|
||||
MultiHeadAttentionWrapper
|
||||
MultiHeadAttentionWrapper,
|
||||
MultiHeadAttention,
|
||||
PyTorchMultiHeadAttention # Bonus: Faster variant using PyTorch's scaled_dot_product_attention
|
||||
)
|
||||
|
||||
from llms_from_scratch.ch04 import (
|
||||
@ -63,6 +64,7 @@ from llms_from_scratch.ch04 import (
|
||||
FeedForward,
|
||||
TransformerBlock,
|
||||
GPTModel,
|
||||
GPTModelFast # Bonus: Faster variant using PyTorch's scaled_dot_product_attention
|
||||
generate_text_simple
|
||||
)
|
||||
|
||||
|
||||
@ -149,3 +149,50 @@ class MultiHeadAttention(nn.Module):
|
||||
context_vec = self.out_proj(context_vec) # optional projection
|
||||
|
||||
return context_vec
|
||||
|
||||
|
||||
######################
|
||||
# Bonus
|
||||
######################
|
||||
|
||||
|
||||
class PyTorchMultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
|
||||
super().__init__()
|
||||
|
||||
assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_out // num_heads
|
||||
self.d_out = d_out
|
||||
|
||||
self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
|
||||
self.proj = nn.Linear(d_out, d_out)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, num_tokens, embed_dim = x.shape
|
||||
|
||||
# (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
|
||||
qkv = self.qkv(x)
|
||||
|
||||
# (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
|
||||
qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
|
||||
|
||||
# (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||
|
||||
# (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
|
||||
queries, keys, values = qkv
|
||||
|
||||
use_dropout = 0. if not self.training else self.dropout
|
||||
|
||||
context_vec = nn.functional.scaled_dot_product_attention(
|
||||
queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)
|
||||
|
||||
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
||||
context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)
|
||||
|
||||
context_vec = self.proj(context_vec)
|
||||
|
||||
return context_vec
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from .ch03 import MultiHeadAttention
|
||||
from .ch03 import MultiHeadAttention, PyTorchMultiHeadAttention
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -128,3 +128,90 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
|
||||
|
||||
return idx
|
||||
|
||||
######################
|
||||
# Bonus
|
||||
######################
|
||||
|
||||
|
||||
class FeedForwardFast(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class TransformerBlockFast(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.att = PyTorchMultiHeadAttention(
|
||||
d_in=cfg["emb_dim"],
|
||||
d_out=cfg["emb_dim"],
|
||||
num_heads=cfg["n_heads"],
|
||||
dropout=cfg["drop_rate"],
|
||||
qkv_bias=cfg["qkv_bias"])
|
||||
self.ff = FeedForwardFast(cfg)
|
||||
self.norm1 = nn.LayerNorm(cfg["emb_dim"])
|
||||
self.norm2 = nn.LayerNorm(cfg["emb_dim"])
|
||||
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
|
||||
|
||||
def forward(self, x):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = self.drop_shortcut(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
shortcut = x
|
||||
x = self.norm2(x)
|
||||
x = self.ff(x)
|
||||
x = self.drop_shortcut(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GPTModelFast(nn.Module):
|
||||
"""
|
||||
A faster variant of GPTModel optimized for training speed.
|
||||
|
||||
This version is only marginally faster on CPU (~1.02x) but significantly
|
||||
faster on GPU (~2.05x) during training, thanks to optimized CUDA kernels
|
||||
and FlashAttention support.
|
||||
|
||||
Key differences from the original GPTModel:
|
||||
1. Uses PyTorch's built-in LayerNorm instead of a custom implementation.
|
||||
2. Uses PyTorch's built-in GELU instead of a custom implementation.
|
||||
3. Uses PyTorch's scaled_dot_product_attention instead of a custom MultiHeadAttention.
|
||||
4. Automatically enables FlashAttention on compatible GPUs.
|
||||
"""
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
|
||||
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
|
||||
self.drop_emb = nn.Dropout(cfg["drop_rate"])
|
||||
|
||||
self.trf_blocks = nn.Sequential(
|
||||
*[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])])
|
||||
|
||||
self.final_norm = nn.LayerNorm(cfg["emb_dim"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
|
||||
|
||||
def forward(self, in_idx):
|
||||
batch_size, seq_len = in_idx.shape
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
|
||||
x = tok_embeds + pos_embeds
|
||||
x = self.drop_emb(x)
|
||||
x = self.trf_blocks(x)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x)
|
||||
return logits
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
|
||||
from llms_from_scratch.ch03 import MultiHeadAttention
|
||||
from llms_from_scratch.ch03 import MultiHeadAttention, PyTorchMultiHeadAttention
|
||||
import torch
|
||||
|
||||
|
||||
@ -14,7 +14,15 @@ def test_mha():
|
||||
d_in = 256
|
||||
d_out = 16
|
||||
|
||||
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
|
||||
mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=2)
|
||||
|
||||
batch = torch.rand(8, 6, d_in)
|
||||
context_vecs = mha(batch)
|
||||
|
||||
context_vecs.shape == torch.Size([8, 6, d_out])
|
||||
|
||||
# Test bonus class
|
||||
mha = PyTorchMultiHeadAttention(d_in, d_out, num_heads=2)
|
||||
|
||||
batch = torch.rand(8, 6, d_in)
|
||||
context_vecs = mha(batch)
|
||||
|
||||
@ -3,14 +3,14 @@
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from llms_from_scratch.ch04 import GPTModel
|
||||
from llms_from_scratch.ch04 import GPTModel, GPTModelFast
|
||||
from llms_from_scratch.ch04 import generate_text_simple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import tiktoken
|
||||
|
||||
|
||||
def test_GPTModel():
|
||||
GPT_CONFIG_124M = {
|
||||
"vocab_size": 50257, # Vocabulary size
|
||||
"context_length": 1024, # Context length
|
||||
@ -21,8 +21,11 @@ def test_GPTModel():
|
||||
"qkv_bias": False # Query-Key-Value bias
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
|
||||
def test_gpt_model_variants(ModelClass):
|
||||
torch.manual_seed(123)
|
||||
model = GPTModel(GPT_CONFIG_124M)
|
||||
model = ModelClass(GPT_CONFIG_124M)
|
||||
model.eval() # disable dropout
|
||||
|
||||
start_context = "Hello, I am"
|
||||
@ -47,4 +50,4 @@ def test_GPTModel():
|
||||
[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267,
|
||||
49706, 43231, 47062, 34657]
|
||||
])
|
||||
torch.equal(expect, out)
|
||||
assert torch.equal(expect, out), "Generated output does not match expected output"
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from llms_from_scratch.ch02 import create_dataloader_v1
|
||||
from llms_from_scratch.ch04 import GPTModel
|
||||
from llms_from_scratch.ch04 import GPTModel, GPTModelFast
|
||||
from llms_from_scratch.ch05 import train_model_simple
|
||||
|
||||
import os
|
||||
@ -16,17 +16,14 @@ import torch
|
||||
from torch.utils.data import Subset, DataLoader
|
||||
|
||||
|
||||
@pytest.mark.parametrize("file_name", ["the-verdict.txt"])
|
||||
def test_train_simple(tmp_path, file_name):
|
||||
|
||||
GPT_CONFIG_124M = {
|
||||
"vocab_size": 50257, # Vocabulary size
|
||||
"context_length": 256, # Shortened context length (orig: 1024)
|
||||
"emb_dim": 768, # Embedding dimension
|
||||
"n_heads": 12, # Number of attention heads
|
||||
"n_layers": 12, # Number of layers
|
||||
"drop_rate": 0.1, # Dropout rate
|
||||
"qkv_bias": False # Query-key-value bias
|
||||
"vocab_size": 50257,
|
||||
"context_length": 256, # Shortened for test speed
|
||||
"emb_dim": 768,
|
||||
"n_heads": 12,
|
||||
"n_layers": 12,
|
||||
"drop_rate": 0.1,
|
||||
"qkv_bias": False
|
||||
}
|
||||
|
||||
OTHER_SETTINGS = {
|
||||
@ -36,40 +33,30 @@ def test_train_simple(tmp_path, file_name):
|
||||
"weight_decay": 0.1
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
|
||||
def test_train_simple(tmp_path, ModelClass):
|
||||
torch.manual_seed(123)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
##############################
|
||||
# Download data if necessary
|
||||
##############################
|
||||
|
||||
file_path = tmp_path / "the-verdict.txt"
|
||||
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
with urllib.request.urlopen(url) as response:
|
||||
text_data = response.read().decode('utf-8')
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(text_data)
|
||||
text_data = response.read().decode("utf-8")
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(text_data)
|
||||
else:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
text_data = file.read()
|
||||
|
||||
##############################
|
||||
# Initialize model
|
||||
##############################
|
||||
|
||||
model = GPTModel(GPT_CONFIG_124M)
|
||||
model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(), lr=OTHER_SETTINGS["learning_rate"], weight_decay=OTHER_SETTINGS["weight_decay"]
|
||||
)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
text_data = f.read()
|
||||
|
||||
##############################
|
||||
# Set up dataloaders
|
||||
##############################
|
||||
|
||||
# Train/validation ratio
|
||||
train_ratio = 0.90
|
||||
split_idx = int(train_ratio * len(text_data))
|
||||
|
||||
@ -93,17 +80,26 @@ def test_train_simple(tmp_path, file_name):
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
##############################
|
||||
# Train model
|
||||
##############################
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
# Limit to 1 batch for speed
|
||||
train_subset = Subset(train_loader.dataset, range(1))
|
||||
one_batch_train_loader = DataLoader(train_subset, batch_size=1)
|
||||
val_subset = Subset(val_loader.dataset, range(1))
|
||||
one_batch_val_loader = DataLoader(val_subset, batch_size=1)
|
||||
|
||||
##############################
|
||||
# Train model
|
||||
##############################
|
||||
model = ModelClass(GPT_CONFIG_124M)
|
||||
model.to(device)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=OTHER_SETTINGS["learning_rate"],
|
||||
weight_decay=OTHER_SETTINGS["weight_decay"]
|
||||
)
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_losses, val_losses, tokens_seen = train_model_simple(
|
||||
model, one_batch_train_loader, one_batch_val_loader, optimizer, device,
|
||||
num_epochs=OTHER_SETTINGS["num_epochs"], eval_freq=1, eval_iter=1,
|
||||
|
||||
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "llms-from-scratch"
|
||||
version = "1.0.0"
|
||||
version = "1.0.1"
|
||||
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user