mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-30 19:41:51 +00:00
Qwen3 From Scratch (#678)
* Qwen3 From Scratch * rev other file * upd * upd * upd * url fixes
This commit is contained in:
parent
e700c66b7a
commit
3d4bce6d57
7
.gitignore
vendored
7
.gitignore
vendored
@ -48,12 +48,13 @@ ch05/07_gpt_to_llama/Llama-3.2-1B
|
||||
ch05/07_gpt_to_llama/Llama-3.2-1B-Instruct
|
||||
ch05/07_gpt_to_llama/Llama-3.2-3B
|
||||
ch05/07_gpt_to_llama/Llama-3.2-3B-Instruct
|
||||
ch05/07_gpt_to_llama/llama3.2-1B-instruct.pth
|
||||
ch05/07_gpt_to_llama/tokenizer.model
|
||||
ch05/10_llm-training-speed/middlemarch.txt
|
||||
ch05/10_llm-training-speed/loss.pdf
|
||||
ch05/10_llm-training-speed/model.pth
|
||||
ch05/07_gpt_to_llama/Untitled.ipynb
|
||||
ch05/07_gpt_to_llama/llama3.2-1B-instruct.pth
|
||||
ch05/07_gpt_to_llama/tokenizer.model
|
||||
ch05/11_qwen3/Qwen3-0.6B
|
||||
ch05/11_qwen3/Qwen3-0.6B-Base
|
||||
|
||||
ch06/01_main-chapter-code/gpt2
|
||||
ch06/02_bonus_additional-experiments/gpt2
|
||||
|
@ -121,6 +121,7 @@ Several folders contain optional materials as a bonus for interested readers:
|
||||
- [Building a User Interface to Interact With the Pretrained LLM](ch05/06_user_interface)
|
||||
- [Converting GPT to Llama](ch05/07_gpt_to_llama)
|
||||
- [Llama 3.2 From Scratch](ch05/07_gpt_to_llama/standalone-llama32.ipynb)
|
||||
- [Qwen3 From Scratch](ch05/11_qwen3/standalone-qwen3.ipynb)
|
||||
- [Memory-efficient Model Weight Loading](ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb)
|
||||
- [Extending the Tiktoken BPE Tokenizer with New Tokens](ch05/09_extending-tokenizers/extend-tiktoken.ipynb)
|
||||
- [PyTorch Performance Tips for Faster LLM Training](ch05/10_llm-training-speed)
|
||||
|
191
ch05/11_qwen3/README.md
Normal file
191
ch05/11_qwen3/README.md
Normal file
@ -0,0 +1,191 @@
|
||||
# Qwen3 From Scratch
|
||||
|
||||
This [standalone-qwen3.ipynb](standalone-qwen3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Qwen3 0.6B.
|
||||
|
||||
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen-overview.webp">
|
||||
|
||||
|
||||
|
||||
### Using Qwen3 0.6B via the `llms-from-scratch` package
|
||||
|
||||
For an easy way to use the Qwen3 from-scratch implementation, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
|
||||
|
||||
|
||||
#### 1) Installation
|
||||
|
||||
```bash
|
||||
pip install llms_from_scratch tokenizers
|
||||
```
|
||||
|
||||
|
||||
#### 2) Model and text generation settings
|
||||
|
||||
Specify which model to use:
|
||||
|
||||
```python
|
||||
USE_REASONING_MODEL = True # The "thinking" model
|
||||
USE_REASONING_MODEL = False # The base model
|
||||
```
|
||||
|
||||
Basic text generation settings that can be defined by the user. With 150 tokens, the model requires approximately 1.5 GB memory.
|
||||
|
||||
```python
|
||||
MAX_NEW_TOKENS = 150
|
||||
TEMPERATURE = 0.
|
||||
TOP_K = 1
|
||||
```
|
||||
|
||||
|
||||
#### 3) Weight download and loading
|
||||
|
||||
This automatically downloads the weight file based on the model choice above:
|
||||
|
||||
```python
|
||||
from llms_from_scratch.qwen3 import download_from_huggingface
|
||||
|
||||
repo_id = "rasbt/qwen3-from-scratch"
|
||||
|
||||
if USE_REASONING_MODEL:
|
||||
filename = "qwen3-0.6B.pth"
|
||||
local_dir = "Qwen3-0.6B"
|
||||
else:
|
||||
filename = "qwen3-0.6B-base.pth"
|
||||
local_dir = "Qwen3-0.6B-Base"
|
||||
|
||||
download_from_huggingface(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
local_dir=local_dir
|
||||
)
|
||||
```
|
||||
|
||||
The model weights are then loaded as follows:
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
from llms_from_scratch.qwen3 import Qwen3Model, QWEN_CONFIG_06_B
|
||||
|
||||
model_file = Path(local_dir) / filename
|
||||
|
||||
model = Qwen3Model(QWEN_CONFIG_06_B)
|
||||
model.load_state_dict(torch.load(model_file, weights_only=True, map_location="cpu"))
|
||||
|
||||
device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else
|
||||
torch.device("mps") if torch.backends.mps.is_available() else
|
||||
torch.device("cpu")
|
||||
)
|
||||
model.to(device)
|
||||
```
|
||||
|
||||
|
||||
#### 4) Initialize tokenizer
|
||||
|
||||
The following code downloads and initializes the tokenizer:
|
||||
|
||||
```python
|
||||
from llms_from_scratch.qwen3 import Qwen3Tokenizer
|
||||
|
||||
if USE_REASONING_MODEL:
|
||||
tok_filename = "tokenizer.json"
|
||||
else:
|
||||
tok_filename = "tokenizer-base.json"
|
||||
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path=tok_filename,
|
||||
repo_id=repo_id,
|
||||
add_generation_prompt=USE_REASONING_MODEL,
|
||||
add_thinking=USE_REASONING_MODEL
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#### 5) Generating text
|
||||
|
||||
Lastly, we can generate text via the following code:
|
||||
|
||||
```python
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
```python
|
||||
from llms_from_scratch.ch05 import generate
|
||||
import time
|
||||
|
||||
torch.manual_seed(123)
|
||||
|
||||
start = time.time()
|
||||
|
||||
output_token_ids = generate(
|
||||
model=model,
|
||||
idx=torch.tensor(input_token_ids, device=device).unsqueeze(0),
|
||||
max_new_tokens=150,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"],
|
||||
top_k=1,
|
||||
temperature=0.
|
||||
)
|
||||
|
||||
total_time = time.time() - start
|
||||
print(f"Time: {total_time:.2f} sec")
|
||||
print(f"{int(len(output_token_ids[0])/total_time)} tokens/sec")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
max_mem_bytes = torch.cuda.max_memory_allocated()
|
||||
max_mem_gb = max_mem_bytes / (1024 ** 3)
|
||||
print(f"Max memory allocated: {max_mem_gb:.2f} GB")
|
||||
|
||||
output_text = tokenizer.decode(output_token_ids.squeeze(0).tolist())
|
||||
|
||||
print("\n\nOutput text:\n\n", output_text + "...")
|
||||
```
|
||||
|
||||
When using the Qwen3 0.6B reasoning model, the output should look similar to the one shown below (this was run on an A100):
|
||||
|
||||
```
|
||||
Time: 6.35 sec
|
||||
25 tokens/sec
|
||||
Max memory allocated: 1.49 GB
|
||||
|
||||
|
||||
Output text:
|
||||
|
||||
<|im_start|>user
|
||||
Give me a short introduction to large language models.<|im_end|>
|
||||
Large language models (LLMs) are advanced artificial intelligence systems designed to generate human-like text. They are trained on vast amounts of text data, allowing them to understand and generate coherent, contextually relevant responses. LLMs are used in a variety of applications, including chatbots, virtual assistants, content generation, and more. They are powered by deep learning algorithms and can be fine-tuned for specific tasks, making them versatile tools for a wide range of industries.<|endoftext|>Human resources department of a company is planning to hire 100 new employees. The company has a budget of $100,000 for the recruitment process. The company has a minimum wage of $10 per hour. The company has a total of...
|
||||
```
|
||||
|
||||
|
||||
#### Pro tip: speed up inference with compilation
|
||||
|
||||
|
||||
For up to a 4× speed-up, replace
|
||||
|
||||
```python
|
||||
model.to(device)
|
||||
```
|
||||
|
||||
with
|
||||
|
||||
```python
|
||||
model = torch.compile(model)
|
||||
model.to(device)
|
||||
```
|
||||
|
||||
Note: There is a significant multi-minute upfront cost when compiling, and the speed-up takes effect after the first `generate` call.
|
||||
|
||||
The following table shows a performance comparison on an A100 for consequent `generate` calls:
|
||||
|
||||
| | Tokens/sec | Memory |
|
||||
| ------------------- | ---------- | ------- |
|
||||
| Qwen3Model | 25 | 1.49 GB |
|
||||
| Qwen3Model compiled | 101 | 1.99 GB |
|
1788
ch05/11_qwen3/standalone-qwen3.ipynb
Normal file
1788
ch05/11_qwen3/standalone-qwen3.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@ -17,6 +17,7 @@
|
||||
- [08_memory_efficient_weight_loading](08_memory_efficient_weight_loading) contains a bonus notebook showing how to load model weights via PyTorch's `load_state_dict` method more efficiently
|
||||
- [09_extending-tokenizers](09_extending-tokenizers) contains a from-scratch implementation of the GPT-2 BPE tokenizer
|
||||
- [10_llm-training-speed](10_llm-training-speed) shows PyTorch performance tips to improve the LLM training speed
|
||||
- [11_qwen3](11_qwen3) A from-scratch implementation of Qwen3 0.6B including code to load the pretrained weights of the base and reasoning model variants
|
||||
|
||||
|
||||
|
||||
|
@ -113,7 +113,7 @@ from llms_from_scratch.appendix_d import find_highest_gradient, train_model
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
### Llama 3 (Bonus material)
|
||||
|
||||
```python
|
||||
@ -126,5 +126,18 @@ from llms_from_scratch.llama3 import (
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
|
||||
|
||||
|
||||
|
||||
### Qwen3 (Bonus material)
|
||||
|
||||
```python
|
||||
from llms_from_scratch.qwen3 import (
|
||||
Qwen3Model,
|
||||
Qwen3Tokenizer,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
For the `llms_from_scratch.qwen3` usage information, please see [this bonus section](../../ch05/11_qwen3/README.md).
|
||||
|
393
pkg/llms_from_scratch/qwen3.py
Normal file
393
pkg/llms_from_scratch/qwen3.py
Normal file
@ -0,0 +1,393 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
import os
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 0.6B model
|
||||
QWEN_CONFIG_06_B = {
|
||||
"vocab_size": 151_936, # Vocabulary size
|
||||
"context_length": 40_960, # Context length that was used to train the model
|
||||
"emb_dim": 1024, # Embedding dimension
|
||||
"n_heads": 16, # Number of attention heads
|
||||
"n_layers": 28, # Number of layers
|
||||
"hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
|
||||
"head_dim": 128, # Size of the heads in GQA
|
||||
"qk_norm": True, # Whether to normalize queries and values in GQA
|
||||
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
|
||||
"rope_base": 1_000_000.0, # The base in RoPE's "theta"
|
||||
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
|
||||
}
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
# Main model parameters
|
||||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
|
||||
|
||||
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
|
||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
|
||||
)
|
||||
self.final_norm = RMSNorm(cfg["emb_dim"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
# Reusuable utilities
|
||||
if cfg["head_dim"] is None:
|
||||
head_dim = cfg["emb_dim"] // cfg["n_heads"]
|
||||
else:
|
||||
head_dim = cfg["head_dim"]
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=head_dim,
|
||||
theta_base=cfg["rope_base"],
|
||||
context_length=cfg["context_length"]
|
||||
)
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, in_idx):
|
||||
# Forward pass
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
num_tokens = x.shape[1]
|
||||
mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
for block in self.trf_blocks:
|
||||
x = block(x, mask, self.cos, self.sin)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.att = GroupedQueryAttention(
|
||||
d_in=cfg["emb_dim"],
|
||||
num_heads=cfg["n_heads"],
|
||||
head_dim=cfg["head_dim"],
|
||||
num_kv_groups=cfg["n_kv_groups"],
|
||||
qk_norm=cfg["qk_norm"],
|
||||
dtype=cfg["dtype"]
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
|
||||
def forward(self, x, mask, cos, sin):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
shortcut = x
|
||||
x = self.norm2(x)
|
||||
x = self.ff(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
|
||||
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
|
||||
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x_fc1 = self.fc1(x)
|
||||
x_fc2 = self.fc2(x)
|
||||
x = nn.functional.silu(x_fc1) * x_fc2
|
||||
return self.fc3(x)
|
||||
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(
|
||||
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_groups = num_kv_groups
|
||||
self.group_size = num_heads // num_kv_groups
|
||||
|
||||
if head_dim is None:
|
||||
assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
|
||||
head_dim = d_in // num_heads
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.d_out = num_heads * head_dim
|
||||
|
||||
self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
|
||||
self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
|
||||
self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
|
||||
|
||||
self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)
|
||||
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
else:
|
||||
self.q_norm = self.k_norm = None
|
||||
|
||||
def forward(self, x, mask, cos, sin):
|
||||
b, num_tokens, _ = x.shape
|
||||
|
||||
# Apply projections
|
||||
queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
|
||||
keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
|
||||
# Reshape
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Optional normalization
|
||||
if self.q_norm:
|
||||
queries = self.q_norm(queries)
|
||||
if self.k_norm:
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
# Apply RoPE
|
||||
queries = apply_rope(queries, cos, sin)
|
||||
keys = apply_rope(keys, cos, sin)
|
||||
|
||||
# Expand K and V to match number of heads
|
||||
keys = keys.repeat_interleave(self.group_size, dim=1)
|
||||
values = values.repeat_interleave(self.group_size, dim=1)
|
||||
|
||||
# Attention
|
||||
attn_scores = queries @ keys.transpose(2, 3)
|
||||
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
|
||||
attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
|
||||
|
||||
context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
|
||||
return self.out_proj(context)
|
||||
|
||||
|
||||
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
|
||||
assert head_dim % 2 == 0, "Embedding dimension must be even"
|
||||
|
||||
# Compute the inverse frequencies
|
||||
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
|
||||
|
||||
# Generate position indices
|
||||
positions = torch.arange(context_length, dtype=dtype)
|
||||
|
||||
# Compute the angles
|
||||
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
|
||||
|
||||
# Expand angles to match the head_dim
|
||||
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
|
||||
|
||||
# Precompute sine and cosine
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
def apply_rope(x, cos, sin):
|
||||
# x: (batch_size, num_heads, seq_len, head_dim)
|
||||
batch_size, num_heads, seq_len, head_dim = x.shape
|
||||
assert head_dim % 2 == 0, "Head dimension must be even"
|
||||
|
||||
# Split x into first half and second half
|
||||
x1 = x[..., : head_dim // 2] # First half
|
||||
x2 = x[..., head_dim // 2:] # Second half
|
||||
|
||||
# Adjust sin and cos shapes
|
||||
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
|
||||
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Apply the rotary transformation
|
||||
rotated = torch.cat((-x2, x1), dim=-1)
|
||||
x_rotated = (x * cos) + (rotated * sin)
|
||||
|
||||
# It's ok to use lower-precision after applying cos and sin rotation
|
||||
return x_rotated.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.qwen3_compatible = qwen3_compatible
|
||||
self.scale = nn.Parameter(torch.ones(emb_dim))
|
||||
self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
input_dtype = x.dtype
|
||||
|
||||
if self.qwen3_compatible:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
norm_x = x * torch.rsqrt(variance + self.eps)
|
||||
norm_x = norm_x * self.scale
|
||||
|
||||
if self.shift is not None:
|
||||
norm_x = norm_x + self.shift
|
||||
|
||||
return norm_x.to(input_dtype)
|
||||
|
||||
|
||||
def load_weights_into_qwen(model, param_config, params):
|
||||
def assign(left, right, tensor_name="unknown"):
|
||||
if left.shape != right.shape:
|
||||
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
|
||||
return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
|
||||
|
||||
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
for l in range(param_config["n_layers"]):
|
||||
block = model.trf_blocks[l]
|
||||
att = block.att
|
||||
|
||||
# Q, K, V projections
|
||||
att.W_query.weight = assign(
|
||||
att.W_query.weight,
|
||||
params[f"model.layers.{l}.self_attn.q_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.q_proj.weight"
|
||||
)
|
||||
att.W_key.weight = assign(
|
||||
att.W_key.weight,
|
||||
params[f"model.layers.{l}.self_attn.k_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.k_proj.weight"
|
||||
)
|
||||
att.W_value.weight = assign(
|
||||
att.W_value.weight,
|
||||
params[f"model.layers.{l}.self_attn.v_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.v_proj.weight"
|
||||
)
|
||||
|
||||
# Output projection
|
||||
att.out_proj.weight = assign(
|
||||
att.out_proj.weight,
|
||||
params[f"model.layers.{l}.self_attn.o_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.o_proj.weight"
|
||||
)
|
||||
|
||||
# QK norms
|
||||
if hasattr(att, "q_norm") and att.q_norm is not None:
|
||||
att.q_norm.scale = assign(
|
||||
att.q_norm.scale,
|
||||
params[f"model.layers.{l}.self_attn.q_norm.weight"],
|
||||
f"model.layers.{l}.self_attn.q_norm.weight"
|
||||
)
|
||||
if hasattr(att, "k_norm") and att.k_norm is not None:
|
||||
att.k_norm.scale = assign(
|
||||
att.k_norm.scale,
|
||||
params[f"model.layers.{l}.self_attn.k_norm.weight"],
|
||||
f"model.layers.{l}.self_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
# Attention layernorm
|
||||
block.norm1.scale = assign(
|
||||
block.norm1.scale,
|
||||
params[f"model.layers.{l}.input_layernorm.weight"],
|
||||
f"model.layers.{l}.input_layernorm.weight"
|
||||
)
|
||||
|
||||
# Feedforward weights
|
||||
block.ff.fc1.weight = assign(
|
||||
block.ff.fc1.weight,
|
||||
params[f"model.layers.{l}.mlp.gate_proj.weight"],
|
||||
f"model.layers.{l}.mlp.gate_proj.weight"
|
||||
)
|
||||
block.ff.fc2.weight = assign(
|
||||
block.ff.fc2.weight,
|
||||
params[f"model.layers.{l}.mlp.up_proj.weight"],
|
||||
f"model.layers.{l}.mlp.up_proj.weight"
|
||||
)
|
||||
block.ff.fc3.weight = assign(
|
||||
block.ff.fc3.weight,
|
||||
params[f"model.layers.{l}.mlp.down_proj.weight"],
|
||||
f"model.layers.{l}.mlp.down_proj.weight"
|
||||
)
|
||||
block.norm2.scale = assign(
|
||||
block.norm2.scale,
|
||||
params[f"model.layers.{l}.post_attention_layernorm.weight"],
|
||||
f"model.layers.{l}.post_attention_layernorm.weight"
|
||||
)
|
||||
|
||||
# Final normalization and output head
|
||||
model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
|
||||
|
||||
# Model uses weight tying, hence we reuse the embedding layer weights here
|
||||
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
|
||||
class Qwen3Tokenizer():
|
||||
def __init__(self, tokenizer_file_path="tokenizer.json",
|
||||
repo_id=None, add_generation_prompt=False, add_thinking=False):
|
||||
from tokenizers import Tokenizer
|
||||
self.tokenizer_file_path = tokenizer_file_path
|
||||
|
||||
if add_generation_prompt != add_thinking:
|
||||
raise ValueError(
|
||||
"Only add_generation_prompt==add_thinking settings are currently supported"
|
||||
)
|
||||
|
||||
self.add_generation_prompt = add_generation_prompt
|
||||
self.add_thinking = add_thinking
|
||||
|
||||
tokenizer_file_path_obj = Path(tokenizer_file_path)
|
||||
if not tokenizer_file_path_obj.is_file() and repo_id is not None:
|
||||
_ = download_from_huggingface(
|
||||
repo_id=repo_id,
|
||||
filename=str(tokenizer_file_path_obj.name),
|
||||
local_dir=str(tokenizer_file_path_obj.parent.name)
|
||||
)
|
||||
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
|
||||
|
||||
def encode(self, prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
formatted_prompt = self.format_qwen_chat(
|
||||
messages,
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
add_thinking=self.add_thinking
|
||||
)
|
||||
return self.tokenizer.encode(formatted_prompt).ids
|
||||
|
||||
def decode(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
||||
|
||||
@staticmethod
|
||||
def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
|
||||
prompt = ""
|
||||
for msg in messages:
|
||||
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
|
||||
if add_generation_prompt:
|
||||
prompt += "<|im_start|>assistant"
|
||||
if not add_thinking:
|
||||
prompt += "<|think>\n\n<|/think>\n\n"
|
||||
else:
|
||||
prompt += "\n"
|
||||
return prompt
|
||||
|
||||
|
||||
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
|
||||
base_url = "https://huggingface.co"
|
||||
url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
|
||||
Path(local_dir).mkdir(parents=True, exist_ok=True)
|
||||
dest_path = os.path.join(local_dir, filename)
|
||||
print(f"Downloading {url} to {dest_path}...")
|
||||
urllib.request.urlretrieve(url, dest_path)
|
||||
return dest_path
|
@ -19,6 +19,36 @@ import tiktoken
|
||||
import torch
|
||||
|
||||
|
||||
class LitGPTRMSNorm(torch.nn.Module):
|
||||
"""Root Mean Square Layer Normalization.
|
||||
|
||||
From https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
|
||||
Apache License 2.0-Clause License: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
|
||||
|
||||
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
||||
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(size))
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.add_unit_offset = add_unit_offset
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
# NOTE: the original RMSNorm paper implementation is not equivalent
|
||||
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
|
||||
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
||||
weight = (1 + self.weight) if self.add_unit_offset else self.weight
|
||||
return (x_normed * weight.float()).to(dtype=dtype)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
|
||||
transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
@ -179,3 +209,25 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
|
||||
[43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
|
||||
])
|
||||
assert torch.equal(expect, out)
|
||||
|
||||
|
||||
def test_rmsnorm_equivalence():
|
||||
torch.manual_seed(42)
|
||||
|
||||
hidden_size = 64
|
||||
batch_size = 8
|
||||
seq_len = 16
|
||||
|
||||
rms_norm = torch.nn.RMSNorm(hidden_size, eps=1e-6)
|
||||
lit_norm = LitGPTRMSNorm(hidden_size)
|
||||
|
||||
# Sync weights
|
||||
with torch.no_grad():
|
||||
lit_norm.weight.copy_(lit_norm.weight)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size)
|
||||
|
||||
out1 = rms_norm(x)
|
||||
out2 = lit_norm(x)
|
||||
|
||||
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
|
||||
|
194
pkg/llms_from_scratch/tests/test_qwen3.py
Normal file
194
pkg/llms_from_scratch/tests/test_qwen3.py
Normal file
@ -0,0 +1,194 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.qwen3 import (
|
||||
compute_rope_params,
|
||||
apply_rope,
|
||||
QWEN_CONFIG_06_B,
|
||||
RMSNorm,
|
||||
Qwen3Model,
|
||||
Qwen3Tokenizer
|
||||
)
|
||||
|
||||
import importlib
|
||||
import pytest
|
||||
import tiktoken
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Qwen3RMSNorm(nn.Module):
|
||||
# Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py
|
||||
# License: Apache License, Version 2.0 (see file above)
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Qwen3RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
print(input_dtype)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_rope():
|
||||
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
# Settings
|
||||
batch_size = 1
|
||||
context_len = 8192
|
||||
num_heads = 4
|
||||
head_dim = 16
|
||||
rope_theta = 1_000_000
|
||||
|
||||
# Instantiate RoPE parameters
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=head_dim,
|
||||
theta_base=rope_theta,
|
||||
context_length=context_len,
|
||||
)
|
||||
|
||||
# Dummy query and key tensors
|
||||
torch.manual_seed(123)
|
||||
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
queries_rot = apply_rope(queries, cos, sin)
|
||||
keys_rot = apply_rope(keys, cos, sin)
|
||||
|
||||
# Generate reference RoPE via HF
|
||||
class RoPEConfig:
|
||||
rope_type = "qwen3"
|
||||
factor = 1.0
|
||||
dim: int = head_dim
|
||||
rope_theta = 1_000_000
|
||||
max_position_embeddings: int = 8192
|
||||
hidden_size = head_dim * num_heads
|
||||
num_attention_heads = num_heads
|
||||
|
||||
config = RoPEConfig()
|
||||
|
||||
rot_emb = Qwen3RotaryEmbedding(config=config)
|
||||
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
||||
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
||||
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
||||
|
||||
torch.testing.assert_close(sin, ref_sin.squeeze(0))
|
||||
torch.testing.assert_close(cos, ref_cos.squeeze(0))
|
||||
torch.testing.assert_close(keys_rot, ref_keys_rot)
|
||||
torch.testing.assert_close(queries_rot, ref_queries_rot)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen3_weights_path(tmp_path_factory):
|
||||
"""Creates and saves a deterministic Llama3 model for testing."""
|
||||
path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
|
||||
|
||||
if not path.exists():
|
||||
torch.manual_seed(123)
|
||||
model = Qwen3Model(QWEN_CONFIG_06_B)
|
||||
torch.save(model.state_dict(), path)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [Qwen3Model])
|
||||
def test_gpt_model_variants(ModelClass, qwen3_weights_path):
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(QWEN_CONFIG_06_B)
|
||||
model.load_state_dict(torch.load(qwen3_weights_path))
|
||||
model.eval()
|
||||
|
||||
start_context = "Llamas eat"
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
encoded = tokenizer.encode(start_context)
|
||||
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
|
||||
|
||||
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
|
||||
print("\nInput text:", start_context)
|
||||
print("Encoded input text:", encoded)
|
||||
print("encoded_tensor.shape:", encoded_tensor.shape)
|
||||
|
||||
out = generate_text_simple(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
max_new_tokens=5,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"]
|
||||
)
|
||||
print("Encoded output text:", out)
|
||||
expect = torch.tensor([
|
||||
[43, 2543, 292, 4483, 115206, 459, 43010, 104223, 55553]
|
||||
])
|
||||
assert torch.equal(expect, out)
|
||||
|
||||
|
||||
def test_rmsnorm_equivalence():
|
||||
torch.manual_seed(42)
|
||||
|
||||
hidden_size = 64
|
||||
batch_size = 8
|
||||
seq_len = 16
|
||||
|
||||
rms_norm = RMSNorm(hidden_size)
|
||||
ref_norm = Qwen3RMSNorm(hidden_size)
|
||||
|
||||
# Sync weights
|
||||
with torch.no_grad():
|
||||
ref_norm.weight.copy_(ref_norm.weight)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size)
|
||||
|
||||
out1 = rms_norm(x)
|
||||
out2 = ref_norm(x)
|
||||
|
||||
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_tokenizer_equivalence():
|
||||
from transformers import AutoTokenizer
|
||||
repo_id = "Qwen/Qwen3-0.6B"
|
||||
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
for states in ((True, True), (False, False)):
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
|
||||
repo_id=repo_id,
|
||||
add_generation_prompt=states[0],
|
||||
add_thinking=states[1]
|
||||
)
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
input_token_ids_ref = tokenizer_ref.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=states[0],
|
||||
enable_thinking=states[1],
|
||||
)
|
||||
assert input_token_ids == input_token_ids_ref, states
|
||||
|
||||
output_text = tokenizer.decode(input_token_ids)
|
||||
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
||||
assert output_text == out_text_ref, states
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "llms-from-scratch"
|
||||
version = "1.0.7"
|
||||
version = "1.0.9"
|
||||
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
Loading…
x
Reference in New Issue
Block a user