mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-27 18:10:39 +00:00
Support different Qwen3 sizes in pkg (#714)
This commit is contained in:
parent
8c8ff24118
commit
dc2f8e95d4
@ -6,9 +6,9 @@ This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this f
|
||||
|
||||
|
||||
|
||||
### Using Qwen3 0.6B via the `llms-from-scratch` package
|
||||
### Using Qwen3 via the `llms-from-scratch` package
|
||||
|
||||
For an easy way to use the Qwen3 0.6B from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
|
||||
For an easy way to use the Qwen3 from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
|
||||
|
||||
|
||||
#### 1) Installation
|
||||
@ -36,9 +36,9 @@ TOP_K = 1
|
||||
```
|
||||
|
||||
|
||||
#### 3) Weight download and loading
|
||||
#### 3a) Weight download and loading of the 0.6B model
|
||||
|
||||
This automatically downloads the weight file based on the model choice above:
|
||||
The following automatically downloads the weight file based on the model choice (reasoning or base) above. Note that this section focuses on the 0.6B model. Skip this section and continue with section 3b) if you want to work with any of the larger models (1.7B, 4B, 8B, or 32B).
|
||||
|
||||
```python
|
||||
from llms_from_scratch.qwen3 import download_from_huggingface
|
||||
@ -77,10 +77,74 @@ device = (
|
||||
torch.device("mps") if torch.backends.mps.is_available() else
|
||||
torch.device("cpu")
|
||||
)
|
||||
model.to(device)
|
||||
model.to(device);
|
||||
```
|
||||
|
||||
|
||||
#### 3b) Weight download and loading of the larger Qwen models
|
||||
|
||||
If you are interested in working with any of the larger Qwen models, for instance, 1.7B, 4B, 8B, or 32B, please use the following code below instead of the code under 3a), which requires additional code dependencies:
|
||||
|
||||
```bash
|
||||
pip install safetensors huggingface_hub
|
||||
```
|
||||
|
||||
Then use the following code (make appropriate changes to `USE_MODEL` to select the desired model size)
|
||||
|
||||
```python
|
||||
USE_MODEL = "1.7B"
|
||||
|
||||
if USE_MODEL == "1.7B":
|
||||
from llms_from_scratch.qwen3 import QWEN3_CONFIG_1_7B as QWEN3_CONFIG
|
||||
elif USE_MODEL == "4B":
|
||||
from llms_from_scratch.qwen3 import QWEN3_CONFIG_4B as QWEN3_CONFIG
|
||||
elif USE_MODEL == "8B":
|
||||
from llms_from_scratch.qwen3 import QWEN3_CONFIG_8B as QWEN3_CONFIG
|
||||
elif USE_MODEL == "14B":
|
||||
from llms_from_scratch.qwen3 import QWEN3_CONFIG_14B as QWEN3_CONFIG
|
||||
elif USE_MODEL == "32B":
|
||||
from llms_from_scratch.qwen3 import QWEN3_CONFIG_32B as QWEN3_CONFIG
|
||||
else:
|
||||
raise ValueError("Invalid USE_MODEL name.")
|
||||
|
||||
repo_id = f"Qwen/Qwen3-{USE_MODEL}"
|
||||
local_dir = f"Qwen3-{USE_MODEL}"
|
||||
|
||||
if not USE_REASONING_MODEL:
|
||||
repo_id = f"{repo_id}-Base"
|
||||
local_dir = f"{local_dir}-Base"
|
||||
```
|
||||
|
||||
Now, download and load the weights into the `model`:
|
||||
|
||||
```python
|
||||
from llms_from_scratch.qwen3 import (
|
||||
Qwen3Model,
|
||||
download_from_huggingface_from_snapshots,
|
||||
load_weights_into_qwen
|
||||
)
|
||||
|
||||
model = Qwen3Model(QWEN3_CONFIG)
|
||||
|
||||
weights_dict = download_from_huggingface_from_snapshots(
|
||||
repo_id=repo_id,
|
||||
local_dir=local_dir
|
||||
)
|
||||
load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)
|
||||
del weights_dict # delete weight dictionary to free up disk space
|
||||
|
||||
device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else
|
||||
torch.device("mps") if torch.backends.mps.is_available() else
|
||||
torch.device("cpu")
|
||||
)
|
||||
|
||||
model.to(device);
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
#### 4) Initialize tokenizer
|
||||
|
||||
The following code downloads and initializes the tokenizer:
|
||||
|
@ -4,29 +4,17 @@
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from .utils import KVCache # noqa: F401
|
||||
|
||||
import os
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from ..qwen3 import ( # noqa: F401
|
||||
QWEN_CONFIG_06_B, QWEN3_CONFIG_1_7B, QWEN3_CONFIG_4B,
|
||||
QWEN3_CONFIG_8B, QWEN3_CONFIG_14B, QWEN3_CONFIG_32B,
|
||||
Qwen3Tokenizer, load_weights_into_qwen,
|
||||
download_from_huggingface,
|
||||
download_from_huggingface_from_snapshots
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 0.6B model
|
||||
QWEN_CONFIG_06_B = {
|
||||
"vocab_size": 151_936, # Vocabulary size
|
||||
"context_length": 40_960, # Context length that was used to train the model
|
||||
"emb_dim": 1024, # Embedding dimension
|
||||
"n_heads": 16, # Number of attention heads
|
||||
"n_layers": 28, # Number of layers
|
||||
"hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
|
||||
"head_dim": 128, # Size of the heads in GQA
|
||||
"qk_norm": True, # Whether to normalize queries and values in GQA
|
||||
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
|
||||
"rope_base": 1_000_000.0, # The base in RoPE's "theta"
|
||||
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
|
||||
}
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
@ -285,150 +273,3 @@ class RMSNorm(nn.Module):
|
||||
norm_x = norm_x + self.shift
|
||||
|
||||
return norm_x.to(input_dtype)
|
||||
|
||||
|
||||
def load_weights_into_qwen(model, param_config, params):
|
||||
def assign(left, right, tensor_name="unknown"):
|
||||
if left.shape != right.shape:
|
||||
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
|
||||
return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
|
||||
|
||||
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
for l in range(param_config["n_layers"]):
|
||||
block = model.trf_blocks[l]
|
||||
att = block.att
|
||||
|
||||
# Q, K, V projections
|
||||
att.W_query.weight = assign(
|
||||
att.W_query.weight,
|
||||
params[f"model.layers.{l}.self_attn.q_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.q_proj.weight"
|
||||
)
|
||||
att.W_key.weight = assign(
|
||||
att.W_key.weight,
|
||||
params[f"model.layers.{l}.self_attn.k_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.k_proj.weight"
|
||||
)
|
||||
att.W_value.weight = assign(
|
||||
att.W_value.weight,
|
||||
params[f"model.layers.{l}.self_attn.v_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.v_proj.weight"
|
||||
)
|
||||
|
||||
# Output projection
|
||||
att.out_proj.weight = assign(
|
||||
att.out_proj.weight,
|
||||
params[f"model.layers.{l}.self_attn.o_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.o_proj.weight"
|
||||
)
|
||||
|
||||
# QK norms
|
||||
if hasattr(att, "q_norm") and att.q_norm is not None:
|
||||
att.q_norm.scale = assign(
|
||||
att.q_norm.scale,
|
||||
params[f"model.layers.{l}.self_attn.q_norm.weight"],
|
||||
f"model.layers.{l}.self_attn.q_norm.weight"
|
||||
)
|
||||
if hasattr(att, "k_norm") and att.k_norm is not None:
|
||||
att.k_norm.scale = assign(
|
||||
att.k_norm.scale,
|
||||
params[f"model.layers.{l}.self_attn.k_norm.weight"],
|
||||
f"model.layers.{l}.self_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
# Attention layernorm
|
||||
block.norm1.scale = assign(
|
||||
block.norm1.scale,
|
||||
params[f"model.layers.{l}.input_layernorm.weight"],
|
||||
f"model.layers.{l}.input_layernorm.weight"
|
||||
)
|
||||
|
||||
# Feedforward weights
|
||||
block.ff.fc1.weight = assign(
|
||||
block.ff.fc1.weight,
|
||||
params[f"model.layers.{l}.mlp.gate_proj.weight"],
|
||||
f"model.layers.{l}.mlp.gate_proj.weight"
|
||||
)
|
||||
block.ff.fc2.weight = assign(
|
||||
block.ff.fc2.weight,
|
||||
params[f"model.layers.{l}.mlp.up_proj.weight"],
|
||||
f"model.layers.{l}.mlp.up_proj.weight"
|
||||
)
|
||||
block.ff.fc3.weight = assign(
|
||||
block.ff.fc3.weight,
|
||||
params[f"model.layers.{l}.mlp.down_proj.weight"],
|
||||
f"model.layers.{l}.mlp.down_proj.weight"
|
||||
)
|
||||
block.norm2.scale = assign(
|
||||
block.norm2.scale,
|
||||
params[f"model.layers.{l}.post_attention_layernorm.weight"],
|
||||
f"model.layers.{l}.post_attention_layernorm.weight"
|
||||
)
|
||||
|
||||
# Final normalization and output head
|
||||
model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
|
||||
|
||||
# Model uses weight tying, hence we reuse the embedding layer weights here
|
||||
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
|
||||
class Qwen3Tokenizer():
|
||||
def __init__(self, tokenizer_file_path="tokenizer.json",
|
||||
repo_id=None, add_generation_prompt=False, add_thinking=False):
|
||||
from tokenizers import Tokenizer
|
||||
self.tokenizer_file_path = tokenizer_file_path
|
||||
|
||||
if add_generation_prompt != add_thinking:
|
||||
raise ValueError(
|
||||
"Only add_generation_prompt==add_thinking settings are currently supported"
|
||||
)
|
||||
|
||||
self.add_generation_prompt = add_generation_prompt
|
||||
self.add_thinking = add_thinking
|
||||
|
||||
tokenizer_file_path_obj = Path(tokenizer_file_path)
|
||||
if not tokenizer_file_path_obj.is_file() and repo_id is not None:
|
||||
_ = download_from_huggingface(
|
||||
repo_id=repo_id,
|
||||
filename=str(tokenizer_file_path_obj.name),
|
||||
local_dir=str(tokenizer_file_path_obj.parent.name)
|
||||
)
|
||||
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
|
||||
|
||||
def encode(self, prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
formatted_prompt = self.format_qwen_chat(
|
||||
messages,
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
add_thinking=self.add_thinking
|
||||
)
|
||||
return self.tokenizer.encode(formatted_prompt).ids
|
||||
|
||||
def decode(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
||||
|
||||
@staticmethod
|
||||
def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
|
||||
prompt = ""
|
||||
for msg in messages:
|
||||
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
|
||||
if add_generation_prompt:
|
||||
prompt += "<|im_start|>assistant"
|
||||
if not add_thinking:
|
||||
prompt += "<|think>\n\n<|/think>\n\n"
|
||||
else:
|
||||
prompt += "\n"
|
||||
return prompt
|
||||
|
||||
|
||||
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
|
||||
base_url = "https://huggingface.co"
|
||||
url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
|
||||
Path(local_dir).mkdir(parents=True, exist_ok=True)
|
||||
dest_path = os.path.join(local_dir, filename)
|
||||
print(f"Downloading {url} to {dest_path}...")
|
||||
urllib.request.urlretrieve(url, dest_path)
|
||||
return dest_path
|
||||
|
@ -4,13 +4,15 @@
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
import os
|
||||
import json
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 0.6B model
|
||||
|
||||
# 0.6 billion parameters
|
||||
QWEN_CONFIG_06_B = {
|
||||
"vocab_size": 151_936, # Vocabulary size
|
||||
"context_length": 40_960, # Context length that was used to train the model
|
||||
@ -25,6 +27,80 @@ QWEN_CONFIG_06_B = {
|
||||
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
|
||||
}
|
||||
|
||||
# 1.7 billion parameters
|
||||
QWEN3_CONFIG_1_7B = {
|
||||
"vocab_size": 151_936,
|
||||
"context_length": 40_960,
|
||||
"emb_dim": 2048, # 2x larger than above
|
||||
"n_heads": 16,
|
||||
"n_layers": 28,
|
||||
"hidden_dim": 6144, # 2x larger than above
|
||||
"head_dim": 128,
|
||||
"qk_norm": True,
|
||||
"n_kv_groups": 8,
|
||||
"rope_base": 1_000_000.0,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
# 4 billion parameters
|
||||
QWEN3_CONFIG_4B = {
|
||||
"vocab_size": 151_936,
|
||||
"context_length": 40_960,
|
||||
"emb_dim": 2560, # 25% larger than above
|
||||
"n_heads": 32, # 2x larger than above
|
||||
"n_layers": 36, # 29% larger than above
|
||||
"hidden_dim": 9728, # ~3x larger than above
|
||||
"head_dim": 128,
|
||||
"qk_norm": True,
|
||||
"n_kv_groups": 8,
|
||||
"rope_base": 1_000_000.0,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
# 8 billion parameters
|
||||
QWEN3_CONFIG_8B = {
|
||||
"vocab_size": 151_936,
|
||||
"context_length": 40_960,
|
||||
"emb_dim": 4096, # 60% larger than above
|
||||
"n_heads": 32,
|
||||
"n_layers": 36, # 26% larger than above
|
||||
"hidden_dim": 12288,
|
||||
"head_dim": 128,
|
||||
"qk_norm": True,
|
||||
"n_kv_groups": 8,
|
||||
"rope_base": 1_000_000.0,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
# 14 billion parameters
|
||||
QWEN3_CONFIG_14B = {
|
||||
"vocab_size": 151_936,
|
||||
"context_length": 40_960,
|
||||
"emb_dim": 5120, # 25% larger than above
|
||||
"n_heads": 40, # 25% larger than above
|
||||
"n_layers": 40, # 11% larger than above
|
||||
"hidden_dim": 17408, # 42% larger than above
|
||||
"head_dim": 128,
|
||||
"qk_norm": True,
|
||||
"n_kv_groups": 8,
|
||||
"rope_base": 1_000_000.0,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
QWEN3_CONFIG_32B = {
|
||||
"vocab_size": 151_936,
|
||||
"context_length": 40_960,
|
||||
"emb_dim": 5120,
|
||||
"n_heads": 64, # 60% larger than above
|
||||
"n_layers": 64, # 60% larger than above
|
||||
"hidden_dim": 25600, # 47% larger than above
|
||||
"head_dim": 128,
|
||||
"qk_norm": True,
|
||||
"n_kv_groups": 8,
|
||||
"rope_base": 1_000_000.0,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
@ -388,6 +464,44 @@ def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
|
||||
url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
|
||||
Path(local_dir).mkdir(parents=True, exist_ok=True)
|
||||
dest_path = os.path.join(local_dir, filename)
|
||||
print(f"Downloading {url} to {dest_path}...")
|
||||
urllib.request.urlretrieve(url, dest_path)
|
||||
|
||||
if os.path.exists(dest_path):
|
||||
print(f"File already exists: {dest_path}")
|
||||
else:
|
||||
print(f"Downloading {url} to {dest_path}...")
|
||||
urllib.request.urlretrieve(url, dest_path)
|
||||
|
||||
return dest_path
|
||||
|
||||
|
||||
def download_from_huggingface_from_snapshots(repo_id, local_dir):
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file # or your preferred loader
|
||||
|
||||
repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)
|
||||
|
||||
index_path = os.path.join(repo_dir, "model.safetensors.index.json")
|
||||
single_file_path = os.path.join(repo_dir, "model.safetensors")
|
||||
|
||||
if os.path.exists(index_path):
|
||||
# Multi-shard model
|
||||
with open(index_path, "r") as f:
|
||||
index = json.load(f)
|
||||
|
||||
weights_dict = {}
|
||||
for filename in set(index["weight_map"].values()):
|
||||
shard_path = os.path.join(repo_dir, filename)
|
||||
shard = load_file(shard_path)
|
||||
weights_dict.update(shard)
|
||||
elif os.path.exists(single_file_path):
|
||||
# Single-shard model
|
||||
weights_file = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="model.safetensors",
|
||||
local_dir=local_dir,
|
||||
)
|
||||
weights_dict = load_file(weights_file)
|
||||
else:
|
||||
raise FileNotFoundError("No model.safetensors or model.safetensors.index.json found.")
|
||||
|
||||
return weights_dict
|
||||
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "llms-from-scratch"
|
||||
version = "1.0.14"
|
||||
version = "1.0.15"
|
||||
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
Loading…
x
Reference in New Issue
Block a user