Add Qwen3 1.7, 4B, 8B, and 32B support to from-scratch nb (#709)

This commit is contained in:
Sebastian Raschka 2025-06-25 08:53:09 -05:00 committed by GitHub
parent 1bce95d70c
commit 81be5fab0b
3 changed files with 245 additions and 77 deletions

8
.gitignore vendored
View File

@ -55,6 +55,14 @@ ch05/10_llm-training-speed/loss.pdf
ch05/10_llm-training-speed/model.pth
ch05/11_qwen3/Qwen3-0.6B
ch05/11_qwen3/Qwen3-0.6B-Base
ch05/11_qwen3/Qwen3-1.7B
ch05/11_qwen3/Qwen3-1.7B-Base
ch05/11_qwen3/Qwen3-4B
ch05/11_qwen3/Qwen3-4B-Base
ch05/11_qwen3/Qwen3-8B
ch05/11_qwen3/Qwen3-8B-Base
ch05/11_qwen3/Qwen3-32B
ch05/11_qwen3/Qwen3-32B-Base
ch06/01_main-chapter-code/gpt2
ch06/02_bonus_additional-experiments/gpt2

View File

@ -1,6 +1,6 @@
# Qwen3 From Scratch
This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Qwen3 0.6B.
This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Qwen3 0.6B, 1.7B, 4B, 8B, and 32 B.
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen-overview.webp">
@ -8,7 +8,7 @@ This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this f
&nbsp;
### Using Qwen3 0.6B via the `llms-from-scratch` package
For an easy way to use the Qwen3 from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
For an easy way to use the Qwen3 0.6B from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
&nbsp;
#### 1) Installation

View File

@ -29,7 +29,7 @@
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
},
"source": [
"# Qwen3 0.6B From Scratch (A Standalone Notebook)"
"# Qwen3 From Scratch (A Standalone Notebook)"
]
},
{
@ -39,7 +39,7 @@
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
},
"source": [
"- This notebook is purposefully minimal and focuses on the code to implement Qwen3 0.6B; for more information about this model, please see the original blog post and technical report:\n",
"- This notebook is purposefully minimal and focuses on the code to implement Qwen3 0.6B, 1.7B, 4B, 8B, and 32B; for more information about this model, please see the original blog post and technical report:\n",
" - [Qwen3: Think Deeper, Act Faster](https://qwenlm.github.io/blog/qwen3/)\n",
" - [Qwen3 Technical Report](https://arxiv.org/abs/2505.09388) \n",
"- Many architectural components in Qwen3 are similar to Llama 3; for a step-by-step guide that explains the individual components and the relationship between GPT and the components used here, you may like the GPT-to-Llama conversion notebooks:\n",
@ -418,33 +418,127 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 25,
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
"metadata": {
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
},
"outputs": [],
"source": [
"# Qwen3 0.6B\n",
"CHOOSE_MODEL = \"0.6B\"\n",
"\n",
"QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936, # Vocabulary size\n",
" \"context_length\": 40_960, # Context length that was used to train the model\n",
" \"emb_dim\": 1024, # Embedding dimension\n",
" \"n_heads\": 16, # Number of attention heads\n",
" \"n_layers\": 28, # Number of layers\n",
" \"hidden_dim\": 3072, # Size of the intermediate dimension in FeedForward\n",
" \"head_dim\": 128, # Size of the heads in GQA\n",
" \"qk_norm\": True, # Whether to normalize queries and values in GQA\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 1_000_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
"}"
"if CHOOSE_MODEL == \"0.6B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936, # Vocabulary size\n",
" \"context_length\": 40_960, # Context length that was used to train the model\n",
" \"emb_dim\": 1024, # Embedding dimension\n",
" \"n_heads\": 16, # Number of attention heads\n",
" \"n_layers\": 28, # Number of layers\n",
" \"hidden_dim\": 3072, # Size of the intermediate dimension in FeedForward\n",
" \"head_dim\": 128, # Size of the heads in GQA\n",
" \"qk_norm\": True, # Whether to normalize queries and values in GQA\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 1_000_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
" }\n",
"\n",
"elif CHOOSE_MODEL == \"1.7B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 2048, # 2x larger than above\n",
" \"n_heads\": 16,\n",
" \"n_layers\": 28,\n",
" \"hidden_dim\": 6144, # 2x larger than above\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"elif CHOOSE_MODEL == \"4B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 2560, # 25% larger than above\n",
" \"n_heads\": 32, # 2x larger than above\n",
" \"n_layers\": 36, # 29% larger than above\n",
" \"hidden_dim\": 9728, # ~3x larger than above\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"elif CHOOSE_MODEL == \"8B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 4096, # 60% larger than above\n",
" \"n_heads\": 32,\n",
" \"n_layers\": 36, # 26% larger than above\n",
" \"hidden_dim\": 12288,\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"elif CHOOSE_MODEL == \"8B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 4096, # 60% larger than above\n",
" \"n_heads\": 32,\n",
" \"n_layers\": 36, # 26% larger than above\n",
" \"hidden_dim\": 12288,\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"elif CHOOSE_MODEL == \"14B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 5120, # 25% larger than above\n",
" \"n_heads\": 40, # 25% larger than above\n",
" \"n_layers\": 40, # 11% larger than above\n",
" \"hidden_dim\": 17408, # 42% larger than above\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"elif CHOOSE_MODEL == \"32B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 5120, \n",
" \"n_heads\": 64, # 60% larger than above\n",
" \"n_layers\": 64, # 60% larger than above\n",
" \"hidden_dim\": 25600, # 47% larger than above\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"else:\n",
" raise ValueError(f\"{CHOOSE_MODEL} is not supported.\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 27,
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
"metadata": {
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
@ -457,7 +551,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 28,
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
"metadata": {},
"outputs": [
@ -465,32 +559,32 @@
"data": {
"text/plain": [
"Qwen3Model(\n",
" (tok_emb): Embedding(151936, 1024)\n",
" (tok_emb): Embedding(151936, 4096)\n",
" (trf_blocks): ModuleList(\n",
" (0-27): 28 x TransformerBlock(\n",
" (0-35): 36 x TransformerBlock(\n",
" (att): GroupedQueryAttention(\n",
" (W_query): Linear(in_features=1024, out_features=2048, bias=False)\n",
" (W_key): Linear(in_features=1024, out_features=1024, bias=False)\n",
" (W_value): Linear(in_features=1024, out_features=1024, bias=False)\n",
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
" (W_query): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (W_key): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (W_value): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (q_norm): RMSNorm()\n",
" (k_norm): RMSNorm()\n",
" )\n",
" (ff): FeedForward(\n",
" (fc1): Linear(in_features=1024, out_features=3072, bias=False)\n",
" (fc2): Linear(in_features=1024, out_features=3072, bias=False)\n",
" (fc3): Linear(in_features=3072, out_features=1024, bias=False)\n",
" (fc1): Linear(in_features=4096, out_features=12288, bias=False)\n",
" (fc2): Linear(in_features=4096, out_features=12288, bias=False)\n",
" (fc3): Linear(in_features=12288, out_features=4096, bias=False)\n",
" )\n",
" (norm1): RMSNorm()\n",
" (norm2): RMSNorm()\n",
" )\n",
" )\n",
" (final_norm): RMSNorm()\n",
" (out_head): Linear(in_features=1024, out_features=151936, bias=False)\n",
" (out_head): Linear(in_features=4096, out_features=151936, bias=False)\n",
")"
]
},
"execution_count": 12,
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
@ -509,20 +603,20 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 29,
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.2256, -0.0164, -0.7070, ..., 0.4414, 0.1245, 1.0703],\n",
" [-0.6602, 0.5352, -0.0718, ..., -0.0737, 0.5391, 0.3086],\n",
" [-0.4785, -0.1562, 0.1045, ..., -0.2324, 0.2354, 0.6328]]],\n",
"tensor([[[-0.7305, -1.2109, 0.4551, ..., -0.0215, -0.5742, -0.2754],\n",
" [-0.4023, -0.6094, 0.0415, ..., 0.6094, -0.6758, 0.3789],\n",
" [-0.4043, 0.1943, -0.0757, ..., 0.4121, -1.2344, -0.1445]]],\n",
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 13,
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
@ -533,7 +627,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 30,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": {
"colab": {
@ -547,9 +641,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of parameters: 751,632,384\n",
"Total number of parameters: 8,190,735,360\n",
"\n",
"Total number of unique parameters: 596,049,920\n"
"Total number of unique parameters: 7,568,405,504\n"
]
}
],
@ -564,7 +658,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 31,
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"metadata": {
"colab": {
@ -578,8 +672,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"float32 (PyTorch default): 5.64 GB\n",
"bfloat16: 2.82 GB\n"
"float32 (PyTorch default): 61.06 GB\n",
"bfloat16: 30.53 GB\n"
]
}
],
@ -614,7 +708,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 32,
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
"metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
@ -644,7 +738,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 36,
"id": "75166128-5899-4995-9b88-9672e135650e",
"metadata": {
"id": "75166128-5899-4995-9b88-9672e135650e"
@ -733,13 +827,17 @@
" # Final normalization and output head\n",
" model.final_norm.scale = assign(model.final_norm.scale, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
"\n",
" # Model uses weight tying, hence we reuse the embedding layer weights here\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
" if \"lm_head.weight\" in params:\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
" else:\n",
" # Model uses weight tying, hence we reuse the embedding layer weights here\n",
" print(\"Model uses weight tying.\")\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 37,
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"metadata": {
"colab": {
@ -762,31 +860,98 @@
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bf7fbc5f95ed4f06b5ba47d4aec96738",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 14 files: 0%| | 0/14 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
},
{
"data": {
"text/plain": [
"Qwen3Model(\n",
" (tok_emb): Embedding(151936, 4096)\n",
" (trf_blocks): ModuleList(\n",
" (0-35): 36 x TransformerBlock(\n",
" (att): GroupedQueryAttention(\n",
" (W_query): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (W_key): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (W_value): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (q_norm): RMSNorm()\n",
" (k_norm): RMSNorm()\n",
" )\n",
" (ff): FeedForward(\n",
" (fc1): Linear(in_features=4096, out_features=12288, bias=False)\n",
" (fc2): Linear(in_features=4096, out_features=12288, bias=False)\n",
" (fc3): Linear(in_features=12288, out_features=4096, bias=False)\n",
" )\n",
" (norm1): RMSNorm()\n",
" (norm2): RMSNorm()\n",
" )\n",
" )\n",
" (final_norm): RMSNorm()\n",
" (out_head): Linear(in_features=4096, out_features=151936, bias=False)\n",
")"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import json\n",
"import os\n",
"from pathlib import Path\n",
"from safetensors.torch import load_file\n",
"from huggingface_hub import hf_hub_download\n",
"from huggingface_hub import hf_hub_download, snapshot_download\n",
"\n",
"\n",
"if USE_REASONING_MODEL:\n",
" repo_id = \"Qwen/Qwen3-0.6B\"\n",
" repo_id = f\"Qwen/Qwen3-{CHOOSE_MODEL}\"\n",
"else:\n",
" repo_id = \"Qwen/Qwen3-0.6B-Base\"\n",
" repo_id = f\"Qwen/Qwen3-{CHOOSE_MODEL}-Base\"\n",
"\n",
"local_dir = Path(repo_id).parts[-1]\n",
"\n",
"if CHOOSE_MODEL == \"0.6B\":\n",
" weights_file = hf_hub_download(\n",
" repo_id=repo_id,\n",
" filename=\"model.safetensors\",\n",
" local_dir=local_dir,\n",
" )\n",
" weights_dict = load_file(weights_file)\n",
"else:\n",
" repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)\n",
" index_path = os.path.join(repo_dir, \"model.safetensors.index.json\")\n",
" with open(index_path, \"r\") as f:\n",
" index = json.load(f)\n",
"\n",
"weights_file = hf_hub_download(\n",
" repo_id=repo_id,\n",
" filename=\"model.safetensors\",\n",
" local_dir=local_dir\n",
")\n",
"\n",
"weights_dict = load_file(weights_file)\n",
" weights_dict = {}\n",
" for filename in set(index[\"weight_map\"].values()):\n",
" shard_path = os.path.join(repo_dir, filename)\n",
" shard = load_file(shard_path)\n",
" weights_dict.update(shard)\n",
"\n",
"load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)\n",
"model.to(device)\n",
"del weights_file # free up memory"
"model.to(device);"
]
},
{
@ -800,7 +965,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 38,
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
"metadata": {},
"outputs": [],
@ -853,15 +1018,15 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 39,
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
"metadata": {},
"outputs": [],
"source": [
"if USE_REASONING_MODEL:\n",
" tokenizer_file_path = \"Qwen3-0.6B/tokenizer.json\"\n",
" tokenizer_file_path = f\"Qwen3-{CHOOSE_MODEL}/tokenizer.json\"\n",
"else:\n",
" tokenizer_file_path = \"Qwen3-0.6B-Base/tokenizer.json\"\n",
" tokenizer_file_path = f\"Qwen3-{CHOOSE_MODEL}-Base/tokenizer.json\"\n",
"\n",
"tokenizer = Qwen3Tokenizer(\n",
" tokenizer_file_path=tokenizer_file_path,\n",
@ -873,7 +1038,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 40,
"id": "1946b534-e3af-431a-a222-391a60bfa892",
"metadata": {},
"outputs": [
@ -883,7 +1048,7 @@
"'<|im_start|>user\\nGive me a short introduction to large language models.<|im_end|>\\n<|im_start|>assistant\\n'"
]
},
"execution_count": 21,
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
@ -909,7 +1074,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 41,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@ -931,7 +1096,7 @@
" # Keep only top_k values\n",
" top_logits, _ = torch.topk(logits, top_k)\n",
" min_val = top_logits[:, -1]\n",
" logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
" logits = torch.where(logits < min_val, torch.tensor(-torch.inf).to(logits.device), logits)\n",
"\n",
" # pply temperature scaling\n",
" if temperature > 0.0:\n",
@ -958,7 +1123,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 42,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
@ -968,17 +1133,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Time: 21.00 sec\n",
"Time: 78.98 sec\n",
"<|im_start|>user\n",
"Give me a short introduction to large language models.<|im_end|>\n",
"<|im_start|>assistant\n",
"<think>\n",
"Okay, the user wants a short introduction to large language models. Let me start by recalling what I know. Large language models are AI systems that can understand and generate human language. They're trained on massive datasets, so they can learn complex patterns and nuances.\n",
"\n",
"I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or language assistants. Also, emphasize their adaptability and versatility. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around 3-4 sentences. Make sure it's clear and easy to understand.\n",
"</think>\n",
"\n",
"Large language models (LLMs) are AI systems designed...\n"
"Okay, the user wants a short introduction to large language models. Let me start by defining what they are. They're AI systems trained on vast amounts of text data, right? I should mention their ability to understand and generate human-like text. Maybe include examples like GPT or BERT. Also, highlight their applications in tasks like answering questions, writing, coding, and more. Need to keep it concise but cover the key points. Oh, and maybe touch on how they're trained using deep learning techniques. Wait, should I explain the training process briefly? Probably not necessary for a short intro. Focus on the main aspects: what they are, how they work, and their uses. Make sure it's easy to understand without too...\n"
]
}
],