LLMs-from-scratch/ch05/11_qwen3/standalone-qwen3.ipynb
2025-07-10 12:52:29 -05:00

1210 lines
44 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
"metadata": {
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
"metadata": {
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
},
"source": [
"# Qwen3 From Scratch (A Standalone Notebook)"
]
},
{
"cell_type": "markdown",
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
"metadata": {
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
},
"source": [
"- This notebook is purposefully minimal and focuses on the code to implement Qwen3 0.6B, 1.7B, 4B, 8B, and 32B; for more information about this model, please see the original blog post and technical report:\n",
" - [Qwen3: Think Deeper, Act Faster](https://qwenlm.github.io/blog/qwen3/)\n",
" - [Qwen3 Technical Report](https://arxiv.org/abs/2505.09388) \n",
"- Many architectural components in Qwen3 are similar to Llama 3; for a step-by-step guide that explains the individual components and the relationship between GPT and the components used here, you may like the GPT-to-Llama conversion notebooks:\n",
" - [Converting a From-Scratch GPT Architecture to Llama 2](../07_gpt_to_llama/converting-gpt-to-llama2.ipynb)\n",
" - [Converting Llama 2 to Llama 3.2 From Scratch](../07_gpt_to_llama/converting-llama2-to-llama3.ipynb)\n",
" \n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen-overview.webp\">\n",
" \n",
" \n",
"- About the code:\n",
" - all code is my own code, mapping the Qwen3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7c201adb-747e-437b-9a62-442802941e01",
"metadata": {},
"outputs": [],
"source": [
"# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
"outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.33.0\n",
"tokenizers version: 0.21.1\n",
"torch version: 2.6.0\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\n",
" \"huggingface_hub\", # to download pretrained weights\n",
" \"tokenizers\", # to implement the tokenizer\n",
" \"torch\", # to implement the model\n",
"]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"id": "07e96fbb-8e16-4f6d-835f-c6159321280b",
"metadata": {},
"source": [
"- This notebook supports both the base model and the reasoning (\"thinking\") model; which model to use can be controlled via the following flag:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "70a90338-624a-4706-aa55-6b4358070194",
"metadata": {},
"outputs": [],
"source": [
"USE_REASONING_MODEL = True"
]
},
{
"cell_type": "markdown",
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
"metadata": {
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
},
"source": [
"&nbsp;\n",
"# 1. Architecture code"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "82076c21-9331-4dcd-b017-42b046cf1a60",
"metadata": {
"id": "82076c21-9331-4dcd-b017-42b046cf1a60"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"\n",
"class FeedForward(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
" self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
" self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
"\n",
" def forward(self, x):\n",
" x_fc1 = self.fc1(x)\n",
" x_fc2 = self.fc2(x)\n",
" x = nn.functional.silu(x_fc1) * x_fc2\n",
" return self.fc3(x)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "56715760-37e1-433e-89da-04864c139a9e",
"metadata": {},
"outputs": [],
"source": [
"class RMSNorm(nn.Module):\n",
" def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):\n",
" super().__init__()\n",
" self.eps = eps\n",
" self.qwen3_compatible = qwen3_compatible\n",
" self.scale = nn.Parameter(torch.ones(emb_dim))\n",
" self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None\n",
"\n",
" def forward(self, x):\n",
" input_dtype = x.dtype\n",
"\n",
" if self.qwen3_compatible:\n",
" x = x.to(torch.float32)\n",
"\n",
" variance = x.pow(2).mean(dim=-1, keepdim=True)\n",
" norm_x = x * torch.rsqrt(variance + self.eps)\n",
" norm_x = norm_x * self.scale\n",
"\n",
" if self.shift is not None:\n",
" norm_x = norm_x + self.shift\n",
"\n",
" return norm_x.to(input_dtype)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4b9a346f-5826-4083-9162-abd56afc03f0",
"metadata": {
"id": "4b9a346f-5826-4083-9162-abd56afc03f0"
},
"outputs": [],
"source": [
"def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):\n",
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
"\n",
" # Precompute sine and cosine\n",
" cos = torch.cos(angles)\n",
" sin = torch.sin(angles)\n",
"\n",
" return cos, sin\n",
"\n",
"\n",
"def apply_rope(x, cos, sin):\n",
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
"\n",
" # Split x into first half and second half\n",
" x1 = x[..., : head_dim // 2] # First half\n",
" x2 = x[..., head_dim // 2 :] # Second half\n",
"\n",
" # Adjust sin and cos shapes\n",
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
"\n",
" # Apply the rotary transformation\n",
" rotated = torch.cat((-x2, x1), dim=-1)\n",
" x_rotated = (x * cos) + (rotated * sin)\n",
"\n",
" # It's ok to use lower-precision after applying cos and sin rotation\n",
" return x_rotated.to(dtype=x.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
"metadata": {
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
},
"outputs": [],
"source": [
"class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n",
" self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None\n",
" ):\n",
" super().__init__()\n",
" assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.num_kv_groups = num_kv_groups\n",
" self.group_size = num_heads // num_kv_groups\n",
"\n",
" if head_dim is None:\n",
" assert d_in % num_heads == 0, \"`d_in` must be divisible by `num_heads` if `head_dim` is not set\"\n",
" head_dim = d_in // num_heads\n",
"\n",
" self.head_dim = head_dim\n",
" self.d_out = num_heads * head_dim\n",
"\n",
" self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)\n",
" self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
" self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
"\n",
" self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)\n",
"\n",
" if qk_norm:\n",
" self.q_norm = RMSNorm(head_dim, eps=1e-6)\n",
" self.k_norm = RMSNorm(head_dim, eps=1e-6)\n",
" else:\n",
" self.q_norm = self.k_norm = None\n",
"\n",
" def forward(self, x, mask, cos, sin):\n",
" b, num_tokens, _ = x.shape\n",
"\n",
" # Apply projections\n",
" queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)\n",
" keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
" values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
"\n",
" # Reshape\n",
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
" keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
"\n",
" # Optional normalization\n",
" if self.q_norm:\n",
" queries = self.q_norm(queries)\n",
" if self.k_norm:\n",
" keys = self.k_norm(keys)\n",
"\n",
" # Apply RoPE\n",
" queries = apply_rope(queries, cos, sin)\n",
" keys = apply_rope(keys, cos, sin)\n",
"\n",
" # Expand K and V to match number of heads\n",
" keys = keys.repeat_interleave(self.group_size, dim=1)\n",
" values = values.repeat_interleave(self.group_size, dim=1)\n",
"\n",
" # Attention\n",
" attn_scores = queries @ keys.transpose(2, 3)\n",
" attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
" attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)\n",
"\n",
" context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n",
" return self.out_proj(context)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
"metadata": {
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
},
"outputs": [],
"source": [
"class TransformerBlock(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.att = GroupedQueryAttention(\n",
" d_in=cfg[\"emb_dim\"],\n",
" num_heads=cfg[\"n_heads\"],\n",
" head_dim=cfg[\"head_dim\"],\n",
" num_kv_groups=cfg[\"n_kv_groups\"],\n",
" qk_norm=cfg[\"qk_norm\"],\n",
" dtype=cfg[\"dtype\"]\n",
" )\n",
" self.ff = FeedForward(cfg)\n",
" self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n",
" self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n",
"\n",
" def forward(self, x, mask, cos, sin):\n",
" # Shortcut connection for attention block\n",
" shortcut = x\n",
" x = self.norm1(x)\n",
" x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n",
" x = x + shortcut # Add the original input back\n",
"\n",
" # Shortcut connection for feed-forward block\n",
" shortcut = x\n",
" x = self.norm2(x)\n",
" x = self.ff(x)\n",
" x = x + shortcut # Add the original input back\n",
"\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
"metadata": {
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
},
"outputs": [],
"source": [
"class Qwen3Model(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
"\n",
" # Main model parameters\n",
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
"\n",
" self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n",
" [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n",
" )\n",
"\n",
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" if cfg[\"head_dim\"] is None:\n",
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
" else:\n",
" head_dim = cfg[\"head_dim\"]\n",
" cos, sin = compute_rope_params(\n",
" head_dim=head_dim,\n",
" theta_base=cfg[\"rope_base\"],\n",
" context_length=cfg[\"context_length\"]\n",
" )\n",
" self.register_buffer(\"cos\", cos, persistent=False)\n",
" self.register_buffer(\"sin\", sin, persistent=False)\n",
" self.cfg = cfg\n",
"\n",
"\n",
" def forward(self, in_idx):\n",
" # Forward pass\n",
" tok_embeds = self.tok_emb(in_idx)\n",
" x = tok_embeds\n",
"\n",
" num_tokens = x.shape[1]\n",
" mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
" \n",
" for block in self.trf_blocks:\n",
" x = block(x, mask, self.cos, self.sin)\n",
" x = self.final_norm(x)\n",
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
" return logits"
]
},
{
"cell_type": "markdown",
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
"metadata": {
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
},
"source": [
"&nbsp;\n",
"# 2. Initialize model"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
"metadata": {
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
},
"outputs": [],
"source": [
"CHOOSE_MODEL = \"0.6B\"\n",
"\n",
"if CHOOSE_MODEL == \"0.6B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936, # Vocabulary size\n",
" \"context_length\": 40_960, # Context length that was used to train the model\n",
" \"emb_dim\": 1024, # Embedding dimension\n",
" \"n_heads\": 16, # Number of attention heads\n",
" \"n_layers\": 28, # Number of layers\n",
" \"hidden_dim\": 3072, # Size of the intermediate dimension in FeedForward\n",
" \"head_dim\": 128, # Size of the heads in GQA\n",
" \"qk_norm\": True, # Whether to normalize queries and values in GQA\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 1_000_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
" }\n",
"\n",
"elif CHOOSE_MODEL == \"1.7B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 2048, # 2x larger than above\n",
" \"n_heads\": 16,\n",
" \"n_layers\": 28,\n",
" \"hidden_dim\": 6144, # 2x larger than above\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"elif CHOOSE_MODEL == \"4B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 2560, # 25% larger than above\n",
" \"n_heads\": 32, # 2x larger than above\n",
" \"n_layers\": 36, # 29% larger than above\n",
" \"hidden_dim\": 9728, # ~3x larger than above\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"elif CHOOSE_MODEL == \"8B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 4096, # 60% larger than above\n",
" \"n_heads\": 32,\n",
" \"n_layers\": 36, # 26% larger than above\n",
" \"hidden_dim\": 12288,\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"elif CHOOSE_MODEL == \"14B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 5120, # 25% larger than above\n",
" \"n_heads\": 40, # 25% larger than above\n",
" \"n_layers\": 40, # 11% larger than above\n",
" \"hidden_dim\": 17408, # 42% larger than above\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"elif CHOOSE_MODEL == \"32B\":\n",
" QWEN3_CONFIG = {\n",
" \"vocab_size\": 151_936,\n",
" \"context_length\": 40_960,\n",
" \"emb_dim\": 5120, \n",
" \"n_heads\": 64, # 60% larger than above\n",
" \"n_layers\": 64, # 60% larger than above\n",
" \"hidden_dim\": 25600, # 47% larger than above\n",
" \"head_dim\": 128,\n",
" \"qk_norm\": True,\n",
" \"n_kv_groups\": 8,\n",
" \"rope_base\": 1_000_000.0,\n",
" \"dtype\": torch.bfloat16,\n",
" } \n",
"\n",
"else:\n",
" raise ValueError(f\"{CHOOSE_MODEL} is not supported.\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
"metadata": {
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
},
"outputs": [],
"source": [
"torch.manual_seed(123)\n",
"model = Qwen3Model(QWEN3_CONFIG)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Qwen3Model(\n",
" (tok_emb): Embedding(151936, 4096)\n",
" (trf_blocks): ModuleList(\n",
" (0-35): 36 x TransformerBlock(\n",
" (att): GroupedQueryAttention(\n",
" (W_query): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (W_key): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (W_value): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (q_norm): RMSNorm()\n",
" (k_norm): RMSNorm()\n",
" )\n",
" (ff): FeedForward(\n",
" (fc1): Linear(in_features=4096, out_features=12288, bias=False)\n",
" (fc2): Linear(in_features=4096, out_features=12288, bias=False)\n",
" (fc3): Linear(in_features=12288, out_features=4096, bias=False)\n",
" )\n",
" (norm1): RMSNorm()\n",
" (norm2): RMSNorm()\n",
" )\n",
" )\n",
" (final_norm): RMSNorm()\n",
" (out_head): Linear(in_features=4096, out_features=151936, bias=False)\n",
")"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "markdown",
"id": "90aca91d-4bee-45ce-993a-4ec5393abe2b",
"metadata": {},
"source": [
"- A quick check that the forward pass works before continuing:"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.7305, -1.2109, 0.4551, ..., -0.0215, -0.5742, -0.2754],\n",
" [-0.4023, -0.6094, 0.0415, ..., 0.6094, -0.6758, 0.3789],\n",
" [-0.4043, 0.1943, -0.0757, ..., 0.4121, -1.2344, -0.1445]]],\n",
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model(torch.tensor([1, 2, 3]).unsqueeze(0))"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"outputId": "00d7e983-262e-4c65-f322-f4d999311988"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of parameters: 8,190,735,360\n",
"\n",
"Total number of unique parameters: 7,568,405,504\n"
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters())\n",
"print(f\"Total number of parameters: {total_params:,}\")\n",
"\n",
"# Account for weight tying\n",
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float32 (PyTorch default): 61.06 GB\n",
"bfloat16: 30.53 GB\n"
]
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
" # Calculate total number of elements per parameter\n",
" param_size = param.numel()\n",
" total_params += param_size\n",
" # Check if gradients are stored for this parameter\n",
" if param.requires_grad:\n",
" total_grads += param_size\n",
"\n",
" # Calculate buffer size (non-parameters that require memory)\n",
" total_buffers = sum(buf.numel() for buf in model.buffers())\n",
"\n",
" # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
" # We assume parameters and gradients are stored in the same type as input dtype\n",
" element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
" total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
"\n",
" # Convert bytes to gigabytes\n",
" total_memory_gb = total_memory_bytes / (1024**3)\n",
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
"metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\")\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"\n",
"model.to(device);"
]
},
{
"cell_type": "markdown",
"id": "c172f89f-d301-439f-b809-46169e5f5945",
"metadata": {
"id": "c172f89f-d301-439f-b809-46169e5f5945"
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "75166128-5899-4995-9b88-9672e135650e",
"metadata": {
"id": "75166128-5899-4995-9b88-9672e135650e"
},
"outputs": [],
"source": [
"def load_weights_into_qwen(model, param_config, params):\n",
" def assign(left, right, tensor_name=\"unknown\"):\n",
" if left.shape != right.shape:\n",
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
"\n",
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
"\n",
" for l in range(param_config[\"n_layers\"]):\n",
" block = model.trf_blocks[l]\n",
" att = block.att\n",
"\n",
" # Q, K, V projections\n",
" att.W_query.weight = assign(\n",
" att.W_query.weight,\n",
" params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.q_proj.weight\"\n",
" )\n",
" att.W_key.weight = assign(\n",
" att.W_key.weight,\n",
" params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.k_proj.weight\"\n",
" )\n",
" att.W_value.weight = assign(\n",
" att.W_value.weight,\n",
" params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.v_proj.weight\"\n",
" )\n",
"\n",
" # Output projection\n",
" att.out_proj.weight = assign(\n",
" att.out_proj.weight,\n",
" params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.o_proj.weight\"\n",
" )\n",
"\n",
" # QK norms\n",
" if hasattr(att, \"q_norm\") and att.q_norm is not None:\n",
" att.q_norm.scale = assign(\n",
" att.q_norm.scale,\n",
" params[f\"model.layers.{l}.self_attn.q_norm.weight\"],\n",
" f\"model.layers.{l}.self_attn.q_norm.weight\"\n",
" )\n",
" if hasattr(att, \"k_norm\") and att.k_norm is not None:\n",
" att.k_norm.scale = assign(\n",
" att.k_norm.scale,\n",
" params[f\"model.layers.{l}.self_attn.k_norm.weight\"],\n",
" f\"model.layers.{l}.self_attn.k_norm.weight\"\n",
" )\n",
"\n",
" # Attention layernorm\n",
" block.norm1.scale = assign(\n",
" block.norm1.scale,\n",
" params[f\"model.layers.{l}.input_layernorm.weight\"],\n",
" f\"model.layers.{l}.input_layernorm.weight\"\n",
" )\n",
"\n",
" # Feedforward weights\n",
" block.ff.fc1.weight = assign(\n",
" block.ff.fc1.weight,\n",
" params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.gate_proj.weight\"\n",
" )\n",
" block.ff.fc2.weight = assign(\n",
" block.ff.fc2.weight,\n",
" params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.up_proj.weight\"\n",
" )\n",
" block.ff.fc3.weight = assign(\n",
" block.ff.fc3.weight,\n",
" params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.down_proj.weight\"\n",
" )\n",
" block.norm2.scale = assign(\n",
" block.norm2.scale,\n",
" params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
" f\"model.layers.{l}.post_attention_layernorm.weight\"\n",
" )\n",
"\n",
" # Final normalization and output head\n",
" model.final_norm.scale = assign(model.final_norm.scale, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
"\n",
" if \"lm_head.weight\" in params:\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
" else:\n",
" # Model uses weight tying, hence we reuse the embedding layer weights here\n",
" print(\"Model uses weight tying.\")\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17,
"referenced_widgets": [
"9881b6995c3f49dc89e6992fd9ab660b",
"17a3174e65c54476b2e0d1faf8f011ca",
"1bbf2e62c0754d1593beb4105a7f1ac1",
"b82112e1dec645d98aa1c1ba64abcb61",
"271e2bd6a35e4a8b92de8697f7c0be5f",
"90a79523187446dfa692723b2e5833a7",
"431ffb83b8c14bf182f0430e07ea6154",
"a8f1b72a33dd4b548de23fbd95e0da18",
"25cc36132d384189acfbecc59483134b",
"bfd06423ad544218968648016e731a46",
"d029630b63ff44cf807ade428d2eb421"
]
},
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bf7fbc5f95ed4f06b5ba47d4aec96738",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 14 files: 0%| | 0/14 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
},
{
"data": {
"text/plain": [
"Qwen3Model(\n",
" (tok_emb): Embedding(151936, 4096)\n",
" (trf_blocks): ModuleList(\n",
" (0-35): 36 x TransformerBlock(\n",
" (att): GroupedQueryAttention(\n",
" (W_query): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (W_key): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (W_value): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (q_norm): RMSNorm()\n",
" (k_norm): RMSNorm()\n",
" )\n",
" (ff): FeedForward(\n",
" (fc1): Linear(in_features=4096, out_features=12288, bias=False)\n",
" (fc2): Linear(in_features=4096, out_features=12288, bias=False)\n",
" (fc3): Linear(in_features=12288, out_features=4096, bias=False)\n",
" )\n",
" (norm1): RMSNorm()\n",
" (norm2): RMSNorm()\n",
" )\n",
" )\n",
" (final_norm): RMSNorm()\n",
" (out_head): Linear(in_features=4096, out_features=151936, bias=False)\n",
")"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import json\n",
"import os\n",
"from pathlib import Path\n",
"from safetensors.torch import load_file\n",
"from huggingface_hub import hf_hub_download, snapshot_download\n",
"\n",
"\n",
"if USE_REASONING_MODEL:\n",
" repo_id = f\"Qwen/Qwen3-{CHOOSE_MODEL}\"\n",
"else:\n",
" repo_id = f\"Qwen/Qwen3-{CHOOSE_MODEL}-Base\"\n",
"\n",
"local_dir = Path(repo_id).parts[-1]\n",
"\n",
"if CHOOSE_MODEL == \"0.6B\":\n",
" weights_file = hf_hub_download(\n",
" repo_id=repo_id,\n",
" filename=\"model.safetensors\",\n",
" local_dir=local_dir,\n",
" )\n",
" weights_dict = load_file(weights_file)\n",
"else:\n",
" repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)\n",
" index_path = os.path.join(repo_dir, \"model.safetensors.index.json\")\n",
" with open(index_path, \"r\") as f:\n",
" index = json.load(f)\n",
"\n",
" weights_dict = {}\n",
" for filename in set(index[\"weight_map\"].values()):\n",
" shard_path = os.path.join(repo_dir, filename)\n",
" shard = load_file(shard_path)\n",
" weights_dict.update(shard)\n",
"\n",
"load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)\n",
"model.to(device);"
]
},
{
"cell_type": "markdown",
"id": "6b345491-3510-4397-92d3-cd0a3fa3deee",
"metadata": {},
"source": [
"&nbsp;\n",
"# 4. Load tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
"metadata": {},
"outputs": [],
"source": [
"from tokenizers import Tokenizer\n",
"\n",
"\n",
"class Qwen3Tokenizer():\n",
" def __init__(self, tokenizer_file_path=\"tokenizer.json\", repo_id=None, add_generation_prompt=False, add_thinking=False):\n",
" self.tokenizer_file_path = tokenizer_file_path\n",
" self.add_generation_prompt = add_generation_prompt\n",
" self.add_thinking = add_thinking\n",
"\n",
" tokenizer_file_path_obj = Path(tokenizer_file_path)\n",
" if not tokenizer_file_path_obj.is_file() and repo_id is not None:\n",
" _ = hf_hub_download(\n",
" repo_id=repo_id,\n",
" filename=str(tokenizer_file_path_obj.name),\n",
" local_dir=str(tokenizer_file_path_obj.parent.name)\n",
" )\n",
" self.tokenizer = Tokenizer.from_file(tokenizer_file_path)\n",
"\n",
" def encode(self, prompt):\n",
" messages = [\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ] \n",
" formatted_prompt = self.format_qwen_chat(\n",
" messages,\n",
" add_generation_prompt=self.add_generation_prompt,\n",
" add_thinking=self.add_thinking\n",
" )\n",
" return self.tokenizer.encode(formatted_prompt).ids\n",
" \n",
" def decode(self, token_ids):\n",
" return self.tokenizer.decode(token_ids, skip_special_tokens=False)\n",
" \n",
" @staticmethod\n",
" def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):\n",
" prompt = \"\"\n",
" for msg in messages:\n",
" prompt += f\"<|im_start|>{msg['role']}\\n{msg['content']}<|im_end|>\\n\"\n",
" if add_generation_prompt:\n",
" prompt += \"<|im_start|>assistant\"\n",
" if not add_thinking:\n",
" prompt += \"<|think>\\n\\n<|/think>\\n\\n\"\n",
" else:\n",
" prompt += \"\\n\" \n",
" return prompt"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
"metadata": {},
"outputs": [],
"source": [
"if USE_REASONING_MODEL:\n",
" tokenizer_file_path = f\"Qwen3-{CHOOSE_MODEL}/tokenizer.json\"\n",
"else:\n",
" tokenizer_file_path = f\"Qwen3-{CHOOSE_MODEL}-Base/tokenizer.json\"\n",
"\n",
"tokenizer = Qwen3Tokenizer(\n",
" tokenizer_file_path=tokenizer_file_path,\n",
" repo_id=repo_id,\n",
" add_generation_prompt=USE_REASONING_MODEL,\n",
" add_thinking=USE_REASONING_MODEL\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "1946b534-e3af-431a-a222-391a60bfa892",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<|im_start|>user\\nGive me a short introduction to large language models.<|im_end|>\\n<|im_start|>assistant\\n'"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prompt = \"Give me a short introduction to large language models.\"\n",
"\n",
"input_token_ids = tokenizer.encode(prompt)\n",
"text = tokenizer.decode(input_token_ids)\n",
"text"
]
},
{
"cell_type": "markdown",
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
"metadata": {
"id": "57d07df1-4401-4792-b549-7c4cc5632323"
},
"source": [
"&nbsp;\n",
"# 5. Generate text"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
},
"outputs": [],
"source": [
"# Identical function from chapter 5\n",
"\n",
"def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n",
" # For-loop is the same as before: Get logits, and only focus on last time step\n",
" for _ in range(max_new_tokens):\n",
" idx_cond = idx[:, -context_size:]\n",
" with torch.no_grad():\n",
" logits = model(idx_cond)\n",
" logits = logits[:, -1, :]\n",
"\n",
" # Filter logits with top_k sampling\n",
" if top_k is not None:\n",
" # Keep only top_k values\n",
" top_logits, _ = torch.topk(logits, top_k)\n",
" min_val = top_logits[:, -1]\n",
" logits = torch.where(logits < min_val, torch.tensor(-torch.inf).to(logits.device), logits)\n",
"\n",
" # pply temperature scaling\n",
" if temperature > 0.0:\n",
" logits = logits / temperature\n",
"\n",
" # Apply softmax to get probabilities\n",
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
"\n",
" # Sample from the distribution\n",
" idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)\n",
"\n",
" # Otherwise same as before: get idx of the vocab entry with the highest logits value\n",
" else:\n",
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n",
"\n",
" if eos_id is not None and idx_next.item() == eos_id:\n",
" break # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n",
"\n",
" # Same as before: append sampled index to the running sequence\n",
" idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n",
"\n",
" return idx"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time: 78.98 sec\n",
"<|im_start|>user\n",
"Give me a short introduction to large language models.<|im_end|>\n",
"<|im_start|>assistant\n",
"<think>\n",
"Okay, the user wants a short introduction to large language models. Let me start by defining what they are. They're AI systems trained on vast amounts of text data, right? I should mention their ability to understand and generate human-like text. Maybe include examples like GPT or BERT. Also, highlight their applications in tasks like answering questions, writing, coding, and more. Need to keep it concise but cover the key points. Oh, and maybe touch on how they're trained using deep learning techniques. Wait, should I explain the training process briefly? Probably not necessary for a short intro. Focus on the main aspects: what they are, how they work, and their uses. Make sure it's easy to understand without too...\n"
]
}
],
"source": [
"import time\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"start = time.time()\n",
"\n",
"output_token_ids = generate(\n",
" model=model,\n",
" idx=torch.tensor(input_token_ids, device=device).unsqueeze(0),\n",
" max_new_tokens=150,\n",
" context_size=QWEN3_CONFIG[\"context_length\"],\n",
" top_k=1,\n",
" temperature=0.\n",
")\n",
"\n",
"print(f\"Time: {time.time() - start:.2f} sec\")\n",
"\n",
"if torch.cuda.is_available():\n",
" max_mem_bytes = torch.cuda.max_memory_allocated()\n",
" max_mem_gb = max_mem_bytes / (1024 ** 3)\n",
" print(f\"Max memory allocated: {max_mem_gb:.2f} GB\")\n",
"\n",
"output_text = tokenizer.decode(output_token_ids.squeeze(0).tolist())\n",
"\n",
"print(output_text + \"...\")"
]
},
{
"cell_type": "markdown",
"id": "549324d6-5c71-4147-ae21-2e67675faa3d",
"metadata": {
"id": "549324d6-5c71-4147-ae21-2e67675faa3d"
},
"source": [
"&nbsp;\n",
"# What's next?"
]
},
{
"cell_type": "markdown",
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
"metadata": {
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
},
"source": [
"- Check out the [README.md](./README.md), to use this model via the `llms_from_scratch` package\n",
"- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
"\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}