Sebastian Raschka aedad7efc3
Add Llama 3.2 to pkg (#591)
* Add Llama 3.2 to pkg

* remove redundant attributes

* update tests

* updates

* updates

* updates

* fix link

* fix link
2025-03-31 18:59:47 -05:00

148 lines
4.2 KiB
Python

# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.llama3 import (
compute_rope_params,
apply_rope,
rescale_theta,
LLAMA32_CONFIG_1B,
Llama3Model
)
import importlib
import pytest
import tiktoken
import torch
transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_rope():
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
# Settings
batch_size = 1
context_len = 8192
num_heads = 4
head_dim = 16
rope_theta = 500_000
rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
}
# Instantiate RoPE parameters
cos, sin = compute_rope_params(
head_dim=head_dim,
theta_base=rope_theta,
context_length=context_len,
freq_config=rope_config,
)
# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
# Apply rotary position embeddings
queries_rot = apply_rope(queries, cos, sin)
keys_rot = apply_rope(keys, cos, sin)
# Generate reference RoPE via HF
hf_rope_params = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}
class RoPEConfig:
rope_type = "llama3"
rope_scaling = hf_rope_params
factor = 1.0
dim: int = head_dim
rope_theta = 500_000
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
torch.testing.assert_close(sin, ref_sin.squeeze(0))
torch.testing.assert_close(cos, ref_cos.squeeze(0))
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
def test_rescale():
new_theta = rescale_theta(
theta_old=500_000.,
context_length_old=131_072,
context_length_new=8192
)
assert new_theta == 31250.
old_theta = rescale_theta(
theta_old=new_theta,
context_length_old=8192,
context_length_new=131_072
)
assert old_theta == 500_000.
@pytest.mark.parametrize("ModelClass", [Llama3Model])
def test_gpt_model_variants(ModelClass):
torch.manual_seed(123)
model = ModelClass(LLAMA32_CONFIG_1B)
model.eval()
start_context = "Hello, I am"
tokenizer = tiktoken.get_encoding("gpt2")
encoded = tokenizer.encode(start_context)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
print("\nInput text:", start_context)
print("Encoded input text:", encoded)
print("encoded_tensor.shape:", encoded_tensor.shape)
out = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=10,
context_size=LLAMA32_CONFIG_1B["context_length"]
)
expect = torch.tensor([
[15496, 11, 314, 716, 78563, 89362, 19616, 115725, 114917,
97198, 60342, 19108, 100752, 98969]
])
assert torch.equal(expect, out)