mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 19:10:19 +00:00
Qwen3 KV cache (#688)
This commit is contained in:
parent
2a530b49fe
commit
0b15a00574
@ -254,3 +254,5 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
|
||||
| Llama3Model | Regular compiled | Nvidia A100 GPU | 170 | 3.12 GB |
|
||||
| Llama3Model | KV cache | Nvidia A100 GPU | 60 | 18.87 GB |
|
||||
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 59 | 19.12 GB |
|
||||
|
||||
Note that all settings above have been tested to produce the same text outputs.
|
@ -165,7 +165,7 @@ Large language models (LLMs) are advanced artificial intelligence systems design
|
||||
```
|
||||
|
||||
|
||||
#### Pro tip: speed up inference with compilation
|
||||
#### Pro tip 1: speed up inference with compilation
|
||||
|
||||
|
||||
For up to a 4× speed-up, replace
|
||||
@ -188,4 +188,44 @@ The following table shows a performance comparison on an A100 for consequent `ge
|
||||
| | Tokens/sec | Memory |
|
||||
| ------------------- | ---------- | ------- |
|
||||
| Qwen3Model | 25 | 1.49 GB |
|
||||
| Qwen3Model compiled | 101 | 1.99 GB |
|
||||
| Qwen3Model compiled | 107 | 1.99 GB |
|
||||
|
||||
|
||||
#### Pro tip 2: speed up inference with compilation
|
||||
|
||||
You can significantly boost inference performance using the KV cache `Qwen3Model` drop-in replacement when running the model on a CPU. (See my [Understanding and Coding the KV Cache in LLMs from Scratch](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms) article to learn more about KV caches.)
|
||||
|
||||
```python
|
||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple
|
||||
|
||||
model = Qwen3Model(QWEN_CONFIG_06_B)
|
||||
# ...
|
||||
token_ids = generate_text_simple(
|
||||
model=model,
|
||||
idx=text_to_token_ids(PROMPT, tokenizer).to(device),
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"],
|
||||
)
|
||||
```
|
||||
|
||||
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
|
||||
|
||||
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
|
||||
|------------|-------------------|-----------------|------------|-------------------|
|
||||
| Qwen3Model | Regular | Mac Mini M4 CPU | 1 | - |
|
||||
| Qwen3Model | Regular compiled | Mac Mini M4 CPU | - | - |
|
||||
| Qwen3Model | KV cache | Mac Mini M4 CPU | 80 | - |
|
||||
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | - | - |
|
||||
| | | | | |
|
||||
| Qwen3Model | Regular | Mac Mini M4 GPU | 21 | - |
|
||||
| Qwen3Model | Regular compiled | Mac Mini M4 GPU | - | - |
|
||||
| Qwen3Model | KV cache | Mac Mini M4 GPU | 32 | - |
|
||||
| Qwen3Model | KV cache compiled | Mac Mini M4 GPU | - | - |
|
||||
| | | | | |
|
||||
| Qwen3Model | Regular | Nvidia A100 GPU | 25 | 1.49 GB |
|
||||
| Qwen3Model | Regular compiled | Nvidia A100 GPU | 107 | 1.99 GB |
|
||||
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 10.20 GB |
|
||||
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 24 | 10.61 GB |
|
||||
|
||||
Note that all settings above have been tested to produce the same text outputs.
|
@ -109,7 +109,6 @@ from llms_from_scratch.ch07 import (
|
||||
from llms_from_scratch.appendix_a import NeuralNetwork, ToyDataset
|
||||
|
||||
from llms_from_scratch.appendix_d import find_highest_gradient, train_model
|
||||
|
||||
```
|
||||
|
||||
|
||||
@ -140,12 +139,15 @@ from llms_from_scratch.llama3 import (
|
||||
clean_text
|
||||
)
|
||||
|
||||
# KV cache drop-in replacements
|
||||
from llms_from_scratch.kv_cache.llama3 import Llama3Model
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple
|
||||
```
|
||||
|
||||
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
|
||||
|
||||
For more information about KV caching, please see the [KV cache README](../../ch04/03_kv-cache).
|
||||
|
||||
|
||||
|
||||
### Qwen3 (Bonus material)
|
||||
@ -155,7 +157,12 @@ from llms_from_scratch.qwen3 import (
|
||||
Qwen3Model,
|
||||
Qwen3Tokenizer,
|
||||
)
|
||||
|
||||
# KV cache drop-in replacements
|
||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple
|
||||
```
|
||||
|
||||
|
||||
For the `llms_from_scratch.qwen3` usage information, please see [this bonus section](../../ch05/11_qwen3/README.md).
|
||||
|
||||
For more information about KV caching, please see the [KV cache README](../../ch04/03_kv-cache).
|
||||
|
299
pkg/llms_from_scratch/kv_cache/qwen3.py
Normal file
299
pkg/llms_from_scratch/kv_cache/qwen3.py
Normal file
@ -0,0 +1,299 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from ..qwen3 import Qwen3Tokenizer, download_from_huggingface, load_weights_into_qwen # noqa: F401
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 0.6B model
|
||||
QWEN_CONFIG_06_B = {
|
||||
"vocab_size": 151_936, # Vocabulary size
|
||||
"context_length": 40_960, # Context length that was used to train the model
|
||||
"window_size": None, # Window size for the KV cache; context_length if None
|
||||
"emb_dim": 1024, # Embedding dimension
|
||||
"n_heads": 16, # Number of attention heads
|
||||
"n_layers": 28, # Number of layers
|
||||
"hidden_dim": 3072, # Size of the intermediate dimension in FeedForward
|
||||
"head_dim": 128, # Size of the heads in GQA
|
||||
"qk_norm": True, # Whether to normalize queries and values in GQA
|
||||
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
|
||||
"rope_base": 1_000_000.0, # The base in RoPE's "theta"
|
||||
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
|
||||
}
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
# Main model parameters
|
||||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
|
||||
|
||||
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
|
||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
|
||||
)
|
||||
self.final_norm = RMSNorm(cfg["emb_dim"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
# Reusuable utilities
|
||||
if cfg["head_dim"] is None:
|
||||
head_dim = cfg["emb_dim"] // cfg["n_heads"]
|
||||
else:
|
||||
head_dim = cfg["head_dim"]
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=head_dim,
|
||||
theta_base=cfg["rope_base"],
|
||||
context_length=cfg["context_length"]
|
||||
)
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, in_idx, use_cache=False):
|
||||
# Forward pass
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
for block in self.trf_blocks:
|
||||
x = block(x, self.cos, self.sin, use_cache)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
|
||||
def reset_kv_cache(self):
|
||||
for blk in self.trf_blocks:
|
||||
blk.att.reset_cache()
|
||||
self.ptr_current_pos = 0
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.att = GroupedQueryAttention(
|
||||
d_in=cfg["emb_dim"],
|
||||
num_heads=cfg["n_heads"],
|
||||
head_dim=cfg["head_dim"],
|
||||
num_kv_groups=cfg["n_kv_groups"],
|
||||
qk_norm=cfg["qk_norm"],
|
||||
max_seq_len=cfg["context_length"],
|
||||
dtype=cfg["dtype"]
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
|
||||
def forward(self, x, cos, sin, use_cache=False):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
shortcut = x
|
||||
x = self.norm2(x)
|
||||
x = self.ff(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
|
||||
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
|
||||
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x_fc1 = self.fc1(x)
|
||||
x_fc2 = self.fc2(x)
|
||||
x = nn.functional.silu(x_fc1) * x_fc2
|
||||
return self.fc3(x)
|
||||
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(
|
||||
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None,
|
||||
max_seq_len=None, window_size=None
|
||||
):
|
||||
super().__init__()
|
||||
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_groups = num_kv_groups
|
||||
self.group_size = num_heads // num_kv_groups
|
||||
|
||||
if head_dim is None:
|
||||
assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
|
||||
head_dim = d_in // num_heads
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.d_out = num_heads * head_dim
|
||||
|
||||
self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
|
||||
self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
|
||||
self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
|
||||
|
||||
self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)
|
||||
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
else:
|
||||
self.q_norm = self.k_norm = None
|
||||
|
||||
# For optional KV cache
|
||||
self.max_seq_len = max_seq_len
|
||||
self.window_size = window_size or self.max_seq_len
|
||||
self.register_buffer("cache_k", None, persistent=False)
|
||||
self.register_buffer("cache_v", None, persistent=False)
|
||||
self.cache_initialized = False
|
||||
self.ptr = 0
|
||||
|
||||
def forward(self, x, cos, sin, use_cache=False):
|
||||
b, num_tokens, _ = x.shape
|
||||
|
||||
# Apply projections
|
||||
queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
|
||||
keys_new = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
values_new = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
|
||||
# Reshape
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Optional normalization
|
||||
if self.q_norm:
|
||||
queries = self.q_norm(queries)
|
||||
if self.k_norm:
|
||||
keys_new = self.k_norm(keys_new)
|
||||
|
||||
# For KV cache
|
||||
pos_start = self.ptr
|
||||
pos_end = pos_start + num_tokens
|
||||
cos_slice = cos[pos_start:pos_end]
|
||||
sin_slice = sin[pos_start:pos_end]
|
||||
|
||||
# Apply RoPE
|
||||
keys_new = apply_rope(keys_new, cos_slice, sin_slice)
|
||||
queries = apply_rope(queries, cos_slice, sin_slice)
|
||||
|
||||
# Expand K and V to match number of heads
|
||||
keys_new = keys_new.repeat_interleave(self.group_size, dim=1)
|
||||
values_new = values_new.repeat_interleave(self.group_size, dim=1)
|
||||
|
||||
if use_cache:
|
||||
if not self.cache_initialized:
|
||||
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
|
||||
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
|
||||
self.ptr = 0
|
||||
self.cache_initialized = True
|
||||
|
||||
# In-place update
|
||||
end = self.ptr + num_tokens
|
||||
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
|
||||
self.cache_v[:, :, self.ptr:end].copy_(values_new)
|
||||
|
||||
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
|
||||
values = self.cache_v[:, :, max(0, end - self.window_size):end]
|
||||
self.ptr = end
|
||||
else:
|
||||
keys, values = keys_new, values_new
|
||||
|
||||
# Attention
|
||||
attn_scores = queries @ keys.transpose(2, 3)
|
||||
|
||||
# Create causal mask to fill attention scores
|
||||
T_q = queries.shape[-2]
|
||||
T_k = keys.shape[-2]
|
||||
|
||||
if not use_cache or T_q > 1:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
|
||||
diagonal=1
|
||||
)
|
||||
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
|
||||
|
||||
attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
|
||||
|
||||
context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
|
||||
return self.out_proj(context)
|
||||
|
||||
def reset_cache(self):
|
||||
if self.cache_k is not None:
|
||||
self.cache_k.zero_()
|
||||
self.cache_v.zero_()
|
||||
self.ptr = 0
|
||||
|
||||
|
||||
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
|
||||
assert head_dim % 2 == 0, "Embedding dimension must be even"
|
||||
|
||||
# Compute the inverse frequencies
|
||||
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
|
||||
|
||||
# Generate position indices
|
||||
positions = torch.arange(context_length, dtype=dtype)
|
||||
|
||||
# Compute the angles
|
||||
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
|
||||
|
||||
# Expand angles to match the head_dim
|
||||
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
|
||||
|
||||
# Precompute sine and cosine
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
def apply_rope(x, cos, sin):
|
||||
# x: (batch_size, num_heads, seq_len, head_dim)
|
||||
batch_size, num_heads, seq_len, head_dim = x.shape
|
||||
assert head_dim % 2 == 0, "Head dimension must be even"
|
||||
|
||||
# Split x into first half and second half
|
||||
x1 = x[..., : head_dim // 2] # First half
|
||||
x2 = x[..., head_dim // 2:] # Second half
|
||||
|
||||
# Adjust sin and cos shapes
|
||||
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
|
||||
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Apply the rotary transformation
|
||||
rotated = torch.cat((-x2, x1), dim=-1)
|
||||
x_rotated = (x * cos) + (rotated * sin)
|
||||
|
||||
# It's ok to use lower-precision after applying cos and sin rotation
|
||||
return x_rotated.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.qwen3_compatible = qwen3_compatible
|
||||
self.scale = nn.Parameter(torch.ones(emb_dim))
|
||||
self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
input_dtype = x.dtype
|
||||
|
||||
if self.qwen3_compatible:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
norm_x = x * torch.rsqrt(variance + self.eps)
|
||||
norm_x = norm_x * self.scale
|
||||
|
||||
if self.shift is not None:
|
||||
norm_x = norm_x + self.shift
|
||||
|
||||
return norm_x.to(input_dtype)
|
||||
|
@ -87,7 +87,7 @@ class TransformerBlock(nn.Module):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = self.att(x, mask, cos, sin,) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
|
@ -189,7 +189,7 @@ def llama3_weights_path(tmp_path_factory):
|
||||
)
|
||||
@pytest.mark.parametrize("ModelClass", [Llama3Model, Llama3ModelKV])
|
||||
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
||||
def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
|
||||
def test_model_variants(ModelClass, generate_fn, llama3_weights_path):
|
||||
|
||||
# Skip incompatible combinations
|
||||
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
|
||||
|
@ -12,6 +12,9 @@ from llms_from_scratch.qwen3 import (
|
||||
Qwen3Model,
|
||||
Qwen3Tokenizer
|
||||
)
|
||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
||||
|
||||
|
||||
import importlib
|
||||
import pytest
|
||||
@ -110,8 +113,16 @@ def qwen3_weights_path(tmp_path_factory):
|
||||
return path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [Qwen3Model])
|
||||
def test_gpt_model_variants(ModelClass, qwen3_weights_path):
|
||||
@pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
|
||||
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
||||
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
||||
|
||||
# Skip incompatible combinations
|
||||
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(QWEN_CONFIG_06_B)
|
||||
model.load_state_dict(torch.load(qwen3_weights_path))
|
||||
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "llms-from-scratch"
|
||||
version = "1.0.12"
|
||||
version = "1.0.13"
|
||||
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
Loading…
x
Reference in New Issue
Block a user