diff --git a/.gitignore b/.gitignore
index f23dddb..27baac8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -48,12 +48,13 @@ ch05/07_gpt_to_llama/Llama-3.2-1B
ch05/07_gpt_to_llama/Llama-3.2-1B-Instruct
ch05/07_gpt_to_llama/Llama-3.2-3B
ch05/07_gpt_to_llama/Llama-3.2-3B-Instruct
+ch05/07_gpt_to_llama/llama3.2-1B-instruct.pth
+ch05/07_gpt_to_llama/tokenizer.model
ch05/10_llm-training-speed/middlemarch.txt
ch05/10_llm-training-speed/loss.pdf
ch05/10_llm-training-speed/model.pth
-ch05/07_gpt_to_llama/Untitled.ipynb
-ch05/07_gpt_to_llama/llama3.2-1B-instruct.pth
-ch05/07_gpt_to_llama/tokenizer.model
+ch05/11_qwen3/Qwen3-0.6B
+ch05/11_qwen3/Qwen3-0.6B-Base
ch06/01_main-chapter-code/gpt2
ch06/02_bonus_additional-experiments/gpt2
diff --git a/README.md b/README.md
index fa2648b..64ac9d7 100644
--- a/README.md
+++ b/README.md
@@ -121,6 +121,7 @@ Several folders contain optional materials as a bonus for interested readers:
- [Building a User Interface to Interact With the Pretrained LLM](ch05/06_user_interface)
- [Converting GPT to Llama](ch05/07_gpt_to_llama)
- [Llama 3.2 From Scratch](ch05/07_gpt_to_llama/standalone-llama32.ipynb)
+ - [Qwen3 From Scratch](ch05/11_qwen3/standalone-qwen3.ipynb)
- [Memory-efficient Model Weight Loading](ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb)
- [Extending the Tiktoken BPE Tokenizer with New Tokens](ch05/09_extending-tokenizers/extend-tiktoken.ipynb)
- [PyTorch Performance Tips for Faster LLM Training](ch05/10_llm-training-speed)
diff --git a/ch05/11_qwen3/README.md b/ch05/11_qwen3/README.md
new file mode 100644
index 0000000..bb17fdd
--- /dev/null
+++ b/ch05/11_qwen3/README.md
@@ -0,0 +1,191 @@
+# Qwen3 From Scratch
+
+This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Qwen3 0.6B.
+
+
+
+
+
+### Using Qwen3 0.6B via the `llms-from-scratch` package
+
+For an easy way to use the Qwen3 from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
+
+
+#### 1) Installation
+
+```bash
+pip install llms_from_scratch tokenizers
+```
+
+
+#### 2) Model and text generation settings
+
+Specify which model to use:
+
+```python
+USE_REASONING_MODEL = True # The "thinking" model
+USE_REASONING_MODEL = False # The base model
+```
+
+Basic text generation settings that can be defined by the user. With 150 tokens, the model requires approximately 1.5 GB memory.
+
+```python
+MAX_NEW_TOKENS = 150
+TEMPERATURE = 0.
+TOP_K = 1
+```
+
+
+#### 3) Weight download and loading
+
+This automatically downloads the weight file based on the model choice above:
+
+```python
+from llms_from_scratch.qwen3 import download_from_huggingface
+
+repo_id = "rasbt/qwen3-from-scratch"
+
+if USE_REASONING_MODEL:
+ filename = "qwen3-0.6B.pth"
+ local_dir = "Qwen3-0.6B"
+else:
+ filename = "qwen3-0.6B-base.pth"
+ local_dir = "Qwen3-0.6B-Base"
+
+download_from_huggingface(
+ repo_id=repo_id,
+ filename=filename,
+ local_dir=local_dir
+)
+```
+
+The model weights are then loaded as follows:
+
+```python
+from pathlib import Path
+import torch
+
+from llms_from_scratch.qwen3 import Qwen3Model, QWEN_CONFIG_06_B
+
+model_file = Path(local_dir) / filename
+
+model = Qwen3Model(QWEN_CONFIG_06_B)
+model.load_state_dict(torch.load(model_file, weights_only=True, map_location="cpu"))
+
+device = (
+ torch.device("cuda") if torch.cuda.is_available() else
+ torch.device("mps") if torch.backends.mps.is_available() else
+ torch.device("cpu")
+)
+model.to(device)
+```
+
+
+#### 4) Initialize tokenizer
+
+The following code downloads and initializes the tokenizer:
+
+```python
+from llms_from_scratch.qwen3 import Qwen3Tokenizer
+
+if USE_REASONING_MODEL:
+ tok_filename = "tokenizer.json"
+else:
+ tok_filename = "tokenizer-base.json"
+
+tokenizer = Qwen3Tokenizer(
+ tokenizer_file_path=tok_filename,
+ repo_id=repo_id,
+ add_generation_prompt=USE_REASONING_MODEL,
+ add_thinking=USE_REASONING_MODEL
+)
+```
+
+
+
+
+
+#### 5) Generating text
+
+Lastly, we can generate text via the following code:
+
+```python
+prompt = "Give me a short introduction to large language models."
+input_token_ids = tokenizer.encode(prompt)
+```
+
+
+
+
+
+```python
+from llms_from_scratch.ch05 import generate
+import time
+
+torch.manual_seed(123)
+
+start = time.time()
+
+output_token_ids = generate(
+ model=model,
+ idx=torch.tensor(input_token_ids, device=device).unsqueeze(0),
+ max_new_tokens=150,
+ context_size=QWEN_CONFIG_06_B["context_length"],
+ top_k=1,
+ temperature=0.
+)
+
+total_time = time.time() - start
+print(f"Time: {total_time:.2f} sec")
+print(f"{int(len(output_token_ids[0])/total_time)} tokens/sec")
+
+if torch.cuda.is_available():
+ max_mem_bytes = torch.cuda.max_memory_allocated()
+ max_mem_gb = max_mem_bytes / (1024 ** 3)
+ print(f"Max memory allocated: {max_mem_gb:.2f} GB")
+
+output_text = tokenizer.decode(output_token_ids.squeeze(0).tolist())
+
+print("\n\nOutput text:\n\n", output_text + "...")
+```
+
+When using the Qwen3 0.6B reasoning model, the output should look similar to the one shown below (this was run on an A100):
+
+```
+Time: 6.35 sec
+25 tokens/sec
+Max memory allocated: 1.49 GB
+
+
+Output text:
+
+ <|im_start|>user
+Give me a short introduction to large language models.<|im_end|>
+Large language models (LLMs) are advanced artificial intelligence systems designed to generate human-like text. They are trained on vast amounts of text data, allowing them to understand and generate coherent, contextually relevant responses. LLMs are used in a variety of applications, including chatbots, virtual assistants, content generation, and more. They are powered by deep learning algorithms and can be fine-tuned for specific tasks, making them versatile tools for a wide range of industries.<|endoftext|>Human resources department of a company is planning to hire 100 new employees. The company has a budget of $100,000 for the recruitment process. The company has a minimum wage of $10 per hour. The company has a total of...
+```
+
+
+#### Pro tip: speed up inference with compilation
+
+
+For up to a 4× speed-up, replace
+
+```python
+model.to(device)
+```
+
+with
+
+```python
+model = torch.compile(model)
+model.to(device)
+```
+
+Note: There is a significant multi-minute upfront cost when compiling, and the speed-up takes effect after the first `generate` call.
+
+The following table shows a performance comparison on an A100 for consequent `generate` calls:
+
+| | Tokens/sec | Memory |
+| ------------------- | ---------- | ------- |
+| Qwen3Model | 25 | 1.49 GB |
+| Qwen3Model compiled | 101 | 1.99 GB |
diff --git a/ch05/11_qwen3/standalone-qwen3.ipynb b/ch05/11_qwen3/standalone-qwen3.ipynb
new file mode 100644
index 0000000..afb4e71
--- /dev/null
+++ b/ch05/11_qwen3/standalone-qwen3.ipynb
@@ -0,0 +1,1788 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
+ "metadata": {
+ "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
+ },
+ "source": [
+ "
\n",
+ "\n",
+ "\n",
+ "\n",
+ "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka \n",
+ " Code repository: https://github.com/rasbt/LLMs-from-scratch\n",
+ "\n",
+ " | \n",
+ "\n",
+ " \n",
+ " | \n",
+ "
\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
+ "metadata": {
+ "id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
+ },
+ "source": [
+ "# Qwen3 0.6B From Scratch (A Standalone Notebook)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
+ "metadata": {
+ "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
+ },
+ "source": [
+ "- This notebook is purposefully minimal and focuses on the code to implement Qwen3 0.6B; for more information about this model, please see the original blog post and technical report:\n",
+ " - [Qwen3: Think Deeper, Act Faster](https://qwenlm.github.io/blog/qwen3/)\n",
+ " - [Qwen3 Technical Report](https://arxiv.org/abs/2505.09388) \n",
+ "- Many architectural components in Qwen3 are similar to Llama 3; for a step-by-step guide that explains the individual components and the relationship between GPT and the components used here, you may like the GPT-to-Llama conversion notebooks:\n",
+ " - [Converting a From-Scratch GPT Architecture to Llama 2](../07_gpt_to_llama/converting-gpt-to-llama2.ipynb)\n",
+ " - [Converting Llama 2 to Llama 3.2 From Scratch](../07_gpt_to_llama/converting-llama2-to-llama3.ipynb)\n",
+ " \n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ "- About the code:\n",
+ " - all code is my own code, mapping the Qwen3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "7c201adb-747e-437b-9a62-442802941e01",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
+ "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "huggingface_hub version: 0.33.0\n",
+ "tokenizers version: 0.21.1\n",
+ "torch version: 2.6.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "from importlib.metadata import version\n",
+ "\n",
+ "pkgs = [\n",
+ " \"huggingface_hub\", # to download pretrained weights\n",
+ " \"tokenizers\", # to implement the tokenizer\n",
+ " \"torch\", # to implement the model\n",
+ "]\n",
+ "for p in pkgs:\n",
+ " print(f\"{p} version: {version(p)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "07e96fbb-8e16-4f6d-835f-c6159321280b",
+ "metadata": {},
+ "source": [
+ "- This notebook supports both the base model and the reasoning (\"thinking\") model; which model to use can be controlled via the following flag:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "70a90338-624a-4706-aa55-6b4358070194",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "USE_REASONING_MODEL = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
+ "metadata": {
+ "id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
+ },
+ "source": [
+ " \n",
+ "# 1. Architecture code"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "82076c21-9331-4dcd-b017-42b046cf1a60",
+ "metadata": {
+ "id": "82076c21-9331-4dcd-b017-42b046cf1a60"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "\n",
+ "\n",
+ "class FeedForward(nn.Module):\n",
+ " def __init__(self, cfg):\n",
+ " super().__init__()\n",
+ " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
+ " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
+ " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x_fc1 = self.fc1(x)\n",
+ " x_fc2 = self.fc2(x)\n",
+ " x = nn.functional.silu(x_fc1) * x_fc2\n",
+ " return self.fc3(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "56715760-37e1-433e-89da-04864c139a9e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class RMSNorm(nn.Module):\n",
+ " def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):\n",
+ " super().__init__()\n",
+ " self.eps = eps\n",
+ " self.qwen3_compatible = qwen3_compatible\n",
+ " self.scale = nn.Parameter(torch.ones(emb_dim))\n",
+ " self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None\n",
+ "\n",
+ " def forward(self, x):\n",
+ " input_dtype = x.dtype\n",
+ "\n",
+ " if self.qwen3_compatible:\n",
+ " x = x.to(torch.float32)\n",
+ "\n",
+ " variance = x.pow(2).mean(dim=-1, keepdim=True)\n",
+ " norm_x = x * torch.rsqrt(variance + self.eps)\n",
+ " norm_x = norm_x * self.scale\n",
+ "\n",
+ " if self.shift is not None:\n",
+ " norm_x = norm_x + self.shift\n",
+ "\n",
+ " return norm_x.to(input_dtype)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "4b9a346f-5826-4083-9162-abd56afc03f0",
+ "metadata": {
+ "id": "4b9a346f-5826-4083-9162-abd56afc03f0"
+ },
+ "outputs": [],
+ "source": [
+ "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):\n",
+ " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
+ "\n",
+ " # Compute the inverse frequencies\n",
+ " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n",
+ "\n",
+ " # Generate position indices\n",
+ " positions = torch.arange(context_length, dtype=dtype)\n",
+ "\n",
+ " # Compute the angles\n",
+ " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
+ "\n",
+ " # Expand angles to match the head_dim\n",
+ " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
+ "\n",
+ " # Precompute sine and cosine\n",
+ " cos = torch.cos(angles)\n",
+ " sin = torch.sin(angles)\n",
+ "\n",
+ " return cos, sin\n",
+ "\n",
+ "\n",
+ "def apply_rope(x, cos, sin):\n",
+ " # x: (batch_size, num_heads, seq_len, head_dim)\n",
+ " batch_size, num_heads, seq_len, head_dim = x.shape\n",
+ " assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
+ "\n",
+ " # Split x into first half and second half\n",
+ " x1 = x[..., : head_dim // 2] # First half\n",
+ " x2 = x[..., head_dim // 2 :] # Second half\n",
+ "\n",
+ " # Adjust sin and cos shapes\n",
+ " cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
+ " sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
+ "\n",
+ " # Apply the rotary transformation\n",
+ " rotated = torch.cat((-x2, x1), dim=-1)\n",
+ " x_rotated = (x * cos) + (rotated * sin)\n",
+ "\n",
+ " # It's ok to use lower-precision after applying cos and sin rotation\n",
+ " return x_rotated.to(dtype=x.dtype)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
+ "metadata": {
+ "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
+ },
+ "outputs": [],
+ "source": [
+ "class GroupedQueryAttention(nn.Module):\n",
+ " def __init__(\n",
+ " self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None\n",
+ " ):\n",
+ " super().__init__()\n",
+ " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
+ "\n",
+ " self.num_heads = num_heads\n",
+ " self.num_kv_groups = num_kv_groups\n",
+ " self.group_size = num_heads // num_kv_groups\n",
+ "\n",
+ " if head_dim is None:\n",
+ " assert d_in % num_heads == 0, \"`d_in` must be divisible by `num_heads` if `head_dim` is not set\"\n",
+ " head_dim = d_in // num_heads\n",
+ "\n",
+ " self.head_dim = head_dim\n",
+ " self.d_out = num_heads * head_dim\n",
+ "\n",
+ " self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)\n",
+ " self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
+ " self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
+ "\n",
+ " self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)\n",
+ "\n",
+ " if qk_norm:\n",
+ " self.q_norm = RMSNorm(head_dim, eps=1e-6)\n",
+ " self.k_norm = RMSNorm(head_dim, eps=1e-6)\n",
+ " else:\n",
+ " self.q_norm = self.k_norm = None\n",
+ "\n",
+ " def forward(self, x, mask, cos, sin):\n",
+ " b, num_tokens, _ = x.shape\n",
+ "\n",
+ " # Apply projections\n",
+ " queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)\n",
+ " keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
+ " values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
+ "\n",
+ " # Reshape\n",
+ " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
+ " keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
+ " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
+ "\n",
+ " # Optional normalization\n",
+ " if self.q_norm:\n",
+ " queries = self.q_norm(queries)\n",
+ " if self.k_norm:\n",
+ " keys = self.k_norm(keys)\n",
+ "\n",
+ " # Apply RoPE\n",
+ " queries = apply_rope(queries, cos, sin)\n",
+ " keys = apply_rope(keys, cos, sin)\n",
+ "\n",
+ " # Expand K and V to match number of heads\n",
+ " keys = keys.repeat_interleave(self.group_size, dim=1)\n",
+ " values = values.repeat_interleave(self.group_size, dim=1)\n",
+ "\n",
+ " # Attention\n",
+ " attn_scores = queries @ keys.transpose(2, 3)\n",
+ " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
+ " attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)\n",
+ "\n",
+ " context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n",
+ " return self.out_proj(context)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
+ "metadata": {
+ "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
+ },
+ "outputs": [],
+ "source": [
+ "class TransformerBlock(nn.Module):\n",
+ " def __init__(self, cfg):\n",
+ " super().__init__()\n",
+ " self.att = GroupedQueryAttention(\n",
+ " d_in=cfg[\"emb_dim\"],\n",
+ " num_heads=cfg[\"n_heads\"],\n",
+ " head_dim=cfg[\"head_dim\"],\n",
+ " num_kv_groups=cfg[\"n_kv_groups\"],\n",
+ " qk_norm=cfg[\"qk_norm\"],\n",
+ " dtype=cfg[\"dtype\"]\n",
+ " )\n",
+ " self.ff = FeedForward(cfg)\n",
+ " self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n",
+ " self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n",
+ "\n",
+ " def forward(self, x, mask, cos, sin):\n",
+ " # Shortcut connection for attention block\n",
+ " shortcut = x\n",
+ " x = self.norm1(x)\n",
+ " x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n",
+ " x = x + shortcut # Add the original input back\n",
+ "\n",
+ " # Shortcut connection for feed-forward block\n",
+ " shortcut = x\n",
+ " x = self.norm2(x)\n",
+ " x = self.ff(x)\n",
+ " x = x + shortcut # Add the original input back\n",
+ "\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
+ "metadata": {
+ "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
+ },
+ "outputs": [],
+ "source": [
+ "class Qwen3Model(nn.Module):\n",
+ " def __init__(self, cfg):\n",
+ " super().__init__()\n",
+ "\n",
+ " # Main model parameters\n",
+ " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
+ "\n",
+ " self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n",
+ " [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n",
+ " )\n",
+ "\n",
+ " self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
+ " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
+ "\n",
+ " # Reusuable utilities\n",
+ " if cfg[\"head_dim\"] is None:\n",
+ " head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
+ " else:\n",
+ " head_dim = cfg[\"head_dim\"]\n",
+ " cos, sin = compute_rope_params(\n",
+ " head_dim=head_dim,\n",
+ " theta_base=cfg[\"rope_base\"],\n",
+ " context_length=cfg[\"context_length\"]\n",
+ " )\n",
+ " self.register_buffer(\"cos\", cos, persistent=False)\n",
+ " self.register_buffer(\"sin\", sin, persistent=False)\n",
+ " self.cfg = cfg\n",
+ "\n",
+ "\n",
+ " def forward(self, in_idx):\n",
+ " # Forward pass\n",
+ " tok_embeds = self.tok_emb(in_idx)\n",
+ " x = tok_embeds\n",
+ "\n",
+ " num_tokens = x.shape[1]\n",
+ " mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
+ " \n",
+ " for block in self.trf_blocks:\n",
+ " x = block(x, mask, self.cos, self.sin)\n",
+ " x = self.final_norm(x)\n",
+ " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
+ " return logits"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
+ "metadata": {
+ "id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
+ },
+ "source": [
+ " \n",
+ "# 2. Initialize model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "caa142fa-b375-4e78-b392-2072ced666f3",
+ "metadata": {
+ "id": "caa142fa-b375-4e78-b392-2072ced666f3"
+ },
+ "outputs": [],
+ "source": [
+ "# Qwen3 0.6B\n",
+ "\n",
+ "QWEN3_CONFIG = {\n",
+ " \"vocab_size\": 151_936, # Vocabulary size\n",
+ " \"context_length\": 40_960, # Context length that was used to train the model\n",
+ " \"emb_dim\": 1024, # Embedding dimension\n",
+ " \"n_heads\": 16, # Number of attention heads\n",
+ " \"n_layers\": 28, # Number of layers\n",
+ " \"hidden_dim\": 3072, # Size of the intermediate dimension in FeedForward\n",
+ " \"head_dim\": 128, # Size of the heads in GQA\n",
+ " \"qk_norm\": True, # Whether to normalize queries and values in GQA\n",
+ " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
+ " \"rope_base\": 1_000_000.0, # The base in RoPE's \"theta\"\n",
+ " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
+ "metadata": {
+ "id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
+ },
+ "outputs": [],
+ "source": [
+ "torch.manual_seed(123)\n",
+ "model = Qwen3Model(QWEN3_CONFIG)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "eaf86265-4e9d-4024-9ed0-99076944e304",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Qwen3Model(\n",
+ " (tok_emb): Embedding(151936, 1024)\n",
+ " (trf_blocks): ModuleList(\n",
+ " (0-27): 28 x TransformerBlock(\n",
+ " (att): GroupedQueryAttention(\n",
+ " (W_query): Linear(in_features=1024, out_features=2048, bias=False)\n",
+ " (W_key): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " (W_value): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
+ " (q_norm): RMSNorm()\n",
+ " (k_norm): RMSNorm()\n",
+ " )\n",
+ " (ff): FeedForward(\n",
+ " (fc1): Linear(in_features=1024, out_features=3072, bias=False)\n",
+ " (fc2): Linear(in_features=1024, out_features=3072, bias=False)\n",
+ " (fc3): Linear(in_features=3072, out_features=1024, bias=False)\n",
+ " )\n",
+ " (norm1): RMSNorm()\n",
+ " (norm2): RMSNorm()\n",
+ " )\n",
+ " )\n",
+ " (final_norm): RMSNorm()\n",
+ " (out_head): Linear(in_features=1024, out_features=151936, bias=False)\n",
+ ")"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "90aca91d-4bee-45ce-993a-4ec5393abe2b",
+ "metadata": {},
+ "source": [
+ "- A quick check that the forward pass works before continuing:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[-0.2256, -0.0164, -0.7070, ..., 0.4414, 0.1245, 1.0703],\n",
+ " [-0.6602, 0.5352, -0.0718, ..., -0.0737, 0.5391, 0.3086],\n",
+ " [-0.4785, -0.1562, 0.1045, ..., -0.2324, 0.2354, 0.6328]]],\n",
+ " dtype=torch.bfloat16, grad_fn=)"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model(torch.tensor([1, 2, 3]).unsqueeze(0))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
+ "outputId": "00d7e983-262e-4c65-f322-f4d999311988"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total number of parameters: 751,632,384\n",
+ "\n",
+ "Total number of unique parameters: 596,049,920\n"
+ ]
+ }
+ ],
+ "source": [
+ "total_params = sum(p.numel() for p in model.parameters())\n",
+ "print(f\"Total number of parameters: {total_params:,}\")\n",
+ "\n",
+ "# Account for weight tying\n",
+ "total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
+ "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
+ "outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "float32 (PyTorch default): 5.64 GB\n",
+ "bfloat16: 2.82 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "def model_memory_size(model, input_dtype=torch.float32):\n",
+ " total_params = 0\n",
+ " total_grads = 0\n",
+ " for param in model.parameters():\n",
+ " # Calculate total number of elements per parameter\n",
+ " param_size = param.numel()\n",
+ " total_params += param_size\n",
+ " # Check if gradients are stored for this parameter\n",
+ " if param.requires_grad:\n",
+ " total_grads += param_size\n",
+ "\n",
+ " # Calculate buffer size (non-parameters that require memory)\n",
+ " total_buffers = sum(buf.numel() for buf in model.buffers())\n",
+ "\n",
+ " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
+ " # We assume parameters and gradients are stored in the same type as input dtype\n",
+ " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
+ " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
+ "\n",
+ " # Convert bytes to gigabytes\n",
+ " total_memory_gb = total_memory_bytes / (1024**3)\n",
+ "\n",
+ " return total_memory_gb\n",
+ "\n",
+ "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
+ "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
+ "metadata": {
+ "id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
+ },
+ "outputs": [],
+ "source": [
+ "if torch.cuda.is_available():\n",
+ " device = torch.device(\"cuda\")\n",
+ "elif torch.backends.mps.is_available():\n",
+ " device = torch.device(\"mps\")\n",
+ "else:\n",
+ " device = torch.device(\"cpu\")\n",
+ "\n",
+ "model.to(device);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c172f89f-d301-439f-b809-46169e5f5945",
+ "metadata": {
+ "id": "c172f89f-d301-439f-b809-46169e5f5945"
+ },
+ "source": [
+ " \n",
+ "# 4. Load pretrained weights"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "75166128-5899-4995-9b88-9672e135650e",
+ "metadata": {
+ "id": "75166128-5899-4995-9b88-9672e135650e"
+ },
+ "outputs": [],
+ "source": [
+ "def load_weights_into_qwen(model, param_config, params):\n",
+ " def assign(left, right, tensor_name=\"unknown\"):\n",
+ " if left.shape != right.shape:\n",
+ " raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
+ " return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
+ "\n",
+ " model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
+ "\n",
+ " for l in range(param_config[\"n_layers\"]):\n",
+ " block = model.trf_blocks[l]\n",
+ " att = block.att\n",
+ "\n",
+ " # Q, K, V projections\n",
+ " att.W_query.weight = assign(\n",
+ " att.W_query.weight,\n",
+ " params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
+ " f\"model.layers.{l}.self_attn.q_proj.weight\"\n",
+ " )\n",
+ " att.W_key.weight = assign(\n",
+ " att.W_key.weight,\n",
+ " params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
+ " f\"model.layers.{l}.self_attn.k_proj.weight\"\n",
+ " )\n",
+ " att.W_value.weight = assign(\n",
+ " att.W_value.weight,\n",
+ " params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
+ " f\"model.layers.{l}.self_attn.v_proj.weight\"\n",
+ " )\n",
+ "\n",
+ " # Output projection\n",
+ " att.out_proj.weight = assign(\n",
+ " att.out_proj.weight,\n",
+ " params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
+ " f\"model.layers.{l}.self_attn.o_proj.weight\"\n",
+ " )\n",
+ "\n",
+ " # QK norms\n",
+ " if hasattr(att, \"q_norm\") and att.q_norm is not None:\n",
+ " att.q_norm.scale = assign(\n",
+ " att.q_norm.scale,\n",
+ " params[f\"model.layers.{l}.self_attn.q_norm.weight\"],\n",
+ " f\"model.layers.{l}.self_attn.q_norm.weight\"\n",
+ " )\n",
+ " if hasattr(att, \"k_norm\") and att.k_norm is not None:\n",
+ " att.k_norm.scale = assign(\n",
+ " att.k_norm.scale,\n",
+ " params[f\"model.layers.{l}.self_attn.k_norm.weight\"],\n",
+ " f\"model.layers.{l}.self_attn.k_norm.weight\"\n",
+ " )\n",
+ "\n",
+ " # Attention layernorm\n",
+ " block.norm1.scale = assign(\n",
+ " block.norm1.scale,\n",
+ " params[f\"model.layers.{l}.input_layernorm.weight\"],\n",
+ " f\"model.layers.{l}.input_layernorm.weight\"\n",
+ " )\n",
+ "\n",
+ " # Feedforward weights\n",
+ " block.ff.fc1.weight = assign(\n",
+ " block.ff.fc1.weight,\n",
+ " params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
+ " f\"model.layers.{l}.mlp.gate_proj.weight\"\n",
+ " )\n",
+ " block.ff.fc2.weight = assign(\n",
+ " block.ff.fc2.weight,\n",
+ " params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
+ " f\"model.layers.{l}.mlp.up_proj.weight\"\n",
+ " )\n",
+ " block.ff.fc3.weight = assign(\n",
+ " block.ff.fc3.weight,\n",
+ " params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
+ " f\"model.layers.{l}.mlp.down_proj.weight\"\n",
+ " )\n",
+ " block.norm2.scale = assign(\n",
+ " block.norm2.scale,\n",
+ " params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
+ " f\"model.layers.{l}.post_attention_layernorm.weight\"\n",
+ " )\n",
+ "\n",
+ " # Final normalization and output head\n",
+ " model.final_norm.scale = assign(model.final_norm.scale, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
+ "\n",
+ " # Model uses weight tying, hence we reuse the embedding layer weights here\n",
+ " model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17,
+ "referenced_widgets": [
+ "9881b6995c3f49dc89e6992fd9ab660b",
+ "17a3174e65c54476b2e0d1faf8f011ca",
+ "1bbf2e62c0754d1593beb4105a7f1ac1",
+ "b82112e1dec645d98aa1c1ba64abcb61",
+ "271e2bd6a35e4a8b92de8697f7c0be5f",
+ "90a79523187446dfa692723b2e5833a7",
+ "431ffb83b8c14bf182f0430e07ea6154",
+ "a8f1b72a33dd4b548de23fbd95e0da18",
+ "25cc36132d384189acfbecc59483134b",
+ "bfd06423ad544218968648016e731a46",
+ "d029630b63ff44cf807ade428d2eb421"
+ ]
+ },
+ "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
+ "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
+ },
+ "outputs": [],
+ "source": [
+ "from pathlib import Path\n",
+ "from safetensors.torch import load_file\n",
+ "from huggingface_hub import hf_hub_download\n",
+ "\n",
+ "\n",
+ "if USE_REASONING_MODEL:\n",
+ " repo_id = \"Qwen/Qwen3-0.6B\"\n",
+ "else:\n",
+ " repo_id = \"Qwen/Qwen3-0.6B-Base\"\n",
+ "local_dir = Path(repo_id).parts[-1]\n",
+ "\n",
+ "\n",
+ "weights_file = hf_hub_download(\n",
+ " repo_id=repo_id,\n",
+ " filename=\"model.safetensors\",\n",
+ " local_dir=local_dir\n",
+ ")\n",
+ "\n",
+ "weights_dict = load_file(weights_file)\n",
+ "\n",
+ "load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)\n",
+ "model.to(device)\n",
+ "del weights_file # free up memory"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6b345491-3510-4397-92d3-cd0a3fa3deee",
+ "metadata": {},
+ "source": [
+ " \n",
+ "# 4. Load tokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "b68ab489-48e5-471e-a814-56cda2d60f81",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tokenizers import Tokenizer\n",
+ "\n",
+ "\n",
+ "class Qwen3Tokenizer():\n",
+ " def __init__(self, tokenizer_file_path=\"tokenizer.json\", repo_id=None, add_generation_prompt=False, add_thinking=False):\n",
+ " self.tokenizer_file_path = tokenizer_file_path\n",
+ " self.add_generation_prompt = add_generation_prompt\n",
+ " self.add_thinking = add_thinking\n",
+ "\n",
+ " tokenizer_file_path_obj = Path(tokenizer_file_path)\n",
+ " if not tokenizer_file_path_obj.is_file() and repo_id is not None:\n",
+ " _ = hf_hub_download(\n",
+ " repo_id=repo_id,\n",
+ " filename=str(tokenizer_file_path_obj.name),\n",
+ " local_dir=str(tokenizer_file_path_obj.parent.name)\n",
+ " )\n",
+ " self.tokenizer = Tokenizer.from_file(tokenizer_file_path)\n",
+ "\n",
+ " def encode(self, prompt):\n",
+ " messages = [\n",
+ " {\"role\": \"user\", \"content\": prompt}\n",
+ " ] \n",
+ " formatted_prompt = self.format_qwen_chat(\n",
+ " messages,\n",
+ " add_generation_prompt=self.add_generation_prompt,\n",
+ " add_thinking=self.add_thinking\n",
+ " )\n",
+ " return self.tokenizer.encode(formatted_prompt).ids\n",
+ " \n",
+ " def decode(self, token_ids):\n",
+ " return self.tokenizer.decode(token_ids, skip_special_tokens=False)\n",
+ " \n",
+ " @staticmethod\n",
+ " def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):\n",
+ " prompt = \"\"\n",
+ " for msg in messages:\n",
+ " prompt += f\"<|im_start|>{msg['role']}\\n{msg['content']}<|im_end|>\\n\"\n",
+ " if add_generation_prompt:\n",
+ " prompt += \"<|im_start|>assistant\"\n",
+ " if not add_thinking:\n",
+ " prompt += \"<|think>\\n\\n<|/think>\\n\\n\"\n",
+ " else:\n",
+ " prompt += \"\\n\" \n",
+ " return prompt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if USE_REASONING_MODEL:\n",
+ " tokenizer_file_path = \"Qwen3-0.6B/tokenizer.json\"\n",
+ "else:\n",
+ " tokenizer_file_path = \"Qwen3-0.6B-Base/tokenizer.json\"\n",
+ "\n",
+ "tokenizer = Qwen3Tokenizer(\n",
+ " tokenizer_file_path=tokenizer_file_path,\n",
+ " repo_id=repo_id,\n",
+ " add_generation_prompt=USE_REASONING_MODEL,\n",
+ " add_thinking=USE_REASONING_MODEL\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "1946b534-e3af-431a-a222-391a60bfa892",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'<|im_start|>user\\nGive me a short introduction to large language models.<|im_end|>\\n<|im_start|>assistant\\n'"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "prompt = \"Give me a short introduction to large language models.\"\n",
+ "\n",
+ "input_token_ids = tokenizer.encode(prompt)\n",
+ "text = tokenizer.decode(input_token_ids)\n",
+ "text"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "57d07df1-4401-4792-b549-7c4cc5632323",
+ "metadata": {
+ "id": "57d07df1-4401-4792-b549-7c4cc5632323"
+ },
+ "source": [
+ " \n",
+ "# 5. Generate text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
+ "metadata": {
+ "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
+ },
+ "outputs": [],
+ "source": [
+ "# Identical function from chapter 5\n",
+ "\n",
+ "def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n",
+ " # For-loop is the same as before: Get logits, and only focus on last time step\n",
+ " for _ in range(max_new_tokens):\n",
+ " idx_cond = idx[:, -context_size:]\n",
+ " with torch.no_grad():\n",
+ " logits = model(idx_cond)\n",
+ " logits = logits[:, -1, :]\n",
+ "\n",
+ " # Filter logits with top_k sampling\n",
+ " if top_k is not None:\n",
+ " # Keep only top_k values\n",
+ " top_logits, _ = torch.topk(logits, top_k)\n",
+ " min_val = top_logits[:, -1]\n",
+ " logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
+ "\n",
+ " # pply temperature scaling\n",
+ " if temperature > 0.0:\n",
+ " logits = logits / temperature\n",
+ "\n",
+ " # Apply softmax to get probabilities\n",
+ " probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
+ "\n",
+ " # Sample from the distribution\n",
+ " idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)\n",
+ "\n",
+ " # Otherwise same as before: get idx of the vocab entry with the highest logits value\n",
+ " else:\n",
+ " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n",
+ "\n",
+ " if eos_id is not None and idx_next.item() == eos_id:\n",
+ " break # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n",
+ "\n",
+ " # Same as before: append sampled index to the running sequence\n",
+ " idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n",
+ "\n",
+ " return idx"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
+ "metadata": {
+ "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Time: 21.96 sec\n",
+ "<|im_start|>user\n",
+ "Give me a short introduction to large language models.<|im_end|>\n",
+ "<|im_start|>assistant\n",
+ "\n",
+ "Okay, the user wants a short introduction to large language models. Let me start by recalling what I know. Large language models are AI systems that can understand and generate human language. They're trained on massive datasets, so they can learn complex patterns and nuances.\n",
+ "\n",
+ "I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or language assistants. Also, emphasize their adaptability and versatility. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around 3-4 sentences. Make sure it's clear and easy to understand.\n",
+ "\n",
+ "\n",
+ "Large language models (LLMs) are AI systems designed...\n"
+ ]
+ }
+ ],
+ "source": [
+ "import time\n",
+ "\n",
+ "torch.manual_seed(123)\n",
+ "\n",
+ "start = time.time()\n",
+ "\n",
+ "output_token_ids = generate(\n",
+ " model=model,\n",
+ " idx=torch.tensor(input_token_ids, device=device).unsqueeze(0),\n",
+ " max_new_tokens=150,\n",
+ " context_size=QWEN3_CONFIG[\"context_length\"],\n",
+ " top_k=1,\n",
+ " temperature=0.\n",
+ ")\n",
+ "\n",
+ "print(f\"Time: {time.time() - start:.2f} sec\")\n",
+ "\n",
+ "if torch.cuda.is_available():\n",
+ " max_mem_bytes = torch.cuda.max_memory_allocated()\n",
+ " max_mem_gb = max_mem_bytes / (1024 ** 3)\n",
+ " print(f\"Max memory allocated: {max_mem_gb:.2f} GB\")\n",
+ "\n",
+ "output_text = tokenizer.decode(output_token_ids.squeeze(0).tolist())\n",
+ "\n",
+ "print(output_text + \"...\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "549324d6-5c71-4147-ae21-2e67675faa3d",
+ "metadata": {
+ "id": "549324d6-5c71-4147-ae21-2e67675faa3d"
+ },
+ "source": [
+ " \n",
+ "# What's next?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
+ "metadata": {
+ "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
+ },
+ "source": [
+ "- Check out the [README.md](./README.md), to use this model via the `llms_from_scratch` package\n",
+ "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
+ "\n",
+ "
"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "A100",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.16"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "0dccd57dcc5c43a588157cef957c07e8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "17a3174e65c54476b2e0d1faf8f011ca": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_90a79523187446dfa692723b2e5833a7",
+ "placeholder": "",
+ "style": "IPY_MODEL_431ffb83b8c14bf182f0430e07ea6154",
+ "tabbable": null,
+ "tooltip": null,
+ "value": "model.safetensors: 35%"
+ }
+ },
+ "1bbf2e62c0754d1593beb4105a7f1ac1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "ProgressView",
+ "bar_style": "",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_a8f1b72a33dd4b548de23fbd95e0da18",
+ "max": 2471645608,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_25cc36132d384189acfbecc59483134b",
+ "tabbable": null,
+ "tooltip": null,
+ "value": 880803840
+ }
+ },
+ "25cc36132d384189acfbecc59483134b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "271e2bd6a35e4a8b92de8697f7c0be5f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3326b6141a1a4eba9f316df528a9b99a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "33ca0cdf2c7f41598a381c4ebe6a4ee1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "431ffb83b8c14bf182f0430e07ea6154": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "5176834aa8784bba9ec21234b87a8948": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "518fb202e4b44aaba47f07d1a61b6762": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_e2dc407afcd945c798e30597fddfcb3c",
+ "placeholder": "",
+ "style": "IPY_MODEL_0dccd57dcc5c43a588157cef957c07e8",
+ "tabbable": null,
+ "tooltip": null,
+ "value": "tokenizer.model: 100%"
+ }
+ },
+ "672cdc5aea954de3af851c001a667ad3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_33ca0cdf2c7f41598a381c4ebe6a4ee1",
+ "max": 2183982,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_ee44487f58454dacb522b1e084ffb733",
+ "tabbable": null,
+ "tooltip": null,
+ "value": 2183982
+ }
+ },
+ "90a79523187446dfa692723b2e5833a7": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9881b6995c3f49dc89e6992fd9ab660b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_17a3174e65c54476b2e0d1faf8f011ca",
+ "IPY_MODEL_1bbf2e62c0754d1593beb4105a7f1ac1",
+ "IPY_MODEL_b82112e1dec645d98aa1c1ba64abcb61"
+ ],
+ "layout": "IPY_MODEL_271e2bd6a35e4a8b92de8697f7c0be5f",
+ "tabbable": null,
+ "tooltip": null
+ }
+ },
+ "a1608feac06d4687967a3e398f01c489": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_518fb202e4b44aaba47f07d1a61b6762",
+ "IPY_MODEL_672cdc5aea954de3af851c001a667ad3",
+ "IPY_MODEL_eebf8874618746b39cf4a21a2728dc7f"
+ ],
+ "layout": "IPY_MODEL_5176834aa8784bba9ec21234b87a8948",
+ "tabbable": null,
+ "tooltip": null
+ }
+ },
+ "a8f1b72a33dd4b548de23fbd95e0da18": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b82112e1dec645d98aa1c1ba64abcb61": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_bfd06423ad544218968648016e731a46",
+ "placeholder": "",
+ "style": "IPY_MODEL_d029630b63ff44cf807ade428d2eb421",
+ "tabbable": null,
+ "tooltip": null,
+ "value": " 870M/2.47G [00:20<00:37, 42.8MB/s]"
+ }
+ },
+ "bfd06423ad544218968648016e731a46": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d029630b63ff44cf807ade428d2eb421": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "d2c41e71a3f441deaed091b620ac5603": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e2dc407afcd945c798e30597fddfcb3c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "ee44487f58454dacb522b1e084ffb733": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "eebf8874618746b39cf4a21a2728dc7f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_d2c41e71a3f441deaed091b620ac5603",
+ "placeholder": "",
+ "style": "IPY_MODEL_3326b6141a1a4eba9f316df528a9b99a",
+ "tabbable": null,
+ "tooltip": null,
+ "value": " 2.18M/2.18M [00:00<00:00, 9.47MB/s]"
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/ch05/README.md b/ch05/README.md
index 1e908be..b0ee9d7 100644
--- a/ch05/README.md
+++ b/ch05/README.md
@@ -17,6 +17,7 @@
- [08_memory_efficient_weight_loading](08_memory_efficient_weight_loading) contains a bonus notebook showing how to load model weights via PyTorch's `load_state_dict` method more efficiently
- [09_extending-tokenizers](09_extending-tokenizers) contains a from-scratch implementation of the GPT-2 BPE tokenizer
- [10_llm-training-speed](10_llm-training-speed) shows PyTorch performance tips to improve the LLM training speed
+- [11_qwen3](11_qwen3) A from-scratch implementation of Qwen3 0.6B including code to load the pretrained weights of the base and reasoning model variants
diff --git a/pkg/llms_from_scratch/README.md b/pkg/llms_from_scratch/README.md
index dc423b6..fceefbf 100644
--- a/pkg/llms_from_scratch/README.md
+++ b/pkg/llms_from_scratch/README.md
@@ -113,7 +113,7 @@ from llms_from_scratch.appendix_d import find_highest_gradient, train_model
```
-
+
### Llama 3 (Bonus material)
```python
@@ -126,5 +126,18 @@ from llms_from_scratch.llama3 import (
)
```
-
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
+
+
+
+### Qwen3 (Bonus material)
+
+```python
+from llms_from_scratch.qwen3 import (
+ Qwen3Model,
+ Qwen3Tokenizer,
+)
+```
+
+
+For the `llms_from_scratch.qwen3` usage information, please see [this bonus section](../../ch05/11_qwen3/README.md).
diff --git a/pkg/llms_from_scratch/qwen3.py b/pkg/llms_from_scratch/qwen3.py
new file mode 100644
index 0000000..968a473
--- /dev/null
+++ b/pkg/llms_from_scratch/qwen3.py
@@ -0,0 +1,393 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+import os
+import urllib.request
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+
+# 0.6B model
+QWEN_CONFIG_06_B = {
+ "vocab_size": 151_936, # Vocabulary size
+ "context_length": 40_960, # Context length that was used to train the model
+ "emb_dim": 1024, # Embedding dimension
+ "n_heads": 16, # Number of attention heads
+ "n_layers": 28, # Number of layers
+ "hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
+ "head_dim": 128, # Size of the heads in GQA
+ "qk_norm": True, # Whether to normalize queries and values in GQA
+ "n_kv_groups": 8, # Key-Value groups for grouped-query attention
+ "rope_base": 1_000_000.0, # The base in RoPE's "theta"
+ "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
+}
+
+
+class Qwen3Model(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ # Main model parameters
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
+
+ self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
+ [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
+ )
+ self.final_norm = RMSNorm(cfg["emb_dim"])
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
+
+ # Reusuable utilities
+ if cfg["head_dim"] is None:
+ head_dim = cfg["emb_dim"] // cfg["n_heads"]
+ else:
+ head_dim = cfg["head_dim"]
+ cos, sin = compute_rope_params(
+ head_dim=head_dim,
+ theta_base=cfg["rope_base"],
+ context_length=cfg["context_length"]
+ )
+ self.register_buffer("cos", cos, persistent=False)
+ self.register_buffer("sin", sin, persistent=False)
+ self.cfg = cfg
+
+ def forward(self, in_idx):
+ # Forward pass
+ tok_embeds = self.tok_emb(in_idx)
+ x = tok_embeds
+
+ num_tokens = x.shape[1]
+ mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
+
+ for block in self.trf_blocks:
+ x = block(x, mask, self.cos, self.sin)
+ x = self.final_norm(x)
+ logits = self.out_head(x.to(self.cfg["dtype"]))
+ return logits
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.att = GroupedQueryAttention(
+ d_in=cfg["emb_dim"],
+ num_heads=cfg["n_heads"],
+ head_dim=cfg["head_dim"],
+ num_kv_groups=cfg["n_kv_groups"],
+ qk_norm=cfg["qk_norm"],
+ dtype=cfg["dtype"]
+ )
+ self.ff = FeedForward(cfg)
+ self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
+ self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
+
+ def forward(self, x, mask, cos, sin):
+ # Shortcut connection for attention block
+ shortcut = x
+ x = self.norm1(x)
+ x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]
+ x = x + shortcut # Add the original input back
+
+ # Shortcut connection for feed-forward block
+ shortcut = x
+ x = self.norm2(x)
+ x = self.ff(x)
+ x = x + shortcut # Add the original input back
+
+ return x
+
+
+class FeedForward(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
+ self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
+ self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
+
+ def forward(self, x):
+ x_fc1 = self.fc1(x)
+ x_fc2 = self.fc2(x)
+ x = nn.functional.silu(x_fc1) * x_fc2
+ return self.fc3(x)
+
+
+class GroupedQueryAttention(nn.Module):
+ def __init__(
+ self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
+ ):
+ super().__init__()
+ assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
+
+ self.num_heads = num_heads
+ self.num_kv_groups = num_kv_groups
+ self.group_size = num_heads // num_kv_groups
+
+ if head_dim is None:
+ assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
+ head_dim = d_in // num_heads
+
+ self.head_dim = head_dim
+ self.d_out = num_heads * head_dim
+
+ self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
+ self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
+ self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
+
+ self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)
+
+ if qk_norm:
+ self.q_norm = RMSNorm(head_dim, eps=1e-6)
+ self.k_norm = RMSNorm(head_dim, eps=1e-6)
+ else:
+ self.q_norm = self.k_norm = None
+
+ def forward(self, x, mask, cos, sin):
+ b, num_tokens, _ = x.shape
+
+ # Apply projections
+ queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
+ keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
+ values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
+
+ # Reshape
+ queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
+ values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
+
+ # Optional normalization
+ if self.q_norm:
+ queries = self.q_norm(queries)
+ if self.k_norm:
+ keys = self.k_norm(keys)
+
+ # Apply RoPE
+ queries = apply_rope(queries, cos, sin)
+ keys = apply_rope(keys, cos, sin)
+
+ # Expand K and V to match number of heads
+ keys = keys.repeat_interleave(self.group_size, dim=1)
+ values = values.repeat_interleave(self.group_size, dim=1)
+
+ # Attention
+ attn_scores = queries @ keys.transpose(2, 3)
+ attn_scores = attn_scores.masked_fill(mask, -torch.inf)
+ attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
+
+ context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
+ return self.out_proj(context)
+
+
+def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
+ assert head_dim % 2 == 0, "Embedding dimension must be even"
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
+
+ # Generate position indices
+ positions = torch.arange(context_length, dtype=dtype)
+
+ # Compute the angles
+ angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
+
+ # Expand angles to match the head_dim
+ angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
+
+ # Precompute sine and cosine
+ cos = torch.cos(angles)
+ sin = torch.sin(angles)
+
+ return cos, sin
+
+
+def apply_rope(x, cos, sin):
+ # x: (batch_size, num_heads, seq_len, head_dim)
+ batch_size, num_heads, seq_len, head_dim = x.shape
+ assert head_dim % 2 == 0, "Head dimension must be even"
+
+ # Split x into first half and second half
+ x1 = x[..., : head_dim // 2] # First half
+ x2 = x[..., head_dim // 2:] # Second half
+
+ # Adjust sin and cos shapes
+ cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
+ sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
+
+ # Apply the rotary transformation
+ rotated = torch.cat((-x2, x1), dim=-1)
+ x_rotated = (x * cos) + (rotated * sin)
+
+ # It's ok to use lower-precision after applying cos and sin rotation
+ return x_rotated.to(dtype=x.dtype)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
+ super().__init__()
+ self.eps = eps
+ self.qwen3_compatible = qwen3_compatible
+ self.scale = nn.Parameter(torch.ones(emb_dim))
+ self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
+
+ def forward(self, x):
+ input_dtype = x.dtype
+
+ if self.qwen3_compatible:
+ x = x.to(torch.float32)
+
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
+ norm_x = x * torch.rsqrt(variance + self.eps)
+ norm_x = norm_x * self.scale
+
+ if self.shift is not None:
+ norm_x = norm_x + self.shift
+
+ return norm_x.to(input_dtype)
+
+
+def load_weights_into_qwen(model, param_config, params):
+ def assign(left, right, tensor_name="unknown"):
+ if left.shape != right.shape:
+ raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
+ return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
+
+ model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
+
+ for l in range(param_config["n_layers"]):
+ block = model.trf_blocks[l]
+ att = block.att
+
+ # Q, K, V projections
+ att.W_query.weight = assign(
+ att.W_query.weight,
+ params[f"model.layers.{l}.self_attn.q_proj.weight"],
+ f"model.layers.{l}.self_attn.q_proj.weight"
+ )
+ att.W_key.weight = assign(
+ att.W_key.weight,
+ params[f"model.layers.{l}.self_attn.k_proj.weight"],
+ f"model.layers.{l}.self_attn.k_proj.weight"
+ )
+ att.W_value.weight = assign(
+ att.W_value.weight,
+ params[f"model.layers.{l}.self_attn.v_proj.weight"],
+ f"model.layers.{l}.self_attn.v_proj.weight"
+ )
+
+ # Output projection
+ att.out_proj.weight = assign(
+ att.out_proj.weight,
+ params[f"model.layers.{l}.self_attn.o_proj.weight"],
+ f"model.layers.{l}.self_attn.o_proj.weight"
+ )
+
+ # QK norms
+ if hasattr(att, "q_norm") and att.q_norm is not None:
+ att.q_norm.scale = assign(
+ att.q_norm.scale,
+ params[f"model.layers.{l}.self_attn.q_norm.weight"],
+ f"model.layers.{l}.self_attn.q_norm.weight"
+ )
+ if hasattr(att, "k_norm") and att.k_norm is not None:
+ att.k_norm.scale = assign(
+ att.k_norm.scale,
+ params[f"model.layers.{l}.self_attn.k_norm.weight"],
+ f"model.layers.{l}.self_attn.k_norm.weight"
+ )
+
+ # Attention layernorm
+ block.norm1.scale = assign(
+ block.norm1.scale,
+ params[f"model.layers.{l}.input_layernorm.weight"],
+ f"model.layers.{l}.input_layernorm.weight"
+ )
+
+ # Feedforward weights
+ block.ff.fc1.weight = assign(
+ block.ff.fc1.weight,
+ params[f"model.layers.{l}.mlp.gate_proj.weight"],
+ f"model.layers.{l}.mlp.gate_proj.weight"
+ )
+ block.ff.fc2.weight = assign(
+ block.ff.fc2.weight,
+ params[f"model.layers.{l}.mlp.up_proj.weight"],
+ f"model.layers.{l}.mlp.up_proj.weight"
+ )
+ block.ff.fc3.weight = assign(
+ block.ff.fc3.weight,
+ params[f"model.layers.{l}.mlp.down_proj.weight"],
+ f"model.layers.{l}.mlp.down_proj.weight"
+ )
+ block.norm2.scale = assign(
+ block.norm2.scale,
+ params[f"model.layers.{l}.post_attention_layernorm.weight"],
+ f"model.layers.{l}.post_attention_layernorm.weight"
+ )
+
+ # Final normalization and output head
+ model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
+
+ # Model uses weight tying, hence we reuse the embedding layer weights here
+ model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
+
+
+class Qwen3Tokenizer():
+ def __init__(self, tokenizer_file_path="tokenizer.json",
+ repo_id=None, add_generation_prompt=False, add_thinking=False):
+ from tokenizers import Tokenizer
+ self.tokenizer_file_path = tokenizer_file_path
+
+ if add_generation_prompt != add_thinking:
+ raise ValueError(
+ "Only add_generation_prompt==add_thinking settings are currently supported"
+ )
+
+ self.add_generation_prompt = add_generation_prompt
+ self.add_thinking = add_thinking
+
+ tokenizer_file_path_obj = Path(tokenizer_file_path)
+ if not tokenizer_file_path_obj.is_file() and repo_id is not None:
+ _ = download_from_huggingface(
+ repo_id=repo_id,
+ filename=str(tokenizer_file_path_obj.name),
+ local_dir=str(tokenizer_file_path_obj.parent.name)
+ )
+ self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
+
+ def encode(self, prompt):
+ messages = [
+ {"role": "user", "content": prompt}
+ ]
+ formatted_prompt = self.format_qwen_chat(
+ messages,
+ add_generation_prompt=self.add_generation_prompt,
+ add_thinking=self.add_thinking
+ )
+ return self.tokenizer.encode(formatted_prompt).ids
+
+ def decode(self, token_ids):
+ return self.tokenizer.decode(token_ids, skip_special_tokens=False)
+
+ @staticmethod
+ def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
+ prompt = ""
+ for msg in messages:
+ prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
+ if add_generation_prompt:
+ prompt += "<|im_start|>assistant"
+ if not add_thinking:
+ prompt += "<|think>\n\n<|/think>\n\n"
+ else:
+ prompt += "\n"
+ return prompt
+
+
+def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
+ base_url = "https://huggingface.co"
+ url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
+ Path(local_dir).mkdir(parents=True, exist_ok=True)
+ dest_path = os.path.join(local_dir, filename)
+ print(f"Downloading {url} to {dest_path}...")
+ urllib.request.urlretrieve(url, dest_path)
+ return dest_path
diff --git a/pkg/llms_from_scratch/tests/test_llama3.py b/pkg/llms_from_scratch/tests/test_llama3.py
index 1719976..9f6d48c 100644
--- a/pkg/llms_from_scratch/tests/test_llama3.py
+++ b/pkg/llms_from_scratch/tests/test_llama3.py
@@ -19,6 +19,36 @@ import tiktoken
import torch
+class LitGPTRMSNorm(torch.nn.Module):
+ """Root Mean Square Layer Normalization.
+
+ From https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
+ Apache License 2.0-Clause License: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
+
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
+ """
+
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None:
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(size))
+ self.eps = eps
+ self.dim = dim
+ self.add_unit_offset = add_unit_offset
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ dtype = x.dtype
+ x = x.float()
+ # NOTE: the original RMSNorm paper implementation is not equivalent
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
+ weight = (1 + self.weight) if self.add_unit_offset else self.weight
+ return (x_normed * weight.float()).to(dtype=dtype)
+
+ def reset_parameters(self) -> None:
+ torch.nn.init.ones_(self.weight)
+
+
transformers_installed = importlib.util.find_spec("transformers") is not None
@@ -179,3 +209,25 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
[43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
])
assert torch.equal(expect, out)
+
+
+def test_rmsnorm_equivalence():
+ torch.manual_seed(42)
+
+ hidden_size = 64
+ batch_size = 8
+ seq_len = 16
+
+ rms_norm = torch.nn.RMSNorm(hidden_size, eps=1e-6)
+ lit_norm = LitGPTRMSNorm(hidden_size)
+
+ # Sync weights
+ with torch.no_grad():
+ lit_norm.weight.copy_(lit_norm.weight)
+
+ x = torch.randn(batch_size, seq_len, hidden_size)
+
+ out1 = rms_norm(x)
+ out2 = lit_norm(x)
+
+ torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py
new file mode 100644
index 0000000..f78cc1b
--- /dev/null
+++ b/pkg/llms_from_scratch/tests/test_qwen3.py
@@ -0,0 +1,194 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+from llms_from_scratch.ch04 import generate_text_simple
+from llms_from_scratch.qwen3 import (
+ compute_rope_params,
+ apply_rope,
+ QWEN_CONFIG_06_B,
+ RMSNorm,
+ Qwen3Model,
+ Qwen3Tokenizer
+)
+
+import importlib
+import pytest
+import tiktoken
+import torch
+import torch.nn as nn
+
+
+class Qwen3RMSNorm(nn.Module):
+ # Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py
+ # License: Apache License, Version 2.0 (see file above)
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Qwen3RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ print(input_dtype)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+transformers_installed = importlib.util.find_spec("transformers") is not None
+
+
+@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
+def test_rope():
+
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
+
+ # Settings
+ batch_size = 1
+ context_len = 8192
+ num_heads = 4
+ head_dim = 16
+ rope_theta = 1_000_000
+
+ # Instantiate RoPE parameters
+ cos, sin = compute_rope_params(
+ head_dim=head_dim,
+ theta_base=rope_theta,
+ context_length=context_len,
+ )
+
+ # Dummy query and key tensors
+ torch.manual_seed(123)
+ queries = torch.randn(batch_size, num_heads, context_len, head_dim)
+ keys = torch.randn(batch_size, num_heads, context_len, head_dim)
+
+ # Apply rotary position embeddings
+ queries_rot = apply_rope(queries, cos, sin)
+ keys_rot = apply_rope(keys, cos, sin)
+
+ # Generate reference RoPE via HF
+ class RoPEConfig:
+ rope_type = "qwen3"
+ factor = 1.0
+ dim: int = head_dim
+ rope_theta = 1_000_000
+ max_position_embeddings: int = 8192
+ hidden_size = head_dim * num_heads
+ num_attention_heads = num_heads
+
+ config = RoPEConfig()
+
+ rot_emb = Qwen3RotaryEmbedding(config=config)
+ position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
+ ref_cos, ref_sin = rot_emb(queries, position_ids)
+ ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
+
+ torch.testing.assert_close(sin, ref_sin.squeeze(0))
+ torch.testing.assert_close(cos, ref_cos.squeeze(0))
+ torch.testing.assert_close(keys_rot, ref_keys_rot)
+ torch.testing.assert_close(queries_rot, ref_queries_rot)
+
+
+@pytest.fixture(scope="session")
+def qwen3_weights_path(tmp_path_factory):
+ """Creates and saves a deterministic Llama3 model for testing."""
+ path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
+
+ if not path.exists():
+ torch.manual_seed(123)
+ model = Qwen3Model(QWEN_CONFIG_06_B)
+ torch.save(model.state_dict(), path)
+
+ return path
+
+
+@pytest.mark.parametrize("ModelClass", [Qwen3Model])
+def test_gpt_model_variants(ModelClass, qwen3_weights_path):
+ torch.manual_seed(123)
+ model = ModelClass(QWEN_CONFIG_06_B)
+ model.load_state_dict(torch.load(qwen3_weights_path))
+ model.eval()
+
+ start_context = "Llamas eat"
+
+ tokenizer = tiktoken.get_encoding("gpt2")
+ encoded = tokenizer.encode(start_context)
+ encoded_tensor = torch.tensor(encoded).unsqueeze(0)
+
+ print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
+ print("\nInput text:", start_context)
+ print("Encoded input text:", encoded)
+ print("encoded_tensor.shape:", encoded_tensor.shape)
+
+ out = generate_text_simple(
+ model=model,
+ idx=encoded_tensor,
+ max_new_tokens=5,
+ context_size=QWEN_CONFIG_06_B["context_length"]
+ )
+ print("Encoded output text:", out)
+ expect = torch.tensor([
+ [43, 2543, 292, 4483, 115206, 459, 43010, 104223, 55553]
+ ])
+ assert torch.equal(expect, out)
+
+
+def test_rmsnorm_equivalence():
+ torch.manual_seed(42)
+
+ hidden_size = 64
+ batch_size = 8
+ seq_len = 16
+
+ rms_norm = RMSNorm(hidden_size)
+ ref_norm = Qwen3RMSNorm(hidden_size)
+
+ # Sync weights
+ with torch.no_grad():
+ ref_norm.weight.copy_(ref_norm.weight)
+
+ x = torch.randn(batch_size, seq_len, hidden_size)
+
+ out1 = rms_norm(x)
+ out2 = ref_norm(x)
+
+ torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
+
+
+@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
+def test_tokenizer_equivalence():
+ from transformers import AutoTokenizer
+ repo_id = "Qwen/Qwen3-0.6B"
+ tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
+ prompt = "Give me a short introduction to large language models."
+ messages = [
+ {"role": "user", "content": prompt},
+ ]
+
+ for states in ((True, True), (False, False)):
+ tokenizer = Qwen3Tokenizer(
+ tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
+ repo_id=repo_id,
+ add_generation_prompt=states[0],
+ add_thinking=states[1]
+ )
+ input_token_ids = tokenizer.encode(prompt)
+ input_token_ids_ref = tokenizer_ref.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=states[0],
+ enable_thinking=states[1],
+ )
+ assert input_token_ids == input_token_ids_ref, states
+
+ output_text = tokenizer.decode(input_token_ids)
+ out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
+ assert output_text == out_text_ref, states
diff --git a/pyproject.toml b/pyproject.toml
index 52b6e28..d1ecd5f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llms-from-scratch"
-version = "1.0.7"
+version = "1.0.9"
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
readme = "README.md"
requires-python = ">=3.10"