Improve KV cache code for torch.compile (#705)

* Improve KV cache code for torch.compile

* cleanup

* cleanup
This commit is contained in:
Sebastian Raschka 2025-06-23 18:08:49 -05:00 committed by GitHub
parent 6522be94be
commit 81eda38d3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 593 additions and 315 deletions

View File

@ -27,7 +27,7 @@ class MultiHeadAttention(nn.Module):
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),diagonal=1),
torch.triu(torch.ones(context_length, context_length), diagonal=1),
persistent=False
)

View File

@ -236,14 +236,14 @@ token_ids = generate_text_simple(
)
```
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage results in even lower memory usage here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements; and KV-cache memory may increase prohibitively for longer contexts lengths).
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
|-------------|-------------------|-----------------|------------|-------------------|
| ----------- | ----------------- | --------------- | ---------- | ----------------- |
| Llama3Model | Regular | Mac Mini M4 CPU | 1 | - |
| Llama3Model | Regular compiled | Mac Mini M4 CPU | - | - |
| Llama3Model | KV cache | Mac Mini M4 CPU | 62 | - |
| Llama3Model | KV cache compiled | Mac Mini M4 CPU | - | - |
| Llama3Model | KV cache | Mac Mini M4 CPU | 68 | - |
| Llama3Model | KV cache compiled | Mac Mini M4 CPU | 86 | - |
| | | | | |
| Llama3Model | Regular | Mac Mini M4 GPU | 15 | - |
| Llama3Model | Regular compiled | Mac Mini M4 GPU | - | - |
@ -252,7 +252,7 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
| | | | | |
| Llama3Model | Regular | Nvidia A100 GPU | 42 | 2.91 GB |
| Llama3Model | Regular compiled | Nvidia A100 GPU | 170 | 3.12 GB |
| Llama3Model | KV cache | Nvidia A100 GPU | 60 | 18.87 GB |
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 59 | 19.12 GB |
| Llama3Model | KV cache | Nvidia A100 GPU | 58 | 2.87 GB |
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 161 | 3.61 GB |
Note that all settings above have been tested to produce the same text outputs.

View File

@ -209,23 +209,23 @@ token_ids = generate_text_simple(
)
```
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage results in even lower memory usage here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements; and KV-cache memory may increase prohibitively for longer contexts lengths).
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
| ---------- | ----------------- | --------------- | ---------- | ----------------- |
| Qwen3Model | Regular | Mac Mini M4 CPU | 1 | - |
| Qwen3Model | Regular compiled | Mac Mini M4 CPU | 1 | - |
| Qwen3Model | KV cache | Mac Mini M4 CPU | 80 | - |
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 82 | - |
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 137 | - |
| | | | | |
| Qwen3Model | Regular | Mac Mini M4 GPU | 21 | - |
| Qwen3Model | Regular compiled | Mac Mini M4 GPU | Error | - |
| Qwen3Model | KV cache | Mac Mini M4 GPU | 32 | - |
| Qwen3Model | KV cache | Mac Mini M4 GPU | 28 | - |
| Qwen3Model | KV cache compiled | Mac Mini M4 GPU | Error | - |
| | | | | |
| Qwen3Model | Regular | Nvidia A100 GPU | 25 | 1.49 GB |
| Qwen3Model | Regular | Nvidia A100 GPU | 26 | 1.49 GB |
| Qwen3Model | Regular compiled | Nvidia A100 GPU | 107 | 1.99 GB |
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 10.20 GB |
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 24 | 10.61 GB |
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 1.47 GB |
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 90 | 1.48 GB |
Note that all settings above have been tested to produce the same text outputs.

View File

@ -3,23 +3,24 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from .utils import KVCache
import torch
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
model.eval()
ctx_len = context_size or model.cfg["context_length"]
cache = KVCache(n_layers=model.cfg["n_layers"]) if use_cache else None
with torch.no_grad():
if use_cache:
model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True)
logits = model(idx[:, -ctx_len:], use_cache=True, cache=cache)
for _ in range(max_new_tokens):
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
logits = model(next_idx, use_cache=True)
logits = model(next_idx, use_cache=True, cache=cache)
else:
for _ in range(max_new_tokens):
logits = model(idx[:, -ctx_len:], use_cache=False)

View File

@ -3,6 +3,8 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from .utils import KVCache # noqa: F401
import torch
import torch.nn as nn
@ -11,7 +13,7 @@ import torch.nn as nn
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, max_seq_len=None, window_size=None):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
@ -25,80 +27,41 @@ class MultiHeadAttention(nn.Module):
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
####################################################
# NEW
self.max_seq_len = max_seq_len or context_length
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
####################################################
def forward(self, x, use_cache=False):
def forward(self, x, use_cache=False, start_pos=0, cache=None):
b, num_tokens, d_in = x.shape
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
values_new = self.W_value(x)
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
values = self.W_value(x)
queries = self.W_query(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.transpose(1, 2)
values_new = values_new.transpose(1, 2)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
####################################################
# NEW
if use_cache:
if self.cache_k is None or self.cache_k.size(0) != b:
self.cache_k = torch.zeros(b, self.num_heads,
self.window_size, self.head_dim,
device=x.device)
self.cache_v = torch.zeros_like(self.cache_k)
self.ptr_cur = 0 # pointer to next free slot
# if incoming chunk would overflow discard oldest tokens
if self.ptr_cur + num_tokens > self.window_size:
overflow = self.ptr_cur + num_tokens - self.window_size
# shift everything left by `overflow` (cheap view-copy)
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
self.ptr_cur -= overflow # pointer after shift
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
self.ptr_cur += num_tokens
keys = self.cache_k[:, :, :self.ptr_cur, :]
values = self.cache_v[:, :, :self.ptr_cur, :]
if cache is not None:
keys = torch.cat([cache[0], keys], dim=2)
values = torch.cat([cache[1], values], dim=2)
next_cache = (keys, values)
else:
keys, values = keys_new, values_new
self.ptr_cur = 0 # keep pointer sane if you interleave modes
####################################################
next_cache = None
seq_len = keys.size(2)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device), diagonal=1)
causal_mask = causal_mask[:, -num_tokens:][None, None, :, :]
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
####################################################
# NEW
K = attn_scores.size(-1)
if num_tokens == K:
# No cache → use the prebaked triangular mask slice
causal_mask = torch.triu(torch.ones(num_tokens, K, device=x.device, dtype=torch.bool), diagonal=1)
else:
# Cached: need to offset the diagonal by (K num_tokens)
offset = K - num_tokens # number of tokens already in cache before this chunk
row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1) # (num_tokens, 1)
col_idx = torch.arange(K, device=x.device).unsqueeze(0) # (1, K)
causal_mask = row_idx + offset < col_idx # True where j > i+offset
####################################################
# Use the mask to fill attention scores
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
attn_scores.masked_fill_(causal_mask, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
@ -110,13 +73,7 @@ class MultiHeadAttention(nn.Module):
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
####################################################
# NEW
def reset_cache(self):
self.cache_k, self.cache_v = None, None
####################################################
return context_vec, next_cache
#####################################
@ -169,25 +126,17 @@ class TransformerBlock(nn.Module):
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"],
window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
)
qkv_bias=cfg["qkv_bias"])
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
def forward(self, x, use_cache=False):
def forward(self, x, use_cache=False, start_pos=0, cache=None):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
####################################################
# NEW
x = self.att(x, use_cache=use_cache)
####################################################
x, next_cache = self.att(x, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
@ -198,7 +147,7 @@ class TransformerBlock(nn.Module):
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
return x
return x, next_cache
class GPTModel(nn.Module):
@ -208,80 +157,34 @@ class GPTModel(nn.Module):
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
# self.trf_blocks = nn.Sequential(
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
####################################################
# NEW
self.trf_blocks = nn.ModuleList(
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.ptr_current_pos = 0
####################################################
self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
self.current_pos = 0
def forward(self, in_idx, use_cache=False):
def forward(self, in_idx, use_cache=False, cache=None):
batch_size, seq_len = in_idx.shape
pos = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device)
tok_embeds = self.tok_emb(in_idx)
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
####################################################
# NEW
pos_embeds = self.pos_emb(pos)
x = self.drop_emb(tok_embeds + pos_embeds)
if use_cache:
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
self.ptr_current_pos += seq_len
start_pos = self.current_pos
self.current_pos += seq_len
else:
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
####################################################
start_pos = 0
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
# x = self.trf_blocks(x)
####################################################
# NEW
for blk in self.trf_blocks:
x = blk(x, use_cache=use_cache)
####################################################
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_cache = block(x, use_cache=use_cache, start_pos=start_pos, cache=blk_cache)
if cache:
cache.update(i, new_cache)
next_cache.append(new_cache)
x = self.final_norm(x)
logits = self.out_head(x)
return logits
####################################################
# NEW
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
####################################################
def generate_text_simple(model, idx, max_new_tokens, context_size):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop current context if it exceeds the supported context size
# E.g., if LLM supports only 5 tokens, and the context size is 10
# then only the last 5 tokens are used as context
idx_cond = idx[:, -context_size:]
# Get the predictions
with torch.no_grad():
logits = model(idx_cond)
# Focus only on the last time step
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
logits = logits[:, -1, :]
# Get the idx of the vocab entry with the highest logits value
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
return idx

View File

@ -3,15 +3,20 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from ..llama3 import Llama3Tokenizer, ChatFormat, clean_text # noqa: F401
from .utils import KVCache # noqa: F401
import os
from pathlib import Path
import torch
import torch.nn as nn
import tiktoken
from tiktoken.load import load_tiktoken_bpe
LLAMA32_CONFIG_1B = {
"vocab_size": 128_256, # Vocabulary size
"context_length": 131_072, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 2048, # Embedding dimension
"n_heads": 32, # Number of attention heads
"n_layers": 16, # Number of layers
@ -30,7 +35,6 @@ LLAMA32_CONFIG_1B = {
LLAMA32_CONFIG_3B = {
"vocab_size": 128_256, # Vocabulary size
"context_length": 131_072, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 3072, # Embedding dimension
"n_heads": 24, # Number of attention heads
"n_layers": 28, # Number of layers
@ -71,21 +75,45 @@ class Llama3Model(nn.Module):
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg
self.current_pos = 0 # Track current position in KV cache
def forward(self, in_idx, use_cache=False):
def forward(self, in_idx, use_cache=False, cache=None):
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds
for block in self.trf_blocks:
x = block(x, self.cos, self.sin, use_cache)
num_tokens = x.shape[1]
if use_cache:
pos_start = self.current_pos
pos_end = pos_start + num_tokens
self.current_pos = pos_end
mask = torch.triu(
torch.ones(pos_end, pos_end, device=x.device, dtype=torch.bool), diagonal=1
)[pos_start:pos_end, :pos_end]
else:
pos_start = 0 # Not strictly necessary but helps torch.compile
mask = torch.triu(
torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1
)
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
mask = mask[None, None, :, :]
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_blk_cache = block(x, mask, self.cos, self.sin,
use_cache=use_cache,
start_pos=pos_start,
cache=blk_cache)
if cache:
cache.update(i, new_blk_cache)
next_cache.append(new_blk_cache)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
self.current_pos = 0
class TransformerBlock(nn.Module):
@ -96,18 +124,17 @@ class TransformerBlock(nn.Module):
d_out=cfg["emb_dim"],
num_heads=cfg["n_heads"],
num_kv_groups=cfg["n_kv_groups"],
max_seq_len=cfg["context_length"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
def forward(self, x, cos, sin, use_cache=False):
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size]
x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
@ -116,7 +143,7 @@ class TransformerBlock(nn.Module):
x = self.ff(x)
x = x + shortcut # Add the original input back
return x
return x, next_cache
class FeedForward(nn.Module):
@ -135,7 +162,7 @@ class FeedForward(nn.Module):
class GroupedQueryAttention(nn.Module):
def __init__(
self, d_in, d_out, num_heads, num_kv_groups, dtype=None, max_seq_len=None, window_size=None
self, d_in, d_out, num_heads, num_kv_groups, dtype=None
):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
@ -153,45 +180,40 @@ class GroupedQueryAttention(nn.Module):
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
# For optional KV cache
self.max_seq_len = max_seq_len
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.cache_initialized = False
self.ptr = 0
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
b, num_tokens, _ = x.shape
def forward(self, x, cos, sin, use_cache=False):
b, num_tokens, d_in = x.shape
# Apply projections
queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
keys_new = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
values_new = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
# Reshape queries, keys, and values
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
# Transpose keys, values, and queries
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
values_new = values_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
# For KV cache
pos_start = self.ptr
pos_end = pos_start + num_tokens
cos_slice = cos[pos_start:pos_end]
sin_slice = sin[pos_start:pos_end]
# Reshape
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
# Apply RoPE
keys_new = apply_rope(keys_new, cos_slice, sin_slice)
queries = apply_rope(queries, cos_slice, sin_slice)
queries = apply_rope(queries, cos, sin, offset=start_pos)
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
if use_cache:
if cache is None:
keys = keys_new
values = values_new
else:
prev_k, prev_v = cache
keys = torch.cat([prev_k, keys_new], dim=2)
values = torch.cat([prev_v, values_new], dim=2)
next_cache = (keys, values)
else:
keys, values = keys_new, values_new
next_cache = None
# Expand keys and values to match the number of heads
# Shape: (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
values_new = values_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
# For example, before repeat_interleave along dim=1 (query groups):
# [K1, K2]
# After repeat_interleave (each query group is repeated group_size times):
@ -199,38 +221,12 @@ class GroupedQueryAttention(nn.Module):
# If we used regular repeat instead of repeat_interleave, we'd get:
# [K1, K2, K1, K2]
if use_cache:
if not self.cache_initialized:
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
self.ptr = 0
self.cache_initialized = True
# In-place update
end = self.ptr + num_tokens
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
self.cache_v[:, :, self.ptr:end].copy_(values_new)
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
values = self.cache_v[:, :, max(0, end - self.window_size):end]
self.ptr = end
else:
keys, values = keys_new, values_new
# Compute scaled dot-product attention (aka self-attention) with a causal mask
# Shape: (b, num_heads, num_tokens, num_tokens)
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Create causal mask to fill attention scores
T_q = queries.shape[-2]
T_k = keys.shape[-2]
if not use_cache or T_q > 1:
causal_mask = torch.triu(
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
# Use the mask to fill attention scores
attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
assert keys.shape[-1] == self.head_dim
@ -242,13 +238,7 @@ class GroupedQueryAttention(nn.Module):
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
def reset_cache(self):
if self.cache_k is not None:
self.cache_k.zero_()
self.cache_v.zero_()
self.ptr = 0
return context_vec, next_cache
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
@ -296,7 +286,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_c
return cos, sin
def apply_rope(x, cos, sin):
def apply_rope(x, cos, sin, offset=9):
# x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even"
@ -306,8 +296,8 @@ def apply_rope(x, cos, sin):
x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1)
@ -315,3 +305,231 @@ def apply_rope(x, cos, sin):
# It's ok to use lower-precision after applying cos and sin rotation
return x_rotated.to(dtype=x.dtype)
##########################################
# Tokenizer
##########################################
class Llama3Tokenizer:
"""Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
def __init__(self, model_path):
if not os.path.isfile(model_path):
raise FileNotFoundError(model_path)
mergeable = load_tiktoken_bpe(model_path)
# hard-coded from Meta's tokenizer.json
self.special = {
"<|begin_of_text|>": 128000,
"<|end_of_text|>": 128001,
"<|start_header_id|>": 128006,
"<|end_header_id|>": 128007,
"<|eot_id|>": 128009,
}
self.special.update({f"<|reserved_{i}|>": 128002 + i
for i in range(256)
if 128002 + i not in self.special.values()})
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
r"|[^\r\n\p{L}\p{N}]?\p{L}+"
r"|\p{N}{1,3}"
r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
r"|\s*[\r\n]+"
r"|\s+(?!\S)"
r"|\s+",
mergeable_ranks=mergeable,
special_tokens=self.special,
)
def encode(self, text, bos=False, eos=False, **kwargs):
ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
+ self.model.encode(text)
if eos:
ids.append(self.special["<|end_of_text|>"])
return ids
def decode(self, ids):
return self.model.decode(ids)
class ChatFormat:
def __init__(self, tokenizer: Llama3Tokenizer, *,
default_system="You are a helpful assistant."):
self.tok = tokenizer
self.default_system = default_system
def _header(self, role):
"""Encode <|start_header_id|>role<|end_header_id|>\n\n"""
return (
[self.tok.special["<|start_header_id|>"]]
+ self.tok.encode(role)
+ [self.tok.special["<|end_header_id|>"]]
+ self.tok.encode("\n\n")
)
def encode(self, user_message, system_message=None, allowed_special=None):
sys_msg = system_message if system_message is not None else self.default_system
ids = [self.tok.special["<|begin_of_text|>"]]
# system
ids += self._header("system")
ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
ids += [self.tok.special["<|eot_id|>"]]
# user
ids += self._header("user")
ids += self.tok.encode(user_message)
ids += [self.tok.special["<|eot_id|>"]]
# assistant header (no content yet)
ids += self._header("assistant")
return ids
def decode(self, ids):
return self.tok.decode(ids)
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
# Find the index of the first occurrence of "<|end_header_id|>"
index = text.find(header_end)
if index != -1:
# Return the substring starting after "<|end_header_id|>"
return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace
else:
# If the token is not found, return the original text
return text
######################################################################
# Llama 3 fast (alternative code geared towards efficiency)
######################################################################
class GroupedQueryAttentionFast(nn.Module):
"""
Drop-in replacement for GroupedQueryAttention but using PyTorch's
scaled_dot_product_attention, which uses FlashAttention if run
on an Ampere GPU (like A100) or newer and uses float16/bfloat16 or lower.
"""
def __init__(self, d_in, d_out, num_heads, num_kv_groups, dtype=None):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.num_kv_groups = num_kv_groups
self.group_size = num_heads // num_kv_groups
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
def forward(self, x, cos, sin):
b, num_tokens, _ = x.shape
# Project to queries, keys, values
q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
# Apply Rotary Positional Embedding
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
# Expand key/value groups to full head count
k = k.repeat_interleave(self.group_size, dim=1)
v = v.repeat_interleave(self.group_size, dim=1)
# Efficient scaled dot-product attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
is_causal=True # Enables Flash/FlexAttention kernels
)
# Combine heads and project
attn_output = attn_output.transpose(1, 2).reshape(b, num_tokens, self.d_out)
return self.out_proj(attn_output)
class TransformerBlockFast(nn.Module):
"""
Same as original TransformerBlock but uses
GroupedQueryAttentionFast instead of GroupedQueryAttention.
"""
def __init__(self, cfg):
super().__init__()
self.att = GroupedQueryAttentionFast(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
num_heads=cfg["n_heads"],
num_kv_groups=cfg["n_kv_groups"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
def forward(self, x, cos, sin):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x, cos, sin) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = x + shortcut # Add the original input back
return x
class Llama3ModelFast(nn.Module):
"""
Same as original Llama3Model but uses TransformerBlockFast
instead of TransformerBlock, which in turn uses
GroupedQueryAttentionFast instead of GroupedQueryAttention.
"""
def __init__(self, cfg):
super().__init__()
# Main model parameters
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, cos, sin`
[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])]
)
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"],
context_length=cfg["context_length"],
freq_config=cfg["rope_freq"]
)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg
def forward(self, in_idx):
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds
for block in self.trf_blocks:
x = block(x, self.cos, self.sin)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits

View File

@ -3,7 +3,11 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from ..qwen3 import Qwen3Tokenizer, download_from_huggingface, load_weights_into_qwen # noqa: F401
from .utils import KVCache # noqa: F401
import os
import urllib.request
from pathlib import Path
import torch
import torch.nn as nn
@ -12,7 +16,6 @@ import torch.nn as nn
QWEN_CONFIG_06_B = {
"vocab_size": 151_936, # Vocabulary size
"context_length": 40_960, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 1024, # Embedding dimension
"n_heads": 16, # Number of attention heads
"n_layers": 28, # Number of layers
@ -51,22 +54,46 @@ class Qwen3Model(nn.Module):
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg
self.current_pos = 0 # Track current position in KV cache
def forward(self, in_idx, use_cache=False):
def forward(self, in_idx, use_cache=False, cache=None):
# Forward pass
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds
for block in self.trf_blocks:
x = block(x, self.cos, self.sin, use_cache)
num_tokens = x.shape[1]
if use_cache:
pos_start = self.current_pos
pos_end = pos_start + num_tokens
self.current_pos = pos_end
mask = torch.triu(
torch.ones(pos_end, pos_end, device=x.device, dtype=torch.bool), diagonal=1
)[pos_start:pos_end, :pos_end]
else:
pos_start = 0 # Not strictly necessary but helps torch.compile
mask = torch.triu(
torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1
)
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
mask = mask[None, None, :, :]
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_blk_cache = block(x, mask, self.cos, self.sin,
use_cache=use_cache,
start_pos=pos_start,
cache=blk_cache)
if cache:
cache.update(i, new_blk_cache)
next_cache.append(new_blk_cache)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
self.current_pos = 0
class TransformerBlock(nn.Module):
@ -78,18 +105,17 @@ class TransformerBlock(nn.Module):
head_dim=cfg["head_dim"],
num_kv_groups=cfg["n_kv_groups"],
qk_norm=cfg["qk_norm"],
max_seq_len=cfg["context_length"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
def forward(self, x, cos, sin, use_cache=False):
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size]
x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
@ -98,7 +124,7 @@ class TransformerBlock(nn.Module):
x = self.ff(x)
x = x + shortcut # Add the original input back
return x
return x, next_cache
class FeedForward(nn.Module):
@ -117,8 +143,7 @@ class FeedForward(nn.Module):
class GroupedQueryAttention(nn.Module):
def __init__(
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None,
max_seq_len=None, window_size=None
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
):
super().__init__()
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
@ -146,26 +171,18 @@ class GroupedQueryAttention(nn.Module):
else:
self.q_norm = self.k_norm = None
# For optional KV cache
self.max_seq_len = max_seq_len
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.cache_initialized = False
self.ptr = 0
def forward(self, x, cos, sin, use_cache=False):
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
b, num_tokens, _ = x.shape
# Apply projections
queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
keys_new = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
values_new = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
# Reshape
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
# Optional normalization
if self.q_norm:
@ -173,62 +190,34 @@ class GroupedQueryAttention(nn.Module):
if self.k_norm:
keys_new = self.k_norm(keys_new)
# For KV cache
pos_start = self.ptr
pos_end = pos_start + num_tokens
cos_slice = cos[pos_start:pos_end]
sin_slice = sin[pos_start:pos_end]
# Apply RoPE
keys_new = apply_rope(keys_new, cos_slice, sin_slice)
queries = apply_rope(queries, cos_slice, sin_slice)
# Expand K and V to match number of heads
keys_new = keys_new.repeat_interleave(self.group_size, dim=1)
values_new = values_new.repeat_interleave(self.group_size, dim=1)
queries = apply_rope(queries, cos, sin, offset=start_pos)
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
if use_cache:
if not self.cache_initialized:
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
self.ptr = 0
self.cache_initialized = True
# In-place update
end = self.ptr + num_tokens
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
self.cache_v[:, :, self.ptr:end].copy_(values_new)
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
values = self.cache_v[:, :, max(0, end - self.window_size):end]
self.ptr = end
if cache is None:
keys = keys_new
values = values_new
else:
prev_k, prev_v = cache
keys = torch.cat([prev_k, keys_new], dim=2)
values = torch.cat([prev_v, values_new], dim=2)
next_cache = (keys, values)
else:
keys, values = keys_new, values_new
next_cache = None
# Expand K and V to match number of heads
keys = keys.repeat_interleave(self.group_size, dim=1)
values = values.repeat_interleave(self.group_size, dim=1)
# Attention
attn_scores = queries @ keys.transpose(2, 3)
# Create causal mask to fill attention scores
T_q = queries.shape[-2]
T_k = keys.shape[-2]
if not use_cache or T_q > 1:
causal_mask = torch.triu(
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
return self.out_proj(context)
def reset_cache(self):
if self.cache_k is not None:
self.cache_k.zero_()
self.cache_v.zero_()
self.ptr = 0
return self.out_proj(context), next_cache
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
@ -253,7 +242,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=
return cos, sin
def apply_rope(x, cos, sin):
def apply_rope(x, cos, sin, offset=0):
# x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even"
@ -263,8 +252,8 @@ def apply_rope(x, cos, sin):
x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1)
@ -297,3 +286,149 @@ class RMSNorm(nn.Module):
return norm_x.to(input_dtype)
def load_weights_into_qwen(model, param_config, params):
def assign(left, right, tensor_name="unknown"):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
for l in range(param_config["n_layers"]):
block = model.trf_blocks[l]
att = block.att
# Q, K, V projections
att.W_query.weight = assign(
att.W_query.weight,
params[f"model.layers.{l}.self_attn.q_proj.weight"],
f"model.layers.{l}.self_attn.q_proj.weight"
)
att.W_key.weight = assign(
att.W_key.weight,
params[f"model.layers.{l}.self_attn.k_proj.weight"],
f"model.layers.{l}.self_attn.k_proj.weight"
)
att.W_value.weight = assign(
att.W_value.weight,
params[f"model.layers.{l}.self_attn.v_proj.weight"],
f"model.layers.{l}.self_attn.v_proj.weight"
)
# Output projection
att.out_proj.weight = assign(
att.out_proj.weight,
params[f"model.layers.{l}.self_attn.o_proj.weight"],
f"model.layers.{l}.self_attn.o_proj.weight"
)
# QK norms
if hasattr(att, "q_norm") and att.q_norm is not None:
att.q_norm.scale = assign(
att.q_norm.scale,
params[f"model.layers.{l}.self_attn.q_norm.weight"],
f"model.layers.{l}.self_attn.q_norm.weight"
)
if hasattr(att, "k_norm") and att.k_norm is not None:
att.k_norm.scale = assign(
att.k_norm.scale,
params[f"model.layers.{l}.self_attn.k_norm.weight"],
f"model.layers.{l}.self_attn.k_norm.weight"
)
# Attention layernorm
block.norm1.scale = assign(
block.norm1.scale,
params[f"model.layers.{l}.input_layernorm.weight"],
f"model.layers.{l}.input_layernorm.weight"
)
# Feedforward weights
block.ff.fc1.weight = assign(
block.ff.fc1.weight,
params[f"model.layers.{l}.mlp.gate_proj.weight"],
f"model.layers.{l}.mlp.gate_proj.weight"
)
block.ff.fc2.weight = assign(
block.ff.fc2.weight,
params[f"model.layers.{l}.mlp.up_proj.weight"],
f"model.layers.{l}.mlp.up_proj.weight"
)
block.ff.fc3.weight = assign(
block.ff.fc3.weight,
params[f"model.layers.{l}.mlp.down_proj.weight"],
f"model.layers.{l}.mlp.down_proj.weight"
)
block.norm2.scale = assign(
block.norm2.scale,
params[f"model.layers.{l}.post_attention_layernorm.weight"],
f"model.layers.{l}.post_attention_layernorm.weight"
)
# Final normalization and output head
model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
# Model uses weight tying, hence we reuse the embedding layer weights here
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
class Qwen3Tokenizer():
def __init__(self, tokenizer_file_path="tokenizer.json",
repo_id=None, add_generation_prompt=False, add_thinking=False):
from tokenizers import Tokenizer
self.tokenizer_file_path = tokenizer_file_path
if add_generation_prompt != add_thinking:
raise ValueError(
"Only add_generation_prompt==add_thinking settings are currently supported"
)
self.add_generation_prompt = add_generation_prompt
self.add_thinking = add_thinking
tokenizer_file_path_obj = Path(tokenizer_file_path)
if not tokenizer_file_path_obj.is_file() and repo_id is not None:
_ = download_from_huggingface(
repo_id=repo_id,
filename=str(tokenizer_file_path_obj.name),
local_dir=str(tokenizer_file_path_obj.parent.name)
)
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
def encode(self, prompt):
messages = [
{"role": "user", "content": prompt}
]
formatted_prompt = self.format_qwen_chat(
messages,
add_generation_prompt=self.add_generation_prompt,
add_thinking=self.add_thinking
)
return self.tokenizer.encode(formatted_prompt).ids
def decode(self, token_ids):
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
@staticmethod
def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
prompt = ""
for msg in messages:
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
if add_generation_prompt:
prompt += "<|im_start|>assistant"
if not add_thinking:
prompt += "<|think>\n\n<|/think>\n\n"
else:
prompt += "\n"
return prompt
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
base_url = "https://huggingface.co"
url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
Path(local_dir).mkdir(parents=True, exist_ok=True)
dest_path = os.path.join(local_dir, filename)
print(f"Downloading {url} to {dest_path}...")
urllib.request.urlretrieve(url, dest_path)
return dest_path

View File

@ -0,0 +1,21 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
class KVCache:
def __init__(self, n_layers):
self.cache = [None] * n_layers
def get(self, layer_idx):
return self.cache[layer_idx]
def update(self, layer_idx, value):
self.cache[layer_idx] = value
def get_all(self):
return self.cache
def reset(self):
for i in range(len(self.cache)):
self.cache[i] = None