Qwen3 and Llama3 equivalency teests with HF transformers (#768)

* Qwen3 and Llama3 equivalency teests with HF transformers

* update
This commit is contained in:
Sebastian Raschka 2025-08-14 18:36:07 -05:00 committed by GitHub
parent 2e3205f747
commit 07c3122b5c
6 changed files with 199 additions and 8 deletions

View File

@ -28,7 +28,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v4

6
.gitignore vendored
View File

@ -1,4 +1,3 @@
# Configs and keys
ch05/07_gpt_to_llama/config.json
ch07/02_dataset-utilities/config.json
@ -78,6 +77,11 @@ ch07/01_main-chapter-code/gpt2-medium355M-sft-standalone.pth
ch07/01_main-chapter-code/Smalltestmodel-sft-standalone.pth
ch07/01_main-chapter-code/gpt2/
Qwen3-0.6B-Base/
Qwen3-0.6B/
tokenizer-base.json
tokenizer.json
# Datasets
the-verdict.txt

View File

@ -132,7 +132,8 @@ For more information about KV caching, please see the [KV cache README](../../ch
```python
from llms_from_scratch.llama3 import (
Llama3Model,
load_weights_into_llama,
Llama3Model,
Llama3ModelFast,
Llama3Tokenizer,
ChatFormat,
@ -154,6 +155,7 @@ For more information about KV caching, please see the [KV cache README](../../ch
```python
from llms_from_scratch.qwen3 import (
load_weights_into_qwen
Qwen3Model,
Qwen3Tokenizer,
)

View File

@ -497,3 +497,77 @@ class Llama3ModelFast(nn.Module):
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits
def assign(left, right, tensor_name="unknown"):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
if isinstance(right, torch.Tensor):
return torch.nn.Parameter(right.clone().detach())
else:
return torch.nn.Parameter(torch.tensor(right))
def load_weights_into_llama(model, param_config, params):
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
for l in range(param_config["n_layers"]):
# Load attention weights
model.trf_blocks[l].att.W_query.weight = assign(
model.trf_blocks[l].att.W_query.weight,
params[f"model.layers.{l}.self_attn.q_proj.weight"],
f"model.layers.{l}.self_attn.q_proj.weight"
)
model.trf_blocks[l].att.W_key.weight = assign(
model.trf_blocks[l].att.W_key.weight,
params[f"model.layers.{l}.self_attn.k_proj.weight"],
f"model.layers.{l}.self_attn.k_proj.weight"
)
model.trf_blocks[l].att.W_value.weight = assign(
model.trf_blocks[l].att.W_value.weight,
params[f"model.layers.{l}.self_attn.v_proj.weight"],
f"model.layers.{l}.self_attn.v_proj.weight"
)
model.trf_blocks[l].att.out_proj.weight = assign(
model.trf_blocks[l].att.out_proj.weight,
params[f"model.layers.{l}.self_attn.o_proj.weight"],
f"model.layers.{l}.self_attn.o_proj.weight"
)
model.trf_blocks[l].norm1.weight = assign(
model.trf_blocks[l].norm1.weight,
params[f"model.layers.{l}.input_layernorm.weight"],
f"model.layers.{l}.input_layernorm.weight"
)
# Load FeedForward weights
model.trf_blocks[l].ff.fc1.weight = assign(
model.trf_blocks[l].ff.fc1.weight,
params[f"model.layers.{l}.mlp.gate_proj.weight"],
f"model.layers.{l}.mlp.gate_proj.weight"
)
model.trf_blocks[l].ff.fc2.weight = assign(
model.trf_blocks[l].ff.fc2.weight,
params[f"model.layers.{l}.mlp.up_proj.weight"],
f"model.layers.{l}.mlp.up_proj.weight"
)
model.trf_blocks[l].ff.fc3.weight = assign(
model.trf_blocks[l].ff.fc3.weight,
params[f"model.layers.{l}.mlp.down_proj.weight"],
f"model.layers.{l}.mlp.down_proj.weight"
)
model.trf_blocks[l].norm2.weight = assign(
model.trf_blocks[l].norm2.weight,
params[f"model.layers.{l}.post_attention_layernorm.weight"],
f"model.layers.{l}.post_attention_layernorm.weight"
)
# Load output layer weights
model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")
if "lm_head.weight" in params.keys():
model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
else:
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
print("Model uses weight tying.")

View File

@ -5,11 +5,12 @@
from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.llama3 import (
compute_rope_params,
apply_rope,
LLAMA32_CONFIG_1B,
compute_rope_params,
GroupedQueryAttention,
GroupedQueryAttentionFast,
load_weights_into_llama,
LLAMA32_CONFIG_1B,
Llama3Model,
)
from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV
@ -246,3 +247,61 @@ def test_rmsnorm_equivalence():
out2 = lit_norm(x)
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_llama3_base_equivalence_with_transformers():
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
cfg = {
"vocab_size": 257,
"context_length": 8192,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"n_kv_groups": 2,
"rope_base": 500_000.0,
"rope_freq": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
},
"dtype": torch.float32,
}
ours = Llama3Model(cfg)
hf_cfg = LlamaConfig(
vocab_size=cfg["vocab_size"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_key_value_heads=cfg["n_kv_groups"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
max_position_embeddings=cfg["context_length"],
rms_norm_eps=1e-5,
attention_bias=False,
rope_theta=cfg["rope_base"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=torch.float32,
rope_scaling={
"type": "llama3",
"factor": cfg["rope_freq"]["factor"],
"low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
"high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
"original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
},
)
theirs = LlamaForCausalLM(hf_cfg)
hf_state = theirs.state_dict()
load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
ours_logits = ours(x)
theirs_logits = theirs(x).logits.to(ours_logits.dtype)
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)

View File

@ -5,12 +5,13 @@
from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.qwen3 import (
compute_rope_params,
apply_rope,
compute_rope_params,
load_weights_into_qwen,
QWEN_CONFIG_06_B,
RMSNorm,
Qwen3Model,
Qwen3Tokenizer
Qwen3Tokenizer,
RMSNorm,
)
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
from llms_from_scratch.kv_cache.utils import KVCache
@ -87,6 +88,7 @@ def dummy_cfg_moe(dummy_cfg_base):
return cfg
@torch.inference_mode()
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_base)
@ -95,6 +97,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@torch.inference_mode()
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_moe)
@ -105,6 +108,7 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
"Expected MoEFeedForward in at least one transformer block"
@torch.inference_mode()
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
def test_qwen3_kvcache_equivalence(cfg_name, request):
cfg = request.getfixturevalue(cfg_name)
@ -438,3 +442,51 @@ def test_tokenizer_equivalence():
expected_pad_token = "<|endoftext|>"
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_qwen3_base_equivalence_with_transformers():
from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
# Tiny config so the test is fast
cfg = {
"vocab_size": 257,
"context_length": 8,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"head_dim": 8,
"qk_norm": True,
"n_kv_groups": 2,
"rope_base": 1_000_000.0,
"dtype": torch.float32,
}
model = Qwen3Model(cfg)
hf_cfg = Qwen3Config(
vocab_size=cfg["vocab_size"],
max_position_embeddings=cfg["context_length"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
head_dim=cfg["head_dim"],
num_key_value_heads=cfg["n_kv_groups"],
rope_theta=cfg["rope_base"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=torch.float32,
)
hf_model = Qwen3ForCausalLM(hf_cfg)
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
load_weights_into_qwen(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)
theirs_logits = hf_model(x).logits
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)