Sebastian Raschka 43e25a5165
Llama3Fast (#593)
* Llama3Fast

* Update pkg/llms_from_scratch/tests/test_llama3.py
2025-04-01 12:56:11 -05:00

200 lines
5.8 KiB
Python

# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.llama3 import (
compute_rope_params,
apply_rope,
rescale_theta,
LLAMA32_CONFIG_1B,
GroupedQueryAttention,
GroupedQueryAttentionFast,
Llama3Model,
)
import importlib
import pytest
import tiktoken
import torch
transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_rope():
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
# Settings
batch_size = 1
context_len = 8192
num_heads = 4
head_dim = 16
rope_theta = 500_000
rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
}
# Instantiate RoPE parameters
cos, sin = compute_rope_params(
head_dim=head_dim,
theta_base=rope_theta,
context_length=context_len,
freq_config=rope_config,
)
# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
# Apply rotary position embeddings
queries_rot = apply_rope(queries, cos, sin)
keys_rot = apply_rope(keys, cos, sin)
# Generate reference RoPE via HF
hf_rope_params = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}
class RoPEConfig:
rope_type = "llama3"
rope_scaling = hf_rope_params
factor = 1.0
dim: int = head_dim
rope_theta = 500_000
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
torch.testing.assert_close(sin, ref_sin.squeeze(0))
torch.testing.assert_close(cos, ref_cos.squeeze(0))
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
def test_rescale():
new_theta = rescale_theta(
theta_old=500_000.,
context_length_old=131_072,
context_length_new=8192
)
assert new_theta == 31250.
old_theta = rescale_theta(
theta_old=new_theta,
context_length_old=8192,
context_length_new=131_072
)
assert old_theta == 500_000.
def test_grouped_query_attention_equivalence():
torch.manual_seed(42)
b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2
x = torch.randn(b, t, d_in)
cos, sin = compute_rope_params(
head_dim=d_out // num_heads,
theta_base=50_000,
context_length=t,
freq_config={
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": t,
}
)
# Causal mask for the slow version
mask = torch.triu(torch.ones(t, t, dtype=torch.bool), diagonal=1)
attn1 = GroupedQueryAttention(d_in, d_out, num_heads, num_kv_groups)
attn2 = GroupedQueryAttentionFast(d_in, d_out, num_heads, num_kv_groups)
# Copy weights to make both models identical
attn2.load_state_dict(attn1.state_dict())
# Run both
y1 = attn1(x, mask, cos, sin)
y2 = attn2(x, cos, sin)
# Compare outputs
max_diff = (y1 - y2).abs().max().item()
print(f"Max difference between slow and fast outputs: {max_diff:.4e}")
assert torch.allclose(y1, y2, atol=1e-4)
@pytest.fixture(scope="session")
def llama3_weights_path(tmp_path_factory):
"""Creates and saves a deterministic Llama3 model for testing."""
path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
if not path.exists():
torch.manual_seed(123)
model = Llama3Model(LLAMA32_CONFIG_1B)
torch.save(model.state_dict(), path)
return path
@pytest.mark.parametrize("ModelClass", [Llama3Model])
def test_gpt_model_variants(ModelClass, llama3_weights_path):
torch.manual_seed(123)
model = ModelClass(LLAMA32_CONFIG_1B)
model.load_state_dict(torch.load(llama3_weights_path))
model.eval()
start_context = "Llamas eat"
tokenizer = tiktoken.get_encoding("gpt2")
encoded = tokenizer.encode(start_context)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
print("\nInput text:", start_context)
print("Encoded input text:", encoded)
print("encoded_tensor.shape:", encoded_tensor.shape)
out = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=5,
context_size=LLAMA32_CONFIG_1B["context_length"]
)
print("Encoded output text:", out)
expect = torch.tensor([
[43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419]
])
assert torch.equal(expect, out)