Llama 3 KV Cache (#685)

* Llama 3 KV Cache

* skip expensive tests on Gh actions

* Update __init__.py
This commit is contained in:
Sebastian Raschka 2025-06-21 10:55:20 -05:00 committed by GitHub
parent c008f95072
commit 3be0f3202a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 410 additions and 4 deletions

View File

@ -216,3 +216,41 @@ The following table shows a performance comparison on an A100 for consequent `ge
| --------------- | ---------- | ------- |
| Llama3Model | 170 | 3.12 GB |
| Llama3ModelFast | 177 | 3.61 GB |
 
#### Pro tip 3: speed up inference with compilation
You can significantly boost inference performance using the KV cache `Llama3Model` drop-in replacement when running the model on a CPU. (See my [Understanding and Coding the KV Cache in LLMs from Scratch](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms) article to learn more about KV caches.)
```python
from llms_from_scratch.kv_cache.llama3 import Llama3Model
from llms_from_scratch.kv_cache.generate import generate_text_simple
model = Llama3Model(LLAMA32_CONFIG)
# ...
token_ids = generate_text_simple(
model=model,
idx=text_to_token_ids(PROMPT, tokenizer).to(device),
max_new_tokens=MAX_NEW_TOKENS,
context_size=LLAMA32_CONFIG["context_length"],
)
```
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
|-------------|-------------------|-----------------|------------|-------------------|
| Llama3Model | Regular | Mac Mini M4 CPU | 1 | - |
| Llama3Model | Regular compiled | Mac Mini M4 CPU | - | - |
| Llama3Model | KV cache | Mac Mini M4 CPU | 62 | - |
| Llama3Model | KV cache compiled | Mac Mini M4 CPU | - | - |
| | | | | |
| Llama3Model | Regular | Mac Mini M4 GPU | 15 | - |
| Llama3Model | Regular compiled | Mac Mini M4 GPU | - | - |
| Llama3Model | KV cache | Mac Mini M4 GPU | 62 | - |
| Llama3Model | KV cache compiled | Mac Mini M4 GPU | - | - |
| | | | | |
| Llama3Model | Regular | Nvidia A100 GPU | 42 | 2.91 GB |
| Llama3Model | Regular compiled | Nvidia A100 GPU | 170 | 3.12 GB |
| Llama3Model | KV cache | Nvidia A100 GPU | 60 | 18.87 GB |
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 59 | 19.12 GB |

View File

@ -124,6 +124,9 @@ from llms_from_scratch.llama3 import (
ChatFormat,
clean_text
)
from llms_from_scratch.kv_cache.llama3 import Llama3Model
from llms_from_scratch.kv_cache.generate import generate_text_simple
```
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).

View File

@ -0,0 +1,4 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch

View File

@ -0,0 +1,29 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import torch
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
model.eval()
ctx_len = context_size or model.cfg["context_length"]
with torch.no_grad():
if use_cache:
model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx

View File

@ -0,0 +1,317 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from ..llama3 import Llama3Tokenizer, ChatFormat, clean_text # noqa: F401
import torch
import torch.nn as nn
LLAMA32_CONFIG_1B = {
"vocab_size": 128_256, # Vocabulary size
"context_length": 131_072, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 2048, # Embedding dimension
"n_heads": 32, # Number of attention heads
"n_layers": 16, # Number of layers
"hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
"rope_base": 500_000.0, # The base in RoPE's "theta"
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
"rope_freq": { # RoPE frequency scaling
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
}
}
LLAMA32_CONFIG_3B = {
"vocab_size": 128_256, # Vocabulary size
"context_length": 131_072, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 3072, # Embedding dimension
"n_heads": 24, # Number of attention heads
"n_layers": 28, # Number of layers
"hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
"rope_base": 500_000.0, # The base in RoPE's "theta"
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
"rope_freq": { # RoPE frequency scaling
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
}
}
class Llama3Model(nn.Module):
def __init__(self, cfg):
super().__init__()
# Main model parameters
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
)
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
# Reusuable utilities
cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"],
context_length=cfg["context_length"],
freq_config=cfg["rope_freq"]
)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg
def forward(self, in_idx, use_cache=False):
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds
for block in self.trf_blocks:
x = block(x, self.cos, self.sin, use_cache)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = GroupedQueryAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
num_heads=cfg["n_heads"],
num_kv_groups=cfg["n_kv_groups"],
max_seq_len=cfg["context_length"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
def forward(self, x, cos, sin, use_cache=False):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = x + shortcut # Add the original input back
return x
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
def forward(self, x):
x_fc1 = self.fc1(x)
x_fc2 = self.fc2(x)
x = nn.functional.silu(x_fc1) * x_fc2
return self.fc3(x)
class GroupedQueryAttention(nn.Module):
def __init__(
self, d_in, d_out, num_heads, num_kv_groups, dtype=None, max_seq_len=None, window_size=None
):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.num_kv_groups = num_kv_groups
self.group_size = num_heads // num_kv_groups
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
# For optional KV cache
self.max_seq_len = max_seq_len
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.cache_initialized = False
self.ptr = 0
def forward(self, x, cos, sin, use_cache=False):
b, num_tokens, d_in = x.shape
queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
keys_new = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
values_new = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
# Reshape queries, keys, and values
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
# Transpose keys, values, and queries
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
values_new = values_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
# For KV cache
pos_start = self.ptr
pos_end = pos_start + num_tokens
cos_slice = cos[pos_start:pos_end]
sin_slice = sin[pos_start:pos_end]
# Apply RoPE
keys_new = apply_rope(keys_new, cos_slice, sin_slice)
queries = apply_rope(queries, cos_slice, sin_slice)
# Expand keys and values to match the number of heads
# Shape: (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
values_new = values_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
# For example, before repeat_interleave along dim=1 (query groups):
# [K1, K2]
# After repeat_interleave (each query group is repeated group_size times):
# [K1, K1, K2, K2]
# If we used regular repeat instead of repeat_interleave, we'd get:
# [K1, K2, K1, K2]
if use_cache:
if not self.cache_initialized:
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
self.ptr = 0
self.cache_initialized = True
# In-place update
end = self.ptr + num_tokens
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
self.cache_v[:, :, self.ptr:end].copy_(values_new)
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
values = self.cache_v[:, :, max(0, end - self.window_size):end]
self.ptr = end
else:
keys, values = keys_new, values_new
# Compute scaled dot-product attention (aka self-attention) with a causal mask
# Shape: (b, num_heads, num_tokens, num_tokens)
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Create causal mask to fill attention scores
T_q = queries.shape[-2]
T_k = keys.shape[-2]
if not use_cache or T_q > 1:
causal_mask = torch.triu(
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
assert keys.shape[-1] == self.head_dim
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
def reset_cache(self):
if self.cache_k is not None:
self.cache_k.zero_()
self.cache_v.zero_()
self.ptr = 0
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
assert head_dim % 2 == 0, "Embedding dimension must be even"
# Compute the inverse frequencies
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
# Frequency adjustments
if freq_config is not None:
low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]
wavelen = 2 * torch.pi / inv_freq
inv_freq_llama = torch.where(
wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
)
smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
)
smoothed_inv_freq = (
(1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
)
is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
inv_freq = inv_freq_llama
# Generate position indices
positions = torch.arange(context_length, dtype=dtype)
# Compute the angles
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
# Expand angles to match the head_dim
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
# Precompute sine and cosine
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin
def apply_rope(x, cos, sin):
# x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even"
# Split x into first half and second half
x1 = x[..., : head_dim // 2] # First half
x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1)
x_rotated = (x * cos) + (rotated * sin)
# It's ok to use lower-precision after applying cos and sin rotation
return x_rotated.to(dtype=x.dtype)

View File

@ -12,8 +12,11 @@ from llms_from_scratch.llama3 import (
GroupedQueryAttentionFast,
Llama3Model,
)
from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
import importlib
import os
import pytest
import tiktoken
import torch
@ -180,8 +183,20 @@ def llama3_weights_path(tmp_path_factory):
return path
@pytest.mark.parametrize("ModelClass", [Llama3Model])
def test_gpt_model_variants(ModelClass, llama3_weights_path):
@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="Skipping in GitHub Actions due to compute or memory constraints"
)
@pytest.mark.parametrize("ModelClass", [Llama3Model, Llama3ModelKV])
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
# Skip incompatible combinations
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
return
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
return
torch.manual_seed(123)
model = ModelClass(LLAMA32_CONFIG_1B)
model.load_state_dict(torch.load(llama3_weights_path))
@ -198,7 +213,7 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
print("Encoded input text:", encoded)
print("encoded_tensor.shape:", encoded_tensor.shape)
out = generate_text_simple(
out = generate_fn(
model=model,
idx=encoded_tensor,
max_new_tokens=5,

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llms-from-scratch"
version = "1.0.9"
version = "1.0.12"
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
readme = "README.md"
requires-python = ">=3.10"