diff --git a/.github/workflows/basic-tests-pixi.yml b/.github/workflows/basic-tests-pixi.yml index e661151..a296d50 100644 --- a/.github/workflows/basic-tests-pixi.yml +++ b/.github/workflows/basic-tests-pixi.yml @@ -28,7 +28,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, windows-latest] steps: - uses: actions/checkout@v4 diff --git a/.gitignore b/.gitignore index 3127d4d..4e9a19e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ - # Configs and keys ch05/07_gpt_to_llama/config.json ch07/02_dataset-utilities/config.json @@ -78,6 +77,11 @@ ch07/01_main-chapter-code/gpt2-medium355M-sft-standalone.pth ch07/01_main-chapter-code/Smalltestmodel-sft-standalone.pth ch07/01_main-chapter-code/gpt2/ +Qwen3-0.6B-Base/ +Qwen3-0.6B/ +tokenizer-base.json +tokenizer.json + # Datasets the-verdict.txt diff --git a/pkg/llms_from_scratch/README.md b/pkg/llms_from_scratch/README.md index 295d4ea..326f9f2 100644 --- a/pkg/llms_from_scratch/README.md +++ b/pkg/llms_from_scratch/README.md @@ -132,7 +132,8 @@ For more information about KV caching, please see the [KV cache README](../../ch ```python from llms_from_scratch.llama3 import ( - Llama3Model, + load_weights_into_llama, + Llama3Model, Llama3ModelFast, Llama3Tokenizer, ChatFormat, @@ -154,6 +155,7 @@ For more information about KV caching, please see the [KV cache README](../../ch ```python from llms_from_scratch.qwen3 import ( + load_weights_into_qwen Qwen3Model, Qwen3Tokenizer, ) diff --git a/pkg/llms_from_scratch/llama3.py b/pkg/llms_from_scratch/llama3.py index ddd4cde..0580ee2 100644 --- a/pkg/llms_from_scratch/llama3.py +++ b/pkg/llms_from_scratch/llama3.py @@ -497,3 +497,77 @@ class Llama3ModelFast(nn.Module): x = self.final_norm(x) logits = self.out_head(x.to(self.cfg["dtype"])) return logits + + +def assign(left, right, tensor_name="unknown"): + if left.shape != right.shape: + raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}") + + if isinstance(right, torch.Tensor): + return torch.nn.Parameter(right.clone().detach()) + else: + return torch.nn.Parameter(torch.tensor(right)) + + +def load_weights_into_llama(model, param_config, params): + model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") + + for l in range(param_config["n_layers"]): + + # Load attention weights + model.trf_blocks[l].att.W_query.weight = assign( + model.trf_blocks[l].att.W_query.weight, + params[f"model.layers.{l}.self_attn.q_proj.weight"], + f"model.layers.{l}.self_attn.q_proj.weight" + ) + model.trf_blocks[l].att.W_key.weight = assign( + model.trf_blocks[l].att.W_key.weight, + params[f"model.layers.{l}.self_attn.k_proj.weight"], + f"model.layers.{l}.self_attn.k_proj.weight" + ) + model.trf_blocks[l].att.W_value.weight = assign( + model.trf_blocks[l].att.W_value.weight, + params[f"model.layers.{l}.self_attn.v_proj.weight"], + f"model.layers.{l}.self_attn.v_proj.weight" + ) + model.trf_blocks[l].att.out_proj.weight = assign( + model.trf_blocks[l].att.out_proj.weight, + params[f"model.layers.{l}.self_attn.o_proj.weight"], + f"model.layers.{l}.self_attn.o_proj.weight" + ) + model.trf_blocks[l].norm1.weight = assign( + model.trf_blocks[l].norm1.weight, + params[f"model.layers.{l}.input_layernorm.weight"], + f"model.layers.{l}.input_layernorm.weight" + ) + + # Load FeedForward weights + model.trf_blocks[l].ff.fc1.weight = assign( + model.trf_blocks[l].ff.fc1.weight, + params[f"model.layers.{l}.mlp.gate_proj.weight"], + f"model.layers.{l}.mlp.gate_proj.weight" + ) + model.trf_blocks[l].ff.fc2.weight = assign( + model.trf_blocks[l].ff.fc2.weight, + params[f"model.layers.{l}.mlp.up_proj.weight"], + f"model.layers.{l}.mlp.up_proj.weight" + ) + model.trf_blocks[l].ff.fc3.weight = assign( + model.trf_blocks[l].ff.fc3.weight, + params[f"model.layers.{l}.mlp.down_proj.weight"], + f"model.layers.{l}.mlp.down_proj.weight" + ) + model.trf_blocks[l].norm2.weight = assign( + model.trf_blocks[l].norm2.weight, + params[f"model.layers.{l}.post_attention_layernorm.weight"], + f"model.layers.{l}.post_attention_layernorm.weight" + ) + + # Load output layer weights + model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight") + + if "lm_head.weight" in params.keys(): + model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight") + else: + model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") + print("Model uses weight tying.") diff --git a/pkg/llms_from_scratch/tests/test_llama3.py b/pkg/llms_from_scratch/tests/test_llama3.py index 0159fa4..a3c3f69 100644 --- a/pkg/llms_from_scratch/tests/test_llama3.py +++ b/pkg/llms_from_scratch/tests/test_llama3.py @@ -5,11 +5,12 @@ from llms_from_scratch.ch04 import generate_text_simple from llms_from_scratch.llama3 import ( - compute_rope_params, apply_rope, - LLAMA32_CONFIG_1B, + compute_rope_params, GroupedQueryAttention, GroupedQueryAttentionFast, + load_weights_into_llama, + LLAMA32_CONFIG_1B, Llama3Model, ) from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV @@ -246,3 +247,61 @@ def test_rmsnorm_equivalence(): out2 = lit_norm(x) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) + + +@torch.inference_mode() +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") +def test_llama3_base_equivalence_with_transformers(): + from transformers.models.llama import LlamaConfig, LlamaForCausalLM + cfg = { + "vocab_size": 257, + "context_length": 8192, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "n_kv_groups": 2, + "rope_base": 500_000.0, + "rope_freq": { + "factor": 32.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_context_length": 8192, + }, + "dtype": torch.float32, + } + + ours = Llama3Model(cfg) + + hf_cfg = LlamaConfig( + vocab_size=cfg["vocab_size"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_key_value_heads=cfg["n_kv_groups"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + max_position_embeddings=cfg["context_length"], + rms_norm_eps=1e-5, + attention_bias=False, + rope_theta=cfg["rope_base"], + tie_word_embeddings=False, + attn_implementation="eager", + torch_dtype=torch.float32, + rope_scaling={ + "type": "llama3", + "factor": cfg["rope_freq"]["factor"], + "low_freq_factor": cfg["rope_freq"]["low_freq_factor"], + "high_freq_factor": cfg["rope_freq"]["high_freq_factor"], + "original_max_position_embeddings": cfg["rope_freq"]["original_context_length"], + }, + ) + theirs = LlamaForCausalLM(hf_cfg) + + hf_state = theirs.state_dict() + load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state) + + x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long) + ours_logits = ours(x) + theirs_logits = theirs(x).logits.to(ours_logits.dtype) + + torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5) \ No newline at end of file diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py index 3825dae..82956de 100644 --- a/pkg/llms_from_scratch/tests/test_qwen3.py +++ b/pkg/llms_from_scratch/tests/test_qwen3.py @@ -5,12 +5,13 @@ from llms_from_scratch.ch04 import generate_text_simple from llms_from_scratch.qwen3 import ( - compute_rope_params, apply_rope, + compute_rope_params, + load_weights_into_qwen, QWEN_CONFIG_06_B, - RMSNorm, Qwen3Model, - Qwen3Tokenizer + Qwen3Tokenizer, + RMSNorm, ) from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV from llms_from_scratch.kv_cache.utils import KVCache @@ -87,6 +88,7 @@ def dummy_cfg_moe(dummy_cfg_base): return cfg +@torch.inference_mode() def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input): torch.manual_seed(123) model = Qwen3Model(dummy_cfg_base) @@ -95,6 +97,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input): f"Expected shape (1, seq_len, vocab_size), got {out.shape}" +@torch.inference_mode() def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input): torch.manual_seed(123) model = Qwen3Model(dummy_cfg_moe) @@ -105,6 +108,7 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input): "Expected MoEFeedForward in at least one transformer block" +@torch.inference_mode() @pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"]) def test_qwen3_kvcache_equivalence(cfg_name, request): cfg = request.getfixturevalue(cfg_name) @@ -438,3 +442,51 @@ def test_tokenizer_equivalence(): expected_pad_token = "<|endoftext|>" assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token + + +@torch.inference_mode() +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") +def test_qwen3_base_equivalence_with_transformers(): + + from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM + + # Tiny config so the test is fast + cfg = { + "vocab_size": 257, + "context_length": 8, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "qk_norm": True, + "n_kv_groups": 2, + "rope_base": 1_000_000.0, + "dtype": torch.float32, + } + model = Qwen3Model(cfg) + + hf_cfg = Qwen3Config( + vocab_size=cfg["vocab_size"], + max_position_embeddings=cfg["context_length"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + head_dim=cfg["head_dim"], + num_key_value_heads=cfg["n_kv_groups"], + rope_theta=cfg["rope_base"], + tie_word_embeddings=False, + attn_implementation="eager", + torch_dtype=torch.float32, + ) + hf_model = Qwen3ForCausalLM(hf_cfg) + + hf_state = hf_model.state_dict() + param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]} + load_weights_into_qwen(model, param_config, hf_state) + + x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long) + ours_logits = model(x) + theirs_logits = hf_model(x).logits + torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)