Sebastian Raschka 80d4732456
add HF equivalency tests for standalone nbs (#774)
* add HF equivalency tests for standalone nbs

* update

* update

* update

* update
2025-08-18 18:58:46 -05:00

117 lines
3.6 KiB
Python

# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import importlib
from pathlib import Path
import pytest
import torch
from llms_from_scratch.utils import import_definitions_from_notebook
transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-llama32.ipynb")
return mod
@pytest.fixture
def dummy_input():
torch.manual_seed(123)
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
@pytest.fixture
def dummy_cfg_base():
return {
"vocab_size": 100,
"emb_dim": 32, # hidden_size
"hidden_dim": 64, # intermediate_size (FFN)
"n_layers": 2,
"n_heads": 4,
"head_dim": 8,
"n_kv_groups": 1,
"dtype": torch.float32,
"rope_base": 500_000.0,
"rope_freq": {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
},
"context_length": 64,
}
@torch.inference_mode()
def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, nb_imports):
torch.manual_seed(123)
model = nb_imports.Llama3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_llama3_base_equivalence_with_transformers(nb_imports):
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
cfg = {
"vocab_size": 257,
"context_length": 8192,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"n_kv_groups": 2,
"rope_base": 500_000.0,
"rope_freq": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
},
"dtype": torch.float32,
}
ours = nb_imports.Llama3Model(cfg)
hf_cfg = LlamaConfig(
vocab_size=cfg["vocab_size"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_key_value_heads=cfg["n_kv_groups"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
max_position_embeddings=cfg["context_length"],
rms_norm_eps=1e-5,
attention_bias=False,
rope_theta=cfg["rope_base"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=torch.float32,
rope_scaling={
"type": "llama3",
"factor": cfg["rope_freq"]["factor"],
"low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
"high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
"original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
},
)
theirs = LlamaForCausalLM(hf_cfg)
hf_state = theirs.state_dict()
nb_imports.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
ours_logits = ours(x)
theirs_logits = theirs(x).logits.to(ours_logits.dtype)
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)