mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-03 11:20:49 +00:00
Qwen3 and Llama3 equivalency teests with HF transformers (#768)
* Qwen3 and Llama3 equivalency teests with HF transformers * update
This commit is contained in:
parent
2e3205f747
commit
07c3122b5c
2
.github/workflows/basic-tests-pixi.yml
vendored
2
.github/workflows/basic-tests-pixi.yml
vendored
@ -28,7 +28,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
os: [ubuntu-latest, windows-latest]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@ -1,4 +1,3 @@
|
||||
|
||||
# Configs and keys
|
||||
ch05/07_gpt_to_llama/config.json
|
||||
ch07/02_dataset-utilities/config.json
|
||||
@ -78,6 +77,11 @@ ch07/01_main-chapter-code/gpt2-medium355M-sft-standalone.pth
|
||||
ch07/01_main-chapter-code/Smalltestmodel-sft-standalone.pth
|
||||
ch07/01_main-chapter-code/gpt2/
|
||||
|
||||
Qwen3-0.6B-Base/
|
||||
Qwen3-0.6B/
|
||||
tokenizer-base.json
|
||||
tokenizer.json
|
||||
|
||||
# Datasets
|
||||
the-verdict.txt
|
||||
|
||||
|
||||
@ -132,7 +132,8 @@ For more information about KV caching, please see the [KV cache README](../../ch
|
||||
|
||||
```python
|
||||
from llms_from_scratch.llama3 import (
|
||||
Llama3Model,
|
||||
load_weights_into_llama,
|
||||
Llama3Model,
|
||||
Llama3ModelFast,
|
||||
Llama3Tokenizer,
|
||||
ChatFormat,
|
||||
@ -154,6 +155,7 @@ For more information about KV caching, please see the [KV cache README](../../ch
|
||||
|
||||
```python
|
||||
from llms_from_scratch.qwen3 import (
|
||||
load_weights_into_qwen
|
||||
Qwen3Model,
|
||||
Qwen3Tokenizer,
|
||||
)
|
||||
|
||||
@ -497,3 +497,77 @@ class Llama3ModelFast(nn.Module):
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
|
||||
|
||||
def assign(left, right, tensor_name="unknown"):
|
||||
if left.shape != right.shape:
|
||||
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
|
||||
|
||||
if isinstance(right, torch.Tensor):
|
||||
return torch.nn.Parameter(right.clone().detach())
|
||||
else:
|
||||
return torch.nn.Parameter(torch.tensor(right))
|
||||
|
||||
|
||||
def load_weights_into_llama(model, param_config, params):
|
||||
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
for l in range(param_config["n_layers"]):
|
||||
|
||||
# Load attention weights
|
||||
model.trf_blocks[l].att.W_query.weight = assign(
|
||||
model.trf_blocks[l].att.W_query.weight,
|
||||
params[f"model.layers.{l}.self_attn.q_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.q_proj.weight"
|
||||
)
|
||||
model.trf_blocks[l].att.W_key.weight = assign(
|
||||
model.trf_blocks[l].att.W_key.weight,
|
||||
params[f"model.layers.{l}.self_attn.k_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.k_proj.weight"
|
||||
)
|
||||
model.trf_blocks[l].att.W_value.weight = assign(
|
||||
model.trf_blocks[l].att.W_value.weight,
|
||||
params[f"model.layers.{l}.self_attn.v_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.v_proj.weight"
|
||||
)
|
||||
model.trf_blocks[l].att.out_proj.weight = assign(
|
||||
model.trf_blocks[l].att.out_proj.weight,
|
||||
params[f"model.layers.{l}.self_attn.o_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.o_proj.weight"
|
||||
)
|
||||
model.trf_blocks[l].norm1.weight = assign(
|
||||
model.trf_blocks[l].norm1.weight,
|
||||
params[f"model.layers.{l}.input_layernorm.weight"],
|
||||
f"model.layers.{l}.input_layernorm.weight"
|
||||
)
|
||||
|
||||
# Load FeedForward weights
|
||||
model.trf_blocks[l].ff.fc1.weight = assign(
|
||||
model.trf_blocks[l].ff.fc1.weight,
|
||||
params[f"model.layers.{l}.mlp.gate_proj.weight"],
|
||||
f"model.layers.{l}.mlp.gate_proj.weight"
|
||||
)
|
||||
model.trf_blocks[l].ff.fc2.weight = assign(
|
||||
model.trf_blocks[l].ff.fc2.weight,
|
||||
params[f"model.layers.{l}.mlp.up_proj.weight"],
|
||||
f"model.layers.{l}.mlp.up_proj.weight"
|
||||
)
|
||||
model.trf_blocks[l].ff.fc3.weight = assign(
|
||||
model.trf_blocks[l].ff.fc3.weight,
|
||||
params[f"model.layers.{l}.mlp.down_proj.weight"],
|
||||
f"model.layers.{l}.mlp.down_proj.weight"
|
||||
)
|
||||
model.trf_blocks[l].norm2.weight = assign(
|
||||
model.trf_blocks[l].norm2.weight,
|
||||
params[f"model.layers.{l}.post_attention_layernorm.weight"],
|
||||
f"model.layers.{l}.post_attention_layernorm.weight"
|
||||
)
|
||||
|
||||
# Load output layer weights
|
||||
model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")
|
||||
|
||||
if "lm_head.weight" in params.keys():
|
||||
model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
|
||||
else:
|
||||
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
print("Model uses weight tying.")
|
||||
|
||||
@ -5,11 +5,12 @@
|
||||
|
||||
from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.llama3 import (
|
||||
compute_rope_params,
|
||||
apply_rope,
|
||||
LLAMA32_CONFIG_1B,
|
||||
compute_rope_params,
|
||||
GroupedQueryAttention,
|
||||
GroupedQueryAttentionFast,
|
||||
load_weights_into_llama,
|
||||
LLAMA32_CONFIG_1B,
|
||||
Llama3Model,
|
||||
)
|
||||
from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV
|
||||
@ -246,3 +247,61 @@ def test_rmsnorm_equivalence():
|
||||
out2 = lit_norm(x)
|
||||
|
||||
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_llama3_base_equivalence_with_transformers():
|
||||
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
|
||||
cfg = {
|
||||
"vocab_size": 257,
|
||||
"context_length": 8192,
|
||||
"emb_dim": 32,
|
||||
"n_heads": 4,
|
||||
"n_layers": 2,
|
||||
"hidden_dim": 64,
|
||||
"n_kv_groups": 2,
|
||||
"rope_base": 500_000.0,
|
||||
"rope_freq": {
|
||||
"factor": 32.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"original_context_length": 8192,
|
||||
},
|
||||
"dtype": torch.float32,
|
||||
}
|
||||
|
||||
ours = Llama3Model(cfg)
|
||||
|
||||
hf_cfg = LlamaConfig(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
hidden_size=cfg["emb_dim"],
|
||||
num_attention_heads=cfg["n_heads"],
|
||||
num_key_value_heads=cfg["n_kv_groups"],
|
||||
num_hidden_layers=cfg["n_layers"],
|
||||
intermediate_size=cfg["hidden_dim"],
|
||||
max_position_embeddings=cfg["context_length"],
|
||||
rms_norm_eps=1e-5,
|
||||
attention_bias=False,
|
||||
rope_theta=cfg["rope_base"],
|
||||
tie_word_embeddings=False,
|
||||
attn_implementation="eager",
|
||||
torch_dtype=torch.float32,
|
||||
rope_scaling={
|
||||
"type": "llama3",
|
||||
"factor": cfg["rope_freq"]["factor"],
|
||||
"low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
|
||||
"high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
|
||||
"original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
|
||||
},
|
||||
)
|
||||
theirs = LlamaForCausalLM(hf_cfg)
|
||||
|
||||
hf_state = theirs.state_dict()
|
||||
load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
|
||||
|
||||
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
|
||||
ours_logits = ours(x)
|
||||
theirs_logits = theirs(x).logits.to(ours_logits.dtype)
|
||||
|
||||
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
|
||||
@ -5,12 +5,13 @@
|
||||
|
||||
from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.qwen3 import (
|
||||
compute_rope_params,
|
||||
apply_rope,
|
||||
compute_rope_params,
|
||||
load_weights_into_qwen,
|
||||
QWEN_CONFIG_06_B,
|
||||
RMSNorm,
|
||||
Qwen3Model,
|
||||
Qwen3Tokenizer
|
||||
Qwen3Tokenizer,
|
||||
RMSNorm,
|
||||
)
|
||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
||||
from llms_from_scratch.kv_cache.utils import KVCache
|
||||
@ -87,6 +88,7 @@ def dummy_cfg_moe(dummy_cfg_base):
|
||||
return cfg
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
|
||||
torch.manual_seed(123)
|
||||
model = Qwen3Model(dummy_cfg_base)
|
||||
@ -95,6 +97,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
|
||||
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
|
||||
torch.manual_seed(123)
|
||||
model = Qwen3Model(dummy_cfg_moe)
|
||||
@ -105,6 +108,7 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
|
||||
"Expected MoEFeedForward in at least one transformer block"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
|
||||
def test_qwen3_kvcache_equivalence(cfg_name, request):
|
||||
cfg = request.getfixturevalue(cfg_name)
|
||||
@ -438,3 +442,51 @@ def test_tokenizer_equivalence():
|
||||
expected_pad_token = "<|endoftext|>"
|
||||
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
|
||||
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_qwen3_base_equivalence_with_transformers():
|
||||
|
||||
from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
# Tiny config so the test is fast
|
||||
cfg = {
|
||||
"vocab_size": 257,
|
||||
"context_length": 8,
|
||||
"emb_dim": 32,
|
||||
"n_heads": 4,
|
||||
"n_layers": 2,
|
||||
"hidden_dim": 64,
|
||||
"head_dim": 8,
|
||||
"qk_norm": True,
|
||||
"n_kv_groups": 2,
|
||||
"rope_base": 1_000_000.0,
|
||||
"dtype": torch.float32,
|
||||
}
|
||||
model = Qwen3Model(cfg)
|
||||
|
||||
hf_cfg = Qwen3Config(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
max_position_embeddings=cfg["context_length"],
|
||||
hidden_size=cfg["emb_dim"],
|
||||
num_attention_heads=cfg["n_heads"],
|
||||
num_hidden_layers=cfg["n_layers"],
|
||||
intermediate_size=cfg["hidden_dim"],
|
||||
head_dim=cfg["head_dim"],
|
||||
num_key_value_heads=cfg["n_kv_groups"],
|
||||
rope_theta=cfg["rope_base"],
|
||||
tie_word_embeddings=False,
|
||||
attn_implementation="eager",
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
hf_model = Qwen3ForCausalLM(hf_cfg)
|
||||
|
||||
hf_state = hf_model.state_dict()
|
||||
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
||||
load_weights_into_qwen(model, param_config, hf_state)
|
||||
|
||||
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
|
||||
ours_logits = model(x)
|
||||
theirs_logits = hf_model(x).logits
|
||||
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user