mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-06-26 23:50:03 +00:00
Llama3Fast (#593)
* Llama3Fast * Update pkg/llms_from_scratch/tests/test_llama3.py
This commit is contained in:
parent
aedad7efc3
commit
43e25a5165
@ -67,7 +67,10 @@ class Llama3Model(nn.Module):
|
|||||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||||
|
|
||||||
# Reusuable utilities
|
# Reusuable utilities
|
||||||
self.register_buffer("mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool())
|
self.register_buffer(
|
||||||
|
"mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(),
|
||||||
|
persistent=False
|
||||||
|
)
|
||||||
|
|
||||||
if cfg["orig_context_length"] != cfg["context_length"]:
|
if cfg["orig_context_length"] != cfg["context_length"]:
|
||||||
cfg["rope_base"] = rescale_theta(
|
cfg["rope_base"] = rescale_theta(
|
||||||
@ -86,7 +89,6 @@ class Llama3Model(nn.Module):
|
|||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
def forward(self, in_idx):
|
def forward(self, in_idx):
|
||||||
# Forward pass
|
|
||||||
tok_embeds = self.tok_emb(in_idx)
|
tok_embeds = self.tok_emb(in_idx)
|
||||||
x = tok_embeds
|
x = tok_embeds
|
||||||
|
|
||||||
@ -143,9 +145,7 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
class GroupedQueryAttention(nn.Module):
|
class GroupedQueryAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, d_in, d_out, num_heads,
|
self, d_in, d_out, num_heads, num_kv_groups, dtype=None
|
||||||
num_kv_groups,
|
|
||||||
dtype=None
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||||
@ -375,3 +375,136 @@ def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
|
|||||||
else:
|
else:
|
||||||
# If the token is not found, return the original text
|
# If the token is not found, return the original text
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
# Llama 3 fast (alternative code geared towards efficiency)
|
||||||
|
######################################################################
|
||||||
|
|
||||||
|
class GroupedQueryAttentionFast(nn.Module):
|
||||||
|
"""
|
||||||
|
Drop-in replacement for GroupedQueryAttention but using PyTorch's
|
||||||
|
scaled_dot_product_attention, which uses FlashAttention if run
|
||||||
|
on an Ampere GPU (like A100) or newer and uses float16/bfloat16 or lower.
|
||||||
|
"""
|
||||||
|
def __init__(self, d_in, d_out, num_heads, num_kv_groups, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||||
|
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
|
||||||
|
|
||||||
|
self.d_out = d_out
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = d_out // num_heads
|
||||||
|
self.num_kv_groups = num_kv_groups
|
||||||
|
self.group_size = num_heads // num_kv_groups
|
||||||
|
|
||||||
|
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
|
||||||
|
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
|
||||||
|
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
|
||||||
|
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, cos, sin):
|
||||||
|
b, num_tokens, _ = x.shape
|
||||||
|
|
||||||
|
# Project to queries, keys, values
|
||||||
|
q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||||
|
v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Apply Rotary Positional Embedding
|
||||||
|
q = apply_rope(q, cos, sin)
|
||||||
|
k = apply_rope(k, cos, sin)
|
||||||
|
|
||||||
|
# Expand key/value groups to full head count
|
||||||
|
k = k.repeat_interleave(self.group_size, dim=1)
|
||||||
|
v = v.repeat_interleave(self.group_size, dim=1)
|
||||||
|
|
||||||
|
# Efficient scaled dot-product attention
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
is_causal=True # Enables Flash/FlexAttention kernels
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine heads and project
|
||||||
|
attn_output = attn_output.transpose(1, 2).reshape(b, num_tokens, self.d_out)
|
||||||
|
return self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlockFast(nn.Module):
|
||||||
|
"""
|
||||||
|
Same as original TransformerBlock but uses
|
||||||
|
GroupedQueryAttentionFast instead of GroupedQueryAttention.
|
||||||
|
"""
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
self.att = GroupedQueryAttentionFast(
|
||||||
|
d_in=cfg["emb_dim"],
|
||||||
|
d_out=cfg["emb_dim"],
|
||||||
|
num_heads=cfg["n_heads"],
|
||||||
|
num_kv_groups=cfg["n_kv_groups"],
|
||||||
|
dtype=cfg["dtype"]
|
||||||
|
)
|
||||||
|
self.ff = FeedForward(cfg)
|
||||||
|
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||||
|
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||||
|
|
||||||
|
def forward(self, x, cos, sin):
|
||||||
|
# Shortcut connection for attention block
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.att(x, cos, sin) # Shape [batch_size, num_tokens, emb_size]
|
||||||
|
x = x + shortcut # Add the original input back
|
||||||
|
|
||||||
|
# Shortcut connection for feed-forward block
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.ff(x)
|
||||||
|
x = x + shortcut # Add the original input back
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Llama3ModelFast(nn.Module):
|
||||||
|
"""
|
||||||
|
Same as original Llama3Model but uses TransformerBlockFast
|
||||||
|
instead of TransformerBlock, which in turn uses
|
||||||
|
GroupedQueryAttentionFast instead of GroupedQueryAttention.
|
||||||
|
"""
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Main model parameters
|
||||||
|
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
|
||||||
|
|
||||||
|
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, cos, sin`
|
||||||
|
[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||||
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||||
|
|
||||||
|
if cfg["orig_context_length"] != cfg["context_length"]:
|
||||||
|
cfg["rope_base"] = rescale_theta(
|
||||||
|
cfg["rope_base"],
|
||||||
|
cfg["orig_context_length"],
|
||||||
|
cfg["context_length"]
|
||||||
|
)
|
||||||
|
cos, sin = compute_rope_params(
|
||||||
|
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||||
|
theta_base=cfg["rope_base"],
|
||||||
|
context_length=cfg["context_length"],
|
||||||
|
freq_config=cfg["rope_freq"]
|
||||||
|
)
|
||||||
|
self.register_buffer("cos", cos, persistent=False)
|
||||||
|
self.register_buffer("sin", sin, persistent=False)
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def forward(self, in_idx):
|
||||||
|
tok_embeds = self.tok_emb(in_idx)
|
||||||
|
x = tok_embeds
|
||||||
|
|
||||||
|
for block in self.trf_blocks:
|
||||||
|
x = block(x, self.cos, self.sin)
|
||||||
|
x = self.final_norm(x)
|
||||||
|
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||||
|
return logits
|
||||||
|
@ -9,7 +9,9 @@ from llms_from_scratch.llama3 import (
|
|||||||
apply_rope,
|
apply_rope,
|
||||||
rescale_theta,
|
rescale_theta,
|
||||||
LLAMA32_CONFIG_1B,
|
LLAMA32_CONFIG_1B,
|
||||||
Llama3Model
|
GroupedQueryAttention,
|
||||||
|
GroupedQueryAttentionFast,
|
||||||
|
Llama3Model,
|
||||||
)
|
)
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
@ -117,13 +119,63 @@ def test_rescale():
|
|||||||
assert old_theta == 500_000.
|
assert old_theta == 500_000.
|
||||||
|
|
||||||
|
|
||||||
|
def test_grouped_query_attention_equivalence():
|
||||||
|
torch.manual_seed(42)
|
||||||
|
b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2
|
||||||
|
|
||||||
|
x = torch.randn(b, t, d_in)
|
||||||
|
cos, sin = compute_rope_params(
|
||||||
|
head_dim=d_out // num_heads,
|
||||||
|
theta_base=50_000,
|
||||||
|
context_length=t,
|
||||||
|
freq_config={
|
||||||
|
"factor": 32.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"original_context_length": t,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Causal mask for the slow version
|
||||||
|
mask = torch.triu(torch.ones(t, t, dtype=torch.bool), diagonal=1)
|
||||||
|
|
||||||
|
attn1 = GroupedQueryAttention(d_in, d_out, num_heads, num_kv_groups)
|
||||||
|
attn2 = GroupedQueryAttentionFast(d_in, d_out, num_heads, num_kv_groups)
|
||||||
|
|
||||||
|
# Copy weights to make both models identical
|
||||||
|
attn2.load_state_dict(attn1.state_dict())
|
||||||
|
|
||||||
|
# Run both
|
||||||
|
y1 = attn1(x, mask, cos, sin)
|
||||||
|
y2 = attn2(x, cos, sin)
|
||||||
|
|
||||||
|
# Compare outputs
|
||||||
|
max_diff = (y1 - y2).abs().max().item()
|
||||||
|
print(f"Max difference between slow and fast outputs: {max_diff:.4e}")
|
||||||
|
assert torch.allclose(y1, y2, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def llama3_weights_path(tmp_path_factory):
|
||||||
|
"""Creates and saves a deterministic Llama3 model for testing."""
|
||||||
|
path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
torch.manual_seed(123)
|
||||||
|
model = Llama3Model(LLAMA32_CONFIG_1B)
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("ModelClass", [Llama3Model])
|
@pytest.mark.parametrize("ModelClass", [Llama3Model])
|
||||||
def test_gpt_model_variants(ModelClass):
|
def test_gpt_model_variants(ModelClass, llama3_weights_path):
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
model = ModelClass(LLAMA32_CONFIG_1B)
|
model = ModelClass(LLAMA32_CONFIG_1B)
|
||||||
|
model.load_state_dict(torch.load(llama3_weights_path))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
start_context = "Hello, I am"
|
start_context = "Llamas eat"
|
||||||
|
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
encoded = tokenizer.encode(start_context)
|
encoded = tokenizer.encode(start_context)
|
||||||
@ -137,11 +189,11 @@ def test_gpt_model_variants(ModelClass):
|
|||||||
out = generate_text_simple(
|
out = generate_text_simple(
|
||||||
model=model,
|
model=model,
|
||||||
idx=encoded_tensor,
|
idx=encoded_tensor,
|
||||||
max_new_tokens=10,
|
max_new_tokens=5,
|
||||||
context_size=LLAMA32_CONFIG_1B["context_length"]
|
context_size=LLAMA32_CONFIG_1B["context_length"]
|
||||||
)
|
)
|
||||||
|
print("Encoded output text:", out)
|
||||||
expect = torch.tensor([
|
expect = torch.tensor([
|
||||||
[15496, 11, 314, 716, 78563, 89362, 19616, 115725, 114917,
|
[43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419]
|
||||||
97198, 60342, 19108, 100752, 98969]
|
|
||||||
])
|
])
|
||||||
assert torch.equal(expect, out)
|
assert torch.equal(expect, out)
|
||||||
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "llms-from-scratch"
|
name = "llms-from-scratch"
|
||||||
version = "1.0.5"
|
version = "1.0.6"
|
||||||
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user