Qwen3 KV cache (#688)

This commit is contained in:
Sebastian Raschka 2025-06-21 17:34:39 -05:00 committed by GitHub
parent 2a530b49fe
commit 0b15a00574
8 changed files with 370 additions and 11 deletions

View File

@ -254,3 +254,5 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
| Llama3Model | Regular compiled | Nvidia A100 GPU | 170 | 3.12 GB |
| Llama3Model | KV cache | Nvidia A100 GPU | 60 | 18.87 GB |
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 59 | 19.12 GB |
Note that all settings above have been tested to produce the same text outputs.

View File

@ -165,7 +165,7 @@ Large language models (LLMs) are advanced artificial intelligence systems design
```
 
#### Pro tip: speed up inference with compilation
#### Pro tip 1: speed up inference with compilation
For up to a 4× speed-up, replace
@ -188,4 +188,44 @@ The following table shows a performance comparison on an A100 for consequent `ge
| | Tokens/sec | Memory |
| ------------------- | ---------- | ------- |
| Qwen3Model | 25 | 1.49 GB |
| Qwen3Model compiled | 101 | 1.99 GB |
| Qwen3Model compiled | 107 | 1.99 GB |
 
#### Pro tip 2: speed up inference with compilation
You can significantly boost inference performance using the KV cache `Qwen3Model` drop-in replacement when running the model on a CPU. (See my [Understanding and Coding the KV Cache in LLMs from Scratch](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms) article to learn more about KV caches.)
```python
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model
from llms_from_scratch.kv_cache.generate import generate_text_simple
model = Qwen3Model(QWEN_CONFIG_06_B)
# ...
token_ids = generate_text_simple(
model=model,
idx=text_to_token_ids(PROMPT, tokenizer).to(device),
max_new_tokens=MAX_NEW_TOKENS,
context_size=QWEN_CONFIG_06_B["context_length"],
)
```
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
|------------|-------------------|-----------------|------------|-------------------|
| Qwen3Model | Regular | Mac Mini M4 CPU | 1 | - |
| Qwen3Model | Regular compiled | Mac Mini M4 CPU | - | - |
| Qwen3Model | KV cache | Mac Mini M4 CPU | 80 | - |
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | - | - |
| | | | | |
| Qwen3Model | Regular | Mac Mini M4 GPU | 21 | - |
| Qwen3Model | Regular compiled | Mac Mini M4 GPU | - | - |
| Qwen3Model | KV cache | Mac Mini M4 GPU | 32 | - |
| Qwen3Model | KV cache compiled | Mac Mini M4 GPU | - | - |
| | | | | |
| Qwen3Model | Regular | Nvidia A100 GPU | 25 | 1.49 GB |
| Qwen3Model | Regular compiled | Nvidia A100 GPU | 107 | 1.99 GB |
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 10.20 GB |
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 24 | 10.61 GB |
Note that all settings above have been tested to produce the same text outputs.

View File

@ -109,7 +109,6 @@ from llms_from_scratch.ch07 import (
from llms_from_scratch.appendix_a import NeuralNetwork, ToyDataset
from llms_from_scratch.appendix_d import find_highest_gradient, train_model
```
@ -140,12 +139,15 @@ from llms_from_scratch.llama3 import (
clean_text
)
# KV cache drop-in replacements
from llms_from_scratch.kv_cache.llama3 import Llama3Model
from llms_from_scratch.kv_cache.generate import generate_text_simple
```
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
For more information about KV caching, please see the [KV cache README](../../ch04/03_kv-cache).
 
### Qwen3 (Bonus material)
@ -155,7 +157,12 @@ from llms_from_scratch.qwen3 import (
Qwen3Model,
Qwen3Tokenizer,
)
# KV cache drop-in replacements
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model
from llms_from_scratch.kv_cache.generate import generate_text_simple
```
For the `llms_from_scratch.qwen3` usage information, please see [this bonus section](../../ch05/11_qwen3/README.md).
For more information about KV caching, please see the [KV cache README](../../ch04/03_kv-cache).

View File

@ -0,0 +1,299 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from ..qwen3 import Qwen3Tokenizer, download_from_huggingface, load_weights_into_qwen # noqa: F401
import torch
import torch.nn as nn
# 0.6B model
QWEN_CONFIG_06_B = {
"vocab_size": 151_936, # Vocabulary size
"context_length": 40_960, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 1024, # Embedding dimension
"n_heads": 16, # Number of attention heads
"n_layers": 28, # Number of layers
"hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
"head_dim": 128, # Size of the heads in GQA
"qk_norm": True, # Whether to normalize queries and values in GQA
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
"rope_base": 1_000_000.0, # The base in RoPE's "theta"
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
}
class Qwen3Model(nn.Module):
def __init__(self, cfg):
super().__init__()
# Main model parameters
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
)
self.final_norm = RMSNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
# Reusuable utilities
if cfg["head_dim"] is None:
head_dim = cfg["emb_dim"] // cfg["n_heads"]
else:
head_dim = cfg["head_dim"]
cos, sin = compute_rope_params(
head_dim=head_dim,
theta_base=cfg["rope_base"],
context_length=cfg["context_length"]
)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg
def forward(self, in_idx, use_cache=False):
# Forward pass
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds
for block in self.trf_blocks:
x = block(x, self.cos, self.sin, use_cache)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = GroupedQueryAttention(
d_in=cfg["emb_dim"],
num_heads=cfg["n_heads"],
head_dim=cfg["head_dim"],
num_kv_groups=cfg["n_kv_groups"],
qk_norm=cfg["qk_norm"],
max_seq_len=cfg["context_length"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
def forward(self, x, cos, sin, use_cache=False):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = x + shortcut # Add the original input back
return x
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
def forward(self, x):
x_fc1 = self.fc1(x)
x_fc2 = self.fc2(x)
x = nn.functional.silu(x_fc1) * x_fc2
return self.fc3(x)
class GroupedQueryAttention(nn.Module):
def __init__(
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None,
max_seq_len=None, window_size=None
):
super().__init__()
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
self.num_heads = num_heads
self.num_kv_groups = num_kv_groups
self.group_size = num_heads // num_kv_groups
if head_dim is None:
assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
head_dim = d_in // num_heads
self.head_dim = head_dim
self.d_out = num_heads * head_dim
self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)
if qk_norm:
self.q_norm = RMSNorm(head_dim, eps=1e-6)
self.k_norm = RMSNorm(head_dim, eps=1e-6)
else:
self.q_norm = self.k_norm = None
# For optional KV cache
self.max_seq_len = max_seq_len
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.cache_initialized = False
self.ptr = 0
def forward(self, x, cos, sin, use_cache=False):
b, num_tokens, _ = x.shape
# Apply projections
queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
keys_new = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
values_new = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
# Reshape
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
# Optional normalization
if self.q_norm:
queries = self.q_norm(queries)
if self.k_norm:
keys_new = self.k_norm(keys_new)
# For KV cache
pos_start = self.ptr
pos_end = pos_start + num_tokens
cos_slice = cos[pos_start:pos_end]
sin_slice = sin[pos_start:pos_end]
# Apply RoPE
keys_new = apply_rope(keys_new, cos_slice, sin_slice)
queries = apply_rope(queries, cos_slice, sin_slice)
# Expand K and V to match number of heads
keys_new = keys_new.repeat_interleave(self.group_size, dim=1)
values_new = values_new.repeat_interleave(self.group_size, dim=1)
if use_cache:
if not self.cache_initialized:
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
self.ptr = 0
self.cache_initialized = True
# In-place update
end = self.ptr + num_tokens
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
self.cache_v[:, :, self.ptr:end].copy_(values_new)
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
values = self.cache_v[:, :, max(0, end - self.window_size):end]
self.ptr = end
else:
keys, values = keys_new, values_new
# Attention
attn_scores = queries @ keys.transpose(2, 3)
# Create causal mask to fill attention scores
T_q = queries.shape[-2]
T_k = keys.shape[-2]
if not use_cache or T_q > 1:
causal_mask = torch.triu(
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
return self.out_proj(context)
def reset_cache(self):
if self.cache_k is not None:
self.cache_k.zero_()
self.cache_v.zero_()
self.ptr = 0
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
assert head_dim % 2 == 0, "Embedding dimension must be even"
# Compute the inverse frequencies
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
# Generate position indices
positions = torch.arange(context_length, dtype=dtype)
# Compute the angles
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
# Expand angles to match the head_dim
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
# Precompute sine and cosine
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin
def apply_rope(x, cos, sin):
# x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even"
# Split x into first half and second half
x1 = x[..., : head_dim // 2] # First half
x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1)
x_rotated = (x * cos) + (rotated * sin)
# It's ok to use lower-precision after applying cos and sin rotation
return x_rotated.to(dtype=x.dtype)
class RMSNorm(nn.Module):
def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
super().__init__()
self.eps = eps
self.qwen3_compatible = qwen3_compatible
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
def forward(self, x):
input_dtype = x.dtype
if self.qwen3_compatible:
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True)
norm_x = x * torch.rsqrt(variance + self.eps)
norm_x = norm_x * self.scale
if self.shift is not None:
norm_x = norm_x + self.shift
return norm_x.to(input_dtype)

View File

@ -87,7 +87,7 @@ class TransformerBlock(nn.Module):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]
x = self.att(x, mask, cos, sin,) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block

View File

@ -189,7 +189,7 @@ def llama3_weights_path(tmp_path_factory):
)
@pytest.mark.parametrize("ModelClass", [Llama3Model, Llama3ModelKV])
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
def test_model_variants(ModelClass, generate_fn, llama3_weights_path):
# Skip incompatible combinations
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):

View File

@ -12,6 +12,9 @@ from llms_from_scratch.qwen3 import (
Qwen3Model,
Qwen3Tokenizer
)
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
import importlib
import pytest
@ -110,8 +113,16 @@ def qwen3_weights_path(tmp_path_factory):
return path
@pytest.mark.parametrize("ModelClass", [Qwen3Model])
def test_gpt_model_variants(ModelClass, qwen3_weights_path):
@pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
# Skip incompatible combinations
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
return
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
return
torch.manual_seed(123)
model = ModelClass(QWEN_CONFIG_06_B)
model.load_state_dict(torch.load(qwen3_weights_path))

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llms-from-scratch"
version = "1.0.12"
version = "1.0.13"
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
readme = "README.md"
requires-python = ">=3.10"