Support different Qwen3 sizes in pkg (#714)

This commit is contained in:
Sebastian Raschka 2025-06-28 08:00:23 -05:00 committed by GitHub
parent 8c8ff24118
commit dc2f8e95d4
4 changed files with 194 additions and 175 deletions

View File

@ -6,9 +6,9 @@ This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this f
 
### Using Qwen3 0.6B via the `llms-from-scratch` package
### Using Qwen3 via the `llms-from-scratch` package
For an easy way to use the Qwen3 0.6B from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
For an easy way to use the Qwen3 from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
 
#### 1) Installation
@ -36,9 +36,9 @@ TOP_K = 1
```
 
#### 3) Weight download and loading
#### 3a) Weight download and loading of the 0.6B model
This automatically downloads the weight file based on the model choice above:
The following automatically downloads the weight file based on the model choice (reasoning or base) above. Note that this section focuses on the 0.6B model. Skip this section and continue with section 3b) if you want to work with any of the larger models (1.7B, 4B, 8B, or 32B).
```python
from llms_from_scratch.qwen3 import download_from_huggingface
@ -77,10 +77,74 @@ device = (
torch.device("mps") if torch.backends.mps.is_available() else
torch.device("cpu")
)
model.to(device)
model.to(device);
```
 
#### 3b) Weight download and loading of the larger Qwen models
If you are interested in working with any of the larger Qwen models, for instance, 1.7B, 4B, 8B, or 32B, please use the following code below instead of the code under 3a), which requires additional code dependencies:
```bash
pip install safetensors huggingface_hub
```
Then use the following code (make appropriate changes to `USE_MODEL` to select the desired model size)
```python
USE_MODEL = "1.7B"
if USE_MODEL == "1.7B":
from llms_from_scratch.qwen3 import QWEN3_CONFIG_1_7B as QWEN3_CONFIG
elif USE_MODEL == "4B":
from llms_from_scratch.qwen3 import QWEN3_CONFIG_4B as QWEN3_CONFIG
elif USE_MODEL == "8B":
from llms_from_scratch.qwen3 import QWEN3_CONFIG_8B as QWEN3_CONFIG
elif USE_MODEL == "14B":
from llms_from_scratch.qwen3 import QWEN3_CONFIG_14B as QWEN3_CONFIG
elif USE_MODEL == "32B":
from llms_from_scratch.qwen3 import QWEN3_CONFIG_32B as QWEN3_CONFIG
else:
raise ValueError("Invalid USE_MODEL name.")
repo_id = f"Qwen/Qwen3-{USE_MODEL}"
local_dir = f"Qwen3-{USE_MODEL}"
if not USE_REASONING_MODEL:
repo_id = f"{repo_id}-Base"
local_dir = f"{local_dir}-Base"
```
Now, download and load the weights into the `model`:
```python
from llms_from_scratch.qwen3 import (
Qwen3Model,
download_from_huggingface_from_snapshots,
load_weights_into_qwen
)
model = Qwen3Model(QWEN3_CONFIG)
weights_dict = download_from_huggingface_from_snapshots(
repo_id=repo_id,
local_dir=local_dir
)
load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)
del weights_dict # delete weight dictionary to free up disk space
device = (
torch.device("cuda") if torch.cuda.is_available() else
torch.device("mps") if torch.backends.mps.is_available() else
torch.device("cpu")
)
model.to(device);
```
 
#### 4) Initialize tokenizer
The following code downloads and initializes the tokenizer:

View File

@ -4,29 +4,17 @@
# Code: https://github.com/rasbt/LLMs-from-scratch
from .utils import KVCache # noqa: F401
import os
import urllib.request
from pathlib import Path
from ..qwen3 import ( # noqa: F401
QWEN_CONFIG_06_B, QWEN3_CONFIG_1_7B, QWEN3_CONFIG_4B,
QWEN3_CONFIG_8B, QWEN3_CONFIG_14B, QWEN3_CONFIG_32B,
Qwen3Tokenizer, load_weights_into_qwen,
download_from_huggingface,
download_from_huggingface_from_snapshots
)
import torch
import torch.nn as nn
# 0.6B model
QWEN_CONFIG_06_B = {
"vocab_size": 151_936, # Vocabulary size
"context_length": 40_960, # Context length that was used to train the model
"emb_dim": 1024, # Embedding dimension
"n_heads": 16, # Number of attention heads
"n_layers": 28, # Number of layers
"hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
"head_dim": 128, # Size of the heads in GQA
"qk_norm": True, # Whether to normalize queries and values in GQA
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
"rope_base": 1_000_000.0, # The base in RoPE's "theta"
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
}
class Qwen3Model(nn.Module):
def __init__(self, cfg):
@ -285,150 +273,3 @@ class RMSNorm(nn.Module):
norm_x = norm_x + self.shift
return norm_x.to(input_dtype)
def load_weights_into_qwen(model, param_config, params):
def assign(left, right, tensor_name="unknown"):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
for l in range(param_config["n_layers"]):
block = model.trf_blocks[l]
att = block.att
# Q, K, V projections
att.W_query.weight = assign(
att.W_query.weight,
params[f"model.layers.{l}.self_attn.q_proj.weight"],
f"model.layers.{l}.self_attn.q_proj.weight"
)
att.W_key.weight = assign(
att.W_key.weight,
params[f"model.layers.{l}.self_attn.k_proj.weight"],
f"model.layers.{l}.self_attn.k_proj.weight"
)
att.W_value.weight = assign(
att.W_value.weight,
params[f"model.layers.{l}.self_attn.v_proj.weight"],
f"model.layers.{l}.self_attn.v_proj.weight"
)
# Output projection
att.out_proj.weight = assign(
att.out_proj.weight,
params[f"model.layers.{l}.self_attn.o_proj.weight"],
f"model.layers.{l}.self_attn.o_proj.weight"
)
# QK norms
if hasattr(att, "q_norm") and att.q_norm is not None:
att.q_norm.scale = assign(
att.q_norm.scale,
params[f"model.layers.{l}.self_attn.q_norm.weight"],
f"model.layers.{l}.self_attn.q_norm.weight"
)
if hasattr(att, "k_norm") and att.k_norm is not None:
att.k_norm.scale = assign(
att.k_norm.scale,
params[f"model.layers.{l}.self_attn.k_norm.weight"],
f"model.layers.{l}.self_attn.k_norm.weight"
)
# Attention layernorm
block.norm1.scale = assign(
block.norm1.scale,
params[f"model.layers.{l}.input_layernorm.weight"],
f"model.layers.{l}.input_layernorm.weight"
)
# Feedforward weights
block.ff.fc1.weight = assign(
block.ff.fc1.weight,
params[f"model.layers.{l}.mlp.gate_proj.weight"],
f"model.layers.{l}.mlp.gate_proj.weight"
)
block.ff.fc2.weight = assign(
block.ff.fc2.weight,
params[f"model.layers.{l}.mlp.up_proj.weight"],
f"model.layers.{l}.mlp.up_proj.weight"
)
block.ff.fc3.weight = assign(
block.ff.fc3.weight,
params[f"model.layers.{l}.mlp.down_proj.weight"],
f"model.layers.{l}.mlp.down_proj.weight"
)
block.norm2.scale = assign(
block.norm2.scale,
params[f"model.layers.{l}.post_attention_layernorm.weight"],
f"model.layers.{l}.post_attention_layernorm.weight"
)
# Final normalization and output head
model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
# Model uses weight tying, hence we reuse the embedding layer weights here
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
class Qwen3Tokenizer():
def __init__(self, tokenizer_file_path="tokenizer.json",
repo_id=None, add_generation_prompt=False, add_thinking=False):
from tokenizers import Tokenizer
self.tokenizer_file_path = tokenizer_file_path
if add_generation_prompt != add_thinking:
raise ValueError(
"Only add_generation_prompt==add_thinking settings are currently supported"
)
self.add_generation_prompt = add_generation_prompt
self.add_thinking = add_thinking
tokenizer_file_path_obj = Path(tokenizer_file_path)
if not tokenizer_file_path_obj.is_file() and repo_id is not None:
_ = download_from_huggingface(
repo_id=repo_id,
filename=str(tokenizer_file_path_obj.name),
local_dir=str(tokenizer_file_path_obj.parent.name)
)
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
def encode(self, prompt):
messages = [
{"role": "user", "content": prompt}
]
formatted_prompt = self.format_qwen_chat(
messages,
add_generation_prompt=self.add_generation_prompt,
add_thinking=self.add_thinking
)
return self.tokenizer.encode(formatted_prompt).ids
def decode(self, token_ids):
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
@staticmethod
def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
prompt = ""
for msg in messages:
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
if add_generation_prompt:
prompt += "<|im_start|>assistant"
if not add_thinking:
prompt += "<|think>\n\n<|/think>\n\n"
else:
prompt += "\n"
return prompt
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
base_url = "https://huggingface.co"
url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
Path(local_dir).mkdir(parents=True, exist_ok=True)
dest_path = os.path.join(local_dir, filename)
print(f"Downloading {url} to {dest_path}...")
urllib.request.urlretrieve(url, dest_path)
return dest_path

View File

@ -4,13 +4,15 @@
# Code: https://github.com/rasbt/LLMs-from-scratch
import os
import json
import urllib.request
from pathlib import Path
import torch
import torch.nn as nn
# 0.6B model
# 0.6 billion parameters
QWEN_CONFIG_06_B = {
"vocab_size": 151_936, # Vocabulary size
"context_length": 40_960, # Context length that was used to train the model
@ -25,6 +27,80 @@ QWEN_CONFIG_06_B = {
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
}
# 1.7 billion parameters
QWEN3_CONFIG_1_7B = {
"vocab_size": 151_936,
"context_length": 40_960,
"emb_dim": 2048, # 2x larger than above
"n_heads": 16,
"n_layers": 28,
"hidden_dim": 6144, # 2x larger than above
"head_dim": 128,
"qk_norm": True,
"n_kv_groups": 8,
"rope_base": 1_000_000.0,
"dtype": torch.bfloat16,
}
# 4 billion parameters
QWEN3_CONFIG_4B = {
"vocab_size": 151_936,
"context_length": 40_960,
"emb_dim": 2560, # 25% larger than above
"n_heads": 32, # 2x larger than above
"n_layers": 36, # 29% larger than above
"hidden_dim": 9728, # ~3x larger than above
"head_dim": 128,
"qk_norm": True,
"n_kv_groups": 8,
"rope_base": 1_000_000.0,
"dtype": torch.bfloat16,
}
# 8 billion parameters
QWEN3_CONFIG_8B = {
"vocab_size": 151_936,
"context_length": 40_960,
"emb_dim": 4096, # 60% larger than above
"n_heads": 32,
"n_layers": 36, # 26% larger than above
"hidden_dim": 12288,
"head_dim": 128,
"qk_norm": True,
"n_kv_groups": 8,
"rope_base": 1_000_000.0,
"dtype": torch.bfloat16,
}
# 14 billion parameters
QWEN3_CONFIG_14B = {
"vocab_size": 151_936,
"context_length": 40_960,
"emb_dim": 5120, # 25% larger than above
"n_heads": 40, # 25% larger than above
"n_layers": 40, # 11% larger than above
"hidden_dim": 17408, # 42% larger than above
"head_dim": 128,
"qk_norm": True,
"n_kv_groups": 8,
"rope_base": 1_000_000.0,
"dtype": torch.bfloat16,
}
QWEN3_CONFIG_32B = {
"vocab_size": 151_936,
"context_length": 40_960,
"emb_dim": 5120,
"n_heads": 64, # 60% larger than above
"n_layers": 64, # 60% larger than above
"hidden_dim": 25600, # 47% larger than above
"head_dim": 128,
"qk_norm": True,
"n_kv_groups": 8,
"rope_base": 1_000_000.0,
"dtype": torch.bfloat16,
}
class Qwen3Model(nn.Module):
def __init__(self, cfg):
@ -388,6 +464,44 @@ def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
Path(local_dir).mkdir(parents=True, exist_ok=True)
dest_path = os.path.join(local_dir, filename)
print(f"Downloading {url} to {dest_path}...")
urllib.request.urlretrieve(url, dest_path)
if os.path.exists(dest_path):
print(f"File already exists: {dest_path}")
else:
print(f"Downloading {url} to {dest_path}...")
urllib.request.urlretrieve(url, dest_path)
return dest_path
def download_from_huggingface_from_snapshots(repo_id, local_dir):
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file # or your preferred loader
repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)
index_path = os.path.join(repo_dir, "model.safetensors.index.json")
single_file_path = os.path.join(repo_dir, "model.safetensors")
if os.path.exists(index_path):
# Multi-shard model
with open(index_path, "r") as f:
index = json.load(f)
weights_dict = {}
for filename in set(index["weight_map"].values()):
shard_path = os.path.join(repo_dir, filename)
shard = load_file(shard_path)
weights_dict.update(shard)
elif os.path.exists(single_file_path):
# Single-shard model
weights_file = hf_hub_download(
repo_id=repo_id,
filename="model.safetensors",
local_dir=local_dir,
)
weights_dict = load_file(weights_file)
else:
raise FileNotFoundError("No model.safetensors or model.safetensors.index.json found.")
return weights_dict

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llms-from-scratch"
version = "1.0.14"
version = "1.0.15"
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
readme = "README.md"
requires-python = ">=3.10"