mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-07 21:37:18 +00:00
add HF equivalency tests for standalone nbs (#774)
* add HF equivalency tests for standalone nbs * update * update * update * update
This commit is contained in:
parent
a6b883c9f9
commit
80d4732456
6
.github/workflows/basic-tests-linux-uv.yml
vendored
6
.github/workflows/basic-tests-linux-uv.yml
vendored
@ -51,8 +51,10 @@ jobs:
|
|||||||
pytest --ruff ch04/01_main-chapter-code/tests.py
|
pytest --ruff ch04/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch04/03_kv-cache/tests.py
|
pytest --ruff ch04/03_kv-cache/tests.py
|
||||||
pytest --ruff ch05/01_main-chapter-code/tests.py
|
pytest --ruff ch05/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
|
||||||
pytest --ruff ch05/12_gemma3/tests/test_gemma3.py
|
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
|
||||||
|
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
|
||||||
|
pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
|
||||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||||
|
|
||||||
- name: Validate Selected Jupyter Notebooks (uv)
|
- name: Validate Selected Jupyter Notebooks (uv)
|
||||||
|
|||||||
6
.github/workflows/basic-tests-macos-uv.yml
vendored
6
.github/workflows/basic-tests-macos-uv.yml
vendored
@ -50,8 +50,10 @@ jobs:
|
|||||||
pytest --ruff setup/02_installing-python-libraries/tests.py
|
pytest --ruff setup/02_installing-python-libraries/tests.py
|
||||||
pytest --ruff ch04/01_main-chapter-code/tests.py
|
pytest --ruff ch04/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/01_main-chapter-code/tests.py
|
pytest --ruff ch05/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
|
||||||
pytest --ruff ch05/12_gemma3/tests/test_gemma3.py
|
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
|
||||||
|
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
|
||||||
|
pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
|
||||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||||
|
|
||||||
- name: Validate Selected Jupyter Notebooks (uv)
|
- name: Validate Selected Jupyter Notebooks (uv)
|
||||||
|
|||||||
@ -47,7 +47,6 @@ jobs:
|
|||||||
pytest --ruff setup/02_installing-python-libraries/tests.py
|
pytest --ruff setup/02_installing-python-libraries/tests.py
|
||||||
pytest --ruff ch04/01_main-chapter-code/tests.py
|
pytest --ruff ch04/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/01_main-chapter-code/tests.py
|
pytest --ruff ch05/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
|
||||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||||
|
|
||||||
- name: Validate Selected Jupyter Notebooks
|
- name: Validate Selected Jupyter Notebooks
|
||||||
|
|||||||
2
.github/workflows/basic-tests-pip.yml
vendored
2
.github/workflows/basic-tests-pip.yml
vendored
@ -41,7 +41,6 @@ jobs:
|
|||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
|
|
||||||
pip install pytest pytest-ruff nbval
|
pip install pytest pytest-ruff nbval
|
||||||
|
|
||||||
- name: Test Selected Python Scripts
|
- name: Test Selected Python Scripts
|
||||||
@ -50,7 +49,6 @@ jobs:
|
|||||||
pytest --ruff setup/02_installing-python-libraries/tests.py
|
pytest --ruff setup/02_installing-python-libraries/tests.py
|
||||||
pytest --ruff ch04/01_main-chapter-code/tests.py
|
pytest --ruff ch04/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/01_main-chapter-code/tests.py
|
pytest --ruff ch05/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
|
||||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||||
|
|
||||||
- name: Validate Selected Jupyter Notebooks
|
- name: Validate Selected Jupyter Notebooks
|
||||||
|
|||||||
1
.github/workflows/basic-tests-pixi.yml
vendored
1
.github/workflows/basic-tests-pixi.yml
vendored
@ -50,7 +50,6 @@ jobs:
|
|||||||
pytest --ruff setup/02_installing-python-libraries/tests.py
|
pytest --ruff setup/02_installing-python-libraries/tests.py
|
||||||
pytest --ruff ch04/01_main-chapter-code/tests.py
|
pytest --ruff ch04/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/01_main-chapter-code/tests.py
|
pytest --ruff ch05/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
|
||||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||||
|
|
||||||
- name: Validate Selected Jupyter Notebooks
|
- name: Validate Selected Jupyter Notebooks
|
||||||
|
|||||||
2
.github/workflows/basic-tests-pytorch-rc.yml
vendored
2
.github/workflows/basic-tests-pytorch-rc.yml
vendored
@ -33,7 +33,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
uv sync --dev --python=3.10 # tests for backwards compatibility
|
uv sync --dev --python=3.10 # tests for backwards compatibility
|
||||||
uv pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
|
|
||||||
uv add pytest-ruff nbval
|
uv add pytest-ruff nbval
|
||||||
uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
|
uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||||
|
|
||||||
@ -43,7 +42,6 @@ jobs:
|
|||||||
pytest --ruff setup/02_installing-python-libraries/tests.py
|
pytest --ruff setup/02_installing-python-libraries/tests.py
|
||||||
pytest --ruff ch04/01_main-chapter-code/tests.py
|
pytest --ruff ch04/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/01_main-chapter-code/tests.py
|
pytest --ruff ch05/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
|
||||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||||
|
|
||||||
- name: Validate Selected Jupyter Notebooks
|
- name: Validate Selected Jupyter Notebooks
|
||||||
|
|||||||
@ -43,6 +43,7 @@ jobs:
|
|||||||
pip install tensorflow-io-gcs-filesystem==0.31.0 # Explicit for Windows
|
pip install tensorflow-io-gcs-filesystem==0.31.0 # Explicit for Windows
|
||||||
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
|
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
|
||||||
pip install pytest-ruff nbval
|
pip install pytest-ruff nbval
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
- name: Run Python Tests
|
- name: Run Python Tests
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -51,7 +52,9 @@ jobs:
|
|||||||
pytest --ruff setup/02_installing-python-libraries/tests.py
|
pytest --ruff setup/02_installing-python-libraries/tests.py
|
||||||
pytest --ruff ch04/01_main-chapter-code/tests.py
|
pytest --ruff ch04/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/01_main-chapter-code/tests.py
|
pytest --ruff ch05/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
|
||||||
|
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
|
||||||
|
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
|
||||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||||
|
|
||||||
- name: Run Jupyter Notebook Tests
|
- name: Run Jupyter Notebook Tests
|
||||||
|
|||||||
@ -51,7 +51,6 @@ jobs:
|
|||||||
pytest --ruff setup/02_installing-python-libraries/tests.py
|
pytest --ruff setup/02_installing-python-libraries/tests.py
|
||||||
pytest --ruff ch04/01_main-chapter-code/tests.py
|
pytest --ruff ch04/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/01_main-chapter-code/tests.py
|
pytest --ruff ch05/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
|
||||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||||
|
|
||||||
- name: Run Jupyter Notebook Tests
|
- name: Run Jupyter Notebook Tests
|
||||||
|
|||||||
116
ch05/07_gpt_to_llama/tests/test_llama32_nb.py
Normal file
116
ch05/07_gpt_to_llama/tests/test_llama32_nb.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||||
|
# Source for "Build a Large Language Model From Scratch"
|
||||||
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||||
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from llms_from_scratch.utils import import_definitions_from_notebook
|
||||||
|
|
||||||
|
|
||||||
|
transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nb_imports():
|
||||||
|
nb_dir = Path(__file__).resolve().parents[1]
|
||||||
|
mod = import_definitions_from_notebook(nb_dir, "standalone-llama32.ipynb")
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_input():
|
||||||
|
torch.manual_seed(123)
|
||||||
|
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_cfg_base():
|
||||||
|
return {
|
||||||
|
"vocab_size": 100,
|
||||||
|
"emb_dim": 32, # hidden_size
|
||||||
|
"hidden_dim": 64, # intermediate_size (FFN)
|
||||||
|
"n_layers": 2,
|
||||||
|
"n_heads": 4,
|
||||||
|
"head_dim": 8,
|
||||||
|
"n_kv_groups": 1,
|
||||||
|
"dtype": torch.float32,
|
||||||
|
"rope_base": 500_000.0,
|
||||||
|
"rope_freq": {
|
||||||
|
"factor": 8.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"original_context_length": 8192,
|
||||||
|
},
|
||||||
|
"context_length": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, nb_imports):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
model = nb_imports.Llama3Model(dummy_cfg_base)
|
||||||
|
out = model(dummy_input)
|
||||||
|
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||||
|
def test_llama3_base_equivalence_with_transformers(nb_imports):
|
||||||
|
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
|
||||||
|
cfg = {
|
||||||
|
"vocab_size": 257,
|
||||||
|
"context_length": 8192,
|
||||||
|
"emb_dim": 32,
|
||||||
|
"n_heads": 4,
|
||||||
|
"n_layers": 2,
|
||||||
|
"hidden_dim": 64,
|
||||||
|
"n_kv_groups": 2,
|
||||||
|
"rope_base": 500_000.0,
|
||||||
|
"rope_freq": {
|
||||||
|
"factor": 32.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"original_context_length": 8192,
|
||||||
|
},
|
||||||
|
"dtype": torch.float32,
|
||||||
|
}
|
||||||
|
|
||||||
|
ours = nb_imports.Llama3Model(cfg)
|
||||||
|
|
||||||
|
hf_cfg = LlamaConfig(
|
||||||
|
vocab_size=cfg["vocab_size"],
|
||||||
|
hidden_size=cfg["emb_dim"],
|
||||||
|
num_attention_heads=cfg["n_heads"],
|
||||||
|
num_key_value_heads=cfg["n_kv_groups"],
|
||||||
|
num_hidden_layers=cfg["n_layers"],
|
||||||
|
intermediate_size=cfg["hidden_dim"],
|
||||||
|
max_position_embeddings=cfg["context_length"],
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
attention_bias=False,
|
||||||
|
rope_theta=cfg["rope_base"],
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
attn_implementation="eager",
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
rope_scaling={
|
||||||
|
"type": "llama3",
|
||||||
|
"factor": cfg["rope_freq"]["factor"],
|
||||||
|
"low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
|
||||||
|
"high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
|
||||||
|
"original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
theirs = LlamaForCausalLM(hf_cfg)
|
||||||
|
|
||||||
|
hf_state = theirs.state_dict()
|
||||||
|
nb_imports.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
|
||||||
|
|
||||||
|
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
|
||||||
|
ours_logits = ours(x)
|
||||||
|
theirs_logits = theirs(x).logits.to(ours_logits.dtype)
|
||||||
|
|
||||||
|
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
|
||||||
122
ch05/11_qwen3/tests/test_qwen3_nb.py
Normal file
122
ch05/11_qwen3/tests/test_qwen3_nb.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||||
|
# Source for "Build a Large Language Model From Scratch"
|
||||||
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||||
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from llms_from_scratch.utils import import_definitions_from_notebook
|
||||||
|
|
||||||
|
|
||||||
|
transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nb_imports():
|
||||||
|
nb_dir = Path(__file__).resolve().parents[1]
|
||||||
|
mod = import_definitions_from_notebook(nb_dir, "standalone-qwen3.ipynb")
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_input():
|
||||||
|
torch.manual_seed(123)
|
||||||
|
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_cfg_base():
|
||||||
|
return {
|
||||||
|
"vocab_size": 100,
|
||||||
|
"emb_dim": 32,
|
||||||
|
"hidden_dim": 64,
|
||||||
|
"n_layers": 2,
|
||||||
|
"n_heads": 4,
|
||||||
|
"head_dim": 8,
|
||||||
|
"n_kv_groups": 1,
|
||||||
|
"qk_norm": False,
|
||||||
|
"dtype": torch.float32,
|
||||||
|
"rope_base": 10000,
|
||||||
|
"context_length": 64,
|
||||||
|
"num_experts": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_cfg_moe(dummy_cfg_base):
|
||||||
|
cfg = dummy_cfg_base.copy()
|
||||||
|
cfg.update({
|
||||||
|
"num_experts": 4,
|
||||||
|
"num_experts_per_tok": 2,
|
||||||
|
"moe_intermediate_size": 64,
|
||||||
|
})
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
model = nb_imports.Qwen3Model(dummy_cfg_base)
|
||||||
|
out = model(dummy_input)
|
||||||
|
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
|
||||||
|
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||||
|
def test_qwen3_base_equivalence_with_transformers(nb_imports):
|
||||||
|
from transformers import Qwen3Config, Qwen3ForCausalLM
|
||||||
|
|
||||||
|
# Tiny config so the test is fast
|
||||||
|
cfg = {
|
||||||
|
"vocab_size": 257,
|
||||||
|
"context_length": 8,
|
||||||
|
"emb_dim": 32,
|
||||||
|
"n_heads": 4,
|
||||||
|
"n_layers": 2,
|
||||||
|
"hidden_dim": 64,
|
||||||
|
"head_dim": 8,
|
||||||
|
"qk_norm": True,
|
||||||
|
"n_kv_groups": 2,
|
||||||
|
"rope_base": 1_000_000.0,
|
||||||
|
"rope_local_base": 10_000.0,
|
||||||
|
"sliding_window": 4,
|
||||||
|
"layer_types": ["full_attention", "full_attention"],
|
||||||
|
"dtype": torch.float32,
|
||||||
|
"query_pre_attn_scalar": 256,
|
||||||
|
}
|
||||||
|
model = nb_imports.Qwen3Model(cfg)
|
||||||
|
|
||||||
|
hf_cfg = Qwen3Config(
|
||||||
|
vocab_size=cfg["vocab_size"],
|
||||||
|
max_position_embeddings=cfg["context_length"],
|
||||||
|
hidden_size=cfg["emb_dim"],
|
||||||
|
num_attention_heads=cfg["n_heads"],
|
||||||
|
num_hidden_layers=cfg["n_layers"],
|
||||||
|
intermediate_size=cfg["hidden_dim"],
|
||||||
|
head_dim=cfg["head_dim"],
|
||||||
|
num_key_value_heads=cfg["n_kv_groups"],
|
||||||
|
rope_theta=cfg["rope_base"],
|
||||||
|
rope_local_base_freq=cfg["rope_local_base"],
|
||||||
|
layer_types=cfg["layer_types"],
|
||||||
|
sliding_window=cfg["sliding_window"],
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
attn_implementation="eager",
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
|
||||||
|
rope_scaling={"rope_type": "default"},
|
||||||
|
)
|
||||||
|
hf_model = Qwen3ForCausalLM(hf_cfg)
|
||||||
|
|
||||||
|
hf_state = hf_model.state_dict()
|
||||||
|
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
||||||
|
nb_imports.load_weights_into_qwen(model, param_config, hf_state)
|
||||||
|
|
||||||
|
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
|
||||||
|
ours_logits = model(x)
|
||||||
|
theirs_logits = hf_model(x).logits
|
||||||
|
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
|
||||||
@ -4,77 +4,21 @@
|
|||||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import types
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import nbformat
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from llms_from_scratch.utils import import_definitions_from_notebook
|
||||||
|
|
||||||
|
|
||||||
transformers_installed = importlib.util.find_spec("transformers") is not None
|
transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||||
|
|
||||||
|
|
||||||
def _extract_defs_and_classes_from_code(src):
|
@pytest.fixture
|
||||||
lines = src.splitlines()
|
def nb_imports():
|
||||||
kept = []
|
nb_dir = Path(__file__).resolve().parents[1]
|
||||||
i = 0
|
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
|
||||||
while i < len(lines):
|
|
||||||
line = lines[i]
|
|
||||||
stripped = line.lstrip()
|
|
||||||
# Keep decorators attached to the next def/class
|
|
||||||
if stripped.startswith("@"):
|
|
||||||
# Look ahead: if the next non-empty line starts with def/class, keep decorator
|
|
||||||
j = i + 1
|
|
||||||
while j < len(lines) and not lines[j].strip():
|
|
||||||
j += 1
|
|
||||||
if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")):
|
|
||||||
kept.append(line)
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
if stripped.startswith("def ") or stripped.startswith("class "):
|
|
||||||
kept.append(line)
|
|
||||||
# capture until we leave the indentation block
|
|
||||||
base_indent = len(line) - len(stripped)
|
|
||||||
i += 1
|
|
||||||
while i < len(lines):
|
|
||||||
nxt = lines[i]
|
|
||||||
if nxt.strip() == "":
|
|
||||||
kept.append(nxt)
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
indent = len(nxt) - len(nxt.lstrip())
|
|
||||||
if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")):
|
|
||||||
break
|
|
||||||
kept.append(nxt)
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
i += 1
|
|
||||||
code = "\n".join(kept)
|
|
||||||
code = re.sub(r"def\s+load_weights_into_gemma\s*\(\s*Gemma3Model\s*,",
|
|
||||||
"def load_weights_into_gemma(model,",
|
|
||||||
code)
|
|
||||||
return code
|
|
||||||
|
|
||||||
|
|
||||||
def import_definitions_from_notebook(nb_dir_or_path, notebook_name):
|
|
||||||
nb_path = Path(nb_dir_or_path)
|
|
||||||
if nb_path.is_dir():
|
|
||||||
nb_file = nb_path / notebook_name
|
|
||||||
else:
|
|
||||||
nb_file = nb_path
|
|
||||||
if not nb_file.exists():
|
|
||||||
raise FileNotFoundError(f"Notebook not found: {nb_file}")
|
|
||||||
|
|
||||||
nb = nbformat.read(nb_file, as_version=4)
|
|
||||||
pieces = ["import torch", "import torch.nn as nn"]
|
|
||||||
for cell in nb.cells:
|
|
||||||
if cell.cell_type == "code":
|
|
||||||
pieces.append(_extract_defs_and_classes_from_code(cell.source))
|
|
||||||
src = "\n\n".join(pieces)
|
|
||||||
|
|
||||||
mod = types.ModuleType("gemma3_defs")
|
|
||||||
exec(src, mod.__dict__)
|
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
|
||||||
@ -106,25 +50,16 @@ def dummy_cfg_base():
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input):
|
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
|
||||||
nb_dir = Path(__file__).resolve().parents[1]
|
|
||||||
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
|
|
||||||
Gemma3Model = mod.Gemma3Model
|
|
||||||
|
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
model = Gemma3Model(dummy_cfg_base)
|
model = nb_imports.Gemma3Model(dummy_cfg_base)
|
||||||
out = model(dummy_input)
|
out = model(dummy_input)
|
||||||
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
|
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||||
def test_gemma3_base_equivalence_with_transformers():
|
def test_gemma3_base_equivalence_with_transformers(nb_imports):
|
||||||
nb_dir = Path(__file__).resolve().parents[1]
|
|
||||||
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
|
|
||||||
Gemma3Model = mod.Gemma3Model
|
|
||||||
load_weights_into_gemma = mod.load_weights_into_gemma
|
|
||||||
|
|
||||||
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
|
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
|
||||||
|
|
||||||
# Tiny config so the test is fast
|
# Tiny config so the test is fast
|
||||||
@ -145,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers():
|
|||||||
"dtype": torch.float32,
|
"dtype": torch.float32,
|
||||||
"query_pre_attn_scalar": 256,
|
"query_pre_attn_scalar": 256,
|
||||||
}
|
}
|
||||||
model = Gemma3Model(cfg)
|
model = nb_imports.Gemma3Model(cfg)
|
||||||
|
|
||||||
hf_cfg = Gemma3TextConfig(
|
hf_cfg = Gemma3TextConfig(
|
||||||
vocab_size=cfg["vocab_size"],
|
vocab_size=cfg["vocab_size"],
|
||||||
@ -170,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers():
|
|||||||
|
|
||||||
hf_state = hf_model.state_dict()
|
hf_state = hf_model.state_dict()
|
||||||
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
||||||
load_weights_into_gemma(model, param_config, hf_state)
|
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
|
||||||
|
|
||||||
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
|
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
|
||||||
ours_logits = model(x)
|
ours_logits = model(x)
|
||||||
124
pkg/llms_from_scratch/utils.py
Normal file
124
pkg/llms_from_scratch/utils.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||||
|
# Source for "Build a Large Language Model From Scratch"
|
||||||
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||||
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
|
||||||
|
# Internal utility functions (not intended for public use)
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import nbformat
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_imports(src: str):
|
||||||
|
out = []
|
||||||
|
try:
|
||||||
|
tree = ast.parse(src)
|
||||||
|
except SyntaxError:
|
||||||
|
return out
|
||||||
|
for node in tree.body:
|
||||||
|
if isinstance(node, ast.Import):
|
||||||
|
parts = []
|
||||||
|
for n in node.names:
|
||||||
|
parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
|
||||||
|
out.append("import " + ", ".join(parts))
|
||||||
|
elif isinstance(node, ast.ImportFrom):
|
||||||
|
module = node.module or ""
|
||||||
|
parts = []
|
||||||
|
for n in node.names:
|
||||||
|
parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
|
||||||
|
level = "." * node.level if getattr(node, "level", 0) else ""
|
||||||
|
out.append(f"from {level}{module} import " + ", ".join(parts))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_defs_and_classes_from_code(src):
|
||||||
|
lines = src.splitlines()
|
||||||
|
kept = []
|
||||||
|
i = 0
|
||||||
|
while i < len(lines):
|
||||||
|
line = lines[i]
|
||||||
|
stripped = line.lstrip()
|
||||||
|
if stripped.startswith("@"):
|
||||||
|
j = i + 1
|
||||||
|
while j < len(lines) and not lines[j].strip():
|
||||||
|
j += 1
|
||||||
|
if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")):
|
||||||
|
kept.append(line)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
if stripped.startswith("def ") or stripped.startswith("class "):
|
||||||
|
kept.append(line)
|
||||||
|
base_indent = len(line) - len(stripped)
|
||||||
|
i += 1
|
||||||
|
while i < len(lines):
|
||||||
|
nxt = lines[i]
|
||||||
|
if nxt.strip() == "":
|
||||||
|
kept.append(nxt)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
indent = len(nxt) - len(nxt.lstrip())
|
||||||
|
if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")):
|
||||||
|
break
|
||||||
|
kept.append(nxt)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
code = "\n".join(kept)
|
||||||
|
|
||||||
|
# General rule:
|
||||||
|
# replace functions defined like `def load_weights_into_xxx(ClassName, ...`
|
||||||
|
# with `def load_weights_into_xxx(model, ...`
|
||||||
|
code = re.sub(
|
||||||
|
r"(def\s+load_weights_into_\w+\s*\()\s*\w+\s*,",
|
||||||
|
r"\1model,",
|
||||||
|
code
|
||||||
|
)
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
def import_definitions_from_notebook(nb_dir_or_path, notebook_name=None, *, extra_globals=None):
|
||||||
|
nb_path = Path(nb_dir_or_path)
|
||||||
|
if notebook_name is not None:
|
||||||
|
nb_file = nb_path / notebook_name if nb_path.is_dir() else nb_path
|
||||||
|
else:
|
||||||
|
nb_file = nb_path
|
||||||
|
|
||||||
|
if not nb_file.exists():
|
||||||
|
raise FileNotFoundError(f"Notebook not found: {nb_file}")
|
||||||
|
|
||||||
|
nb = nbformat.read(nb_file, as_version=4)
|
||||||
|
|
||||||
|
import_lines = []
|
||||||
|
seen = set()
|
||||||
|
for cell in nb.cells:
|
||||||
|
if cell.cell_type == "code":
|
||||||
|
for line in _extract_imports(cell.source):
|
||||||
|
if line not in seen:
|
||||||
|
import_lines.append(line)
|
||||||
|
seen.add(line)
|
||||||
|
|
||||||
|
for required in ("import torch", "import torch.nn as nn"):
|
||||||
|
if required not in seen:
|
||||||
|
import_lines.append(required)
|
||||||
|
seen.add(required)
|
||||||
|
|
||||||
|
pieces = []
|
||||||
|
for cell in nb.cells:
|
||||||
|
if cell.cell_type == "code":
|
||||||
|
pieces.append(_extract_defs_and_classes_from_code(cell.source))
|
||||||
|
|
||||||
|
src = "\n\n".join(import_lines + pieces)
|
||||||
|
|
||||||
|
mod_name = nb_file.stem.replace("-", "_").replace(" ", "_") or "notebook_defs"
|
||||||
|
mod = types.ModuleType(mod_name)
|
||||||
|
|
||||||
|
if extra_globals:
|
||||||
|
mod.__dict__.update(extra_globals)
|
||||||
|
|
||||||
|
exec(src, mod.__dict__)
|
||||||
|
return mod
|
||||||
@ -30,6 +30,7 @@ dev = [
|
|||||||
"llms-from-scratch",
|
"llms-from-scratch",
|
||||||
"twine>=6.1.0",
|
"twine>=6.1.0",
|
||||||
"tokenizers>=0.21.1",
|
"tokenizers>=0.21.1",
|
||||||
|
"safetensors>=0.6.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user