From 43e25a5165cc3743a22cba378b18c44a96c6daae Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Tue, 1 Apr 2025 12:56:11 -0500 Subject: [PATCH] Llama3Fast (#593) * Llama3Fast * Update pkg/llms_from_scratch/tests/test_llama3.py --- pkg/llms_from_scratch/llama3.py | 143 ++++++++++++++++++++- pkg/llms_from_scratch/tests/test_llama3.py | 64 ++++++++- pyproject.toml | 2 +- 3 files changed, 197 insertions(+), 12 deletions(-) diff --git a/pkg/llms_from_scratch/llama3.py b/pkg/llms_from_scratch/llama3.py index 203b996..2776882 100644 --- a/pkg/llms_from_scratch/llama3.py +++ b/pkg/llms_from_scratch/llama3.py @@ -67,7 +67,10 @@ class Llama3Model(nn.Module): self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) # Reusuable utilities - self.register_buffer("mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool()) + self.register_buffer( + "mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(), + persistent=False + ) if cfg["orig_context_length"] != cfg["context_length"]: cfg["rope_base"] = rescale_theta( @@ -86,7 +89,6 @@ class Llama3Model(nn.Module): self.cfg = cfg def forward(self, in_idx): - # Forward pass tok_embeds = self.tok_emb(in_idx) x = tok_embeds @@ -143,9 +145,7 @@ class FeedForward(nn.Module): class GroupedQueryAttention(nn.Module): def __init__( - self, d_in, d_out, num_heads, - num_kv_groups, - dtype=None + self, d_in, d_out, num_heads, num_kv_groups, dtype=None ): super().__init__() assert d_out % num_heads == 0, "d_out must be divisible by num_heads" @@ -375,3 +375,136 @@ def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): else: # If the token is not found, return the original text return text + + +###################################################################### +# Llama 3 fast (alternative code geared towards efficiency) +###################################################################### + +class GroupedQueryAttentionFast(nn.Module): + """ + Drop-in replacement for GroupedQueryAttention but using PyTorch's + scaled_dot_product_attention, which uses FlashAttention if run + on an Ampere GPU (like A100) or newer and uses float16/bfloat16 or lower. + """ + def __init__(self, d_in, d_out, num_heads, num_kv_groups, dtype=None): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads + self.num_kv_groups = num_kv_groups + self.group_size = num_heads // num_kv_groups + + self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) + self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) + self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) + self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) + + def forward(self, x, cos, sin): + b, num_tokens, _ = x.shape + + # Project to queries, keys, values + q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) + k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) + v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) + + # Apply Rotary Positional Embedding + q = apply_rope(q, cos, sin) + k = apply_rope(k, cos, sin) + + # Expand key/value groups to full head count + k = k.repeat_interleave(self.group_size, dim=1) + v = v.repeat_interleave(self.group_size, dim=1) + + # Efficient scaled dot-product attention + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + is_causal=True # Enables Flash/FlexAttention kernels + ) + + # Combine heads and project + attn_output = attn_output.transpose(1, 2).reshape(b, num_tokens, self.d_out) + return self.out_proj(attn_output) + + +class TransformerBlockFast(nn.Module): + """ + Same as original TransformerBlock but uses + GroupedQueryAttentionFast instead of GroupedQueryAttention. + """ + def __init__(self, cfg): + super().__init__() + self.att = GroupedQueryAttentionFast( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + num_heads=cfg["n_heads"], + num_kv_groups=cfg["n_kv_groups"], + dtype=cfg["dtype"] + ) + self.ff = FeedForward(cfg) + self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) + self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) + + def forward(self, x, cos, sin): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + x = self.att(x, cos, sin) # Shape [batch_size, num_tokens, emb_size] + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = x + shortcut # Add the original input back + + return x + + +class Llama3ModelFast(nn.Module): + """ + Same as original Llama3Model but uses TransformerBlockFast + instead of TransformerBlock, which in turn uses + GroupedQueryAttentionFast instead of GroupedQueryAttention. + """ + def __init__(self, cfg): + super().__init__() + + # Main model parameters + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) + + self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, cos, sin` + [TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])] + ) + + self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) + + if cfg["orig_context_length"] != cfg["context_length"]: + cfg["rope_base"] = rescale_theta( + cfg["rope_base"], + cfg["orig_context_length"], + cfg["context_length"] + ) + cos, sin = compute_rope_params( + head_dim=cfg["emb_dim"] // cfg["n_heads"], + theta_base=cfg["rope_base"], + context_length=cfg["context_length"], + freq_config=cfg["rope_freq"] + ) + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + self.cfg = cfg + + def forward(self, in_idx): + tok_embeds = self.tok_emb(in_idx) + x = tok_embeds + + for block in self.trf_blocks: + x = block(x, self.cos, self.sin) + x = self.final_norm(x) + logits = self.out_head(x.to(self.cfg["dtype"])) + return logits diff --git a/pkg/llms_from_scratch/tests/test_llama3.py b/pkg/llms_from_scratch/tests/test_llama3.py index 70ff8f5..0ffdc09 100644 --- a/pkg/llms_from_scratch/tests/test_llama3.py +++ b/pkg/llms_from_scratch/tests/test_llama3.py @@ -9,7 +9,9 @@ from llms_from_scratch.llama3 import ( apply_rope, rescale_theta, LLAMA32_CONFIG_1B, - Llama3Model + GroupedQueryAttention, + GroupedQueryAttentionFast, + Llama3Model, ) import importlib @@ -117,13 +119,63 @@ def test_rescale(): assert old_theta == 500_000. +def test_grouped_query_attention_equivalence(): + torch.manual_seed(42) + b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2 + + x = torch.randn(b, t, d_in) + cos, sin = compute_rope_params( + head_dim=d_out // num_heads, + theta_base=50_000, + context_length=t, + freq_config={ + "factor": 32.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_context_length": t, + } + ) + + # Causal mask for the slow version + mask = torch.triu(torch.ones(t, t, dtype=torch.bool), diagonal=1) + + attn1 = GroupedQueryAttention(d_in, d_out, num_heads, num_kv_groups) + attn2 = GroupedQueryAttentionFast(d_in, d_out, num_heads, num_kv_groups) + + # Copy weights to make both models identical + attn2.load_state_dict(attn1.state_dict()) + + # Run both + y1 = attn1(x, mask, cos, sin) + y2 = attn2(x, cos, sin) + + # Compare outputs + max_diff = (y1 - y2).abs().max().item() + print(f"Max difference between slow and fast outputs: {max_diff:.4e}") + assert torch.allclose(y1, y2, atol=1e-4) + + +@pytest.fixture(scope="session") +def llama3_weights_path(tmp_path_factory): + """Creates and saves a deterministic Llama3 model for testing.""" + path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt" + + if not path.exists(): + torch.manual_seed(123) + model = Llama3Model(LLAMA32_CONFIG_1B) + torch.save(model.state_dict(), path) + + return path + + @pytest.mark.parametrize("ModelClass", [Llama3Model]) -def test_gpt_model_variants(ModelClass): +def test_gpt_model_variants(ModelClass, llama3_weights_path): torch.manual_seed(123) model = ModelClass(LLAMA32_CONFIG_1B) + model.load_state_dict(torch.load(llama3_weights_path)) model.eval() - start_context = "Hello, I am" + start_context = "Llamas eat" tokenizer = tiktoken.get_encoding("gpt2") encoded = tokenizer.encode(start_context) @@ -137,11 +189,11 @@ def test_gpt_model_variants(ModelClass): out = generate_text_simple( model=model, idx=encoded_tensor, - max_new_tokens=10, + max_new_tokens=5, context_size=LLAMA32_CONFIG_1B["context_length"] ) + print("Encoded output text:", out) expect = torch.tensor([ - [15496, 11, 314, 716, 78563, 89362, 19616, 115725, 114917, - 97198, 60342, 19108, 100752, 98969] + [43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419] ]) assert torch.equal(expect, out) diff --git a/pyproject.toml b/pyproject.toml index f0805c0..d9997a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "llms-from-scratch" -version = "1.0.5" +version = "1.0.6" description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step" readme = "README.md" requires-python = ">=3.10"