From ba0370abd1cbd888bef6d8612e66258756fdec7f Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sun, 15 Jun 2025 14:26:16 -0500 Subject: [PATCH] Optimized KV cache (#672) * Optimized KV cache * typo fix --- ch04/03_kv-cache/README.md | 61 +++ .../gpt_with_kv_cache_optimized.py | 380 ++++++++++++++++++ 2 files changed, 441 insertions(+) create mode 100644 ch04/03_kv-cache/gpt_with_kv_cache_optimized.py diff --git a/ch04/03_kv-cache/README.md b/ch04/03_kv-cache/README.md index 8e14277..8c00d1f 100644 --- a/ch04/03_kv-cache/README.md +++ b/ch04/03_kv-cache/README.md @@ -218,3 +218,64 @@ As sequence length increases, the benefits and downsides of a KV cache become mo +  +## Optimizing the KV Cache Implementation + +While my conceptual implementation of a KV cache above helps with clarity and is mainly geared towards code readability and educational purposes, deploying it in real-world scenarios (especially with larger models and longer sequence lengths) requires more careful optimization. + +  +### Common pitfalls when scaling the cache + +- **Memory fragmentation and repeated allocations**: Continuously concatenating tensors via `torch.cat` as shown earlier, leads to performance bottlenecks due to frequent memory allocation and reallocation. + +- **Linear growth in memory usage**: Without proper handling, the KV cache size becomes impractical for very long sequences. + +  +#### Tip 1: Pre-allocate Memory + +Rather than concatenating tensors repeatedly, we could pre-allocate a sufficiently large tensor based on the expected maximum sequence length. This ensures consistent memory use and reduces overhead. In pseudo-code, this may look like as follows: + +```python +# Example pre-allocation for keys and values +max_seq_len = 1024 # maximum expected sequence length +cache_k = torch.zeros((batch_size, num_heads, max_seq_len, head_dim), device=device) +cache_v = torch.zeros((batch_size, num_heads, max_seq_len, head_dim), device=device) +``` + +During inference, we can then simply write into slices of these pre-allocated tensors. + +  +#### Tip 2: Truncate Cache via Sliding Window + +To avoid blowing up our GPU memory, we can implement a sliding window approach with dynamic truncation. Via the sliding window, we maintain only the last `window_size` tokens in the cache: + + +```python +# Sliding window cache implementation +window_size = 512 +cache_k = cache_k[:, :, -window_size:, :] +cache_v = cache_v[:, :, -window_size:, :] +``` + +  +#### Optimizations in practice + +You can find these optimizations in the [`gpt_with_kv_cache_optimized.py`](gpt_with_kv_cache_optimized.py) file. + + +On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size of 48 below, the code runtimes compare as follows: + +| | Tokens/sec | +| -------------------------------- | ---------- | +| `gpt_ch04.py` | 27 | +| `gpt_with_kv_cache.py` | 110 | +| `gpt_with_kv_cache_optimized.py` | 148 | + +Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model. However, we can see a significant difference in the memory usage: + +| | RAM | +| -------------------------------- | ------- | +| `gpt_ch04.py` | 0.74 GB | +| `gpt_with_kv_cache.py` | 4.35 GB | +| `gpt_with_kv_cache_optimized.py` | 0.89 GB | + diff --git a/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py b/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py new file mode 100644 index 0000000..8b233ca --- /dev/null +++ b/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py @@ -0,0 +1,380 @@ +# This file collects all the relevant code that we covered thus far +# throughout Chapters 3-4. +# This file can be run as a standalone script. + +import time +import tiktoken +import torch +import torch.nn as nn + + +##################################### +# Chapter 3 +##################################### +class MultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, max_seq_len=None, window_size=None): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + + #################################################### + # NEW + self.max_seq_len = max_seq_len or context_length + self.window_size = window_size or self.max_seq_len + self.register_buffer("cache_k", None, persistent=False) + self.register_buffer("cache_v", None, persistent=False) + #################################################### + + def forward(self, x, use_cache=False): + b, num_tokens, d_in = x.shape + + keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out) + values_new = self.W_value(x) + queries = self.W_query(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim) + values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys_new = keys_new.transpose(1, 2) + values_new = values_new.transpose(1, 2) + queries = queries.transpose(1, 2) + + #################################################### + # NEW + if use_cache: + if self.cache_k is None or self.cache_k.size(0) != b: + self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device) + self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device) + self.current_pos = 0 + + # write new entries + start = self.current_pos + end = start + num_tokens + self.cache_k[:, :, start:end, :] = keys_new + self.cache_v[:, :, start:end, :] = values_new + self.current_pos = end + + # sliding window truncation + if self.current_pos > self.window_size: + self.cache_k = self.cache_k[:, :, -self.window_size:, :] + self.cache_v = self.cache_v[:, :, -self.window_size:, :] + self.current_pos = self.window_size + + keys = self.cache_k[:, :, :self.current_pos, :] + values = self.cache_v[:, :, :self.current_pos, :] + else: + keys = keys_new + values = values_new + #################################################### + + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + + #################################################### + # NEW + K = attn_scores.size(-1) + + if num_tokens == K: + # No cache → use the pre‑baked triangular mask slice + causal_mask = torch.triu(torch.ones(num_tokens, K, device=x.device, dtype=torch.bool), diagonal=1) + else: + # Cached: need to offset the diagonal by (K − num_tokens) + offset = K - num_tokens # number of tokens already in cache before this chunk + row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1) # (num_tokens, 1) + col_idx = torch.arange(K, device=x.device).unsqueeze(0) # (1, K) + causal_mask = row_idx + offset < col_idx # True where j > i+offset + #################################################### + + # Use the mask to fill attention scores + attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + #################################################### + # NEW + def reset_cache(self): + self.cache_k, self.cache_v = None, None + #################################################### + + +##################################### +# Chapter 4 +##################################### +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), + GELU(), + nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), + ) + + def forward(self, x): + return self.layers(x) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + context_length=cfg["context_length"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"], + window_size=cfg["kv_window_size"]) # NEW + self.ff = FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x, use_cache=False): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + + # x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + #################################################### + # NEW + x = self.att(x, use_cache=use_cache) + #################################################### + + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + # self.trf_blocks = nn.Sequential( + # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + #################################################### + # NEW + self.trf_blocks = nn.ModuleList( + [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + + self.current_pos = 0 + #################################################### + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx, use_cache=False): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + + # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + + #################################################### + # NEW + + if use_cache: + pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) + self.current_pos += seq_len + else: + pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long) + pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) + #################################################### + + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + + # x = self.trf_blocks(x) + #################################################### + # NEW + for blk in self.trf_blocks: + x = blk(x, use_cache=use_cache) + #################################################### + + x = self.final_norm(x) + logits = self.out_head(x) + return logits + + #################################################### + # NEW + def reset_kv_cache(self): + for blk in self.trf_blocks: + blk.att.reset_cache() + + #################################################### + + +def generate_text_simple(model, idx, max_new_tokens, context_size): + # idx is (B, T) array of indices in the current context + for _ in range(max_new_tokens): + + # Crop current context if it exceeds the supported context size + # E.g., if LLM supports only 5 tokens, and the context size is 10 + # then only the last 5 tokens are used as context + idx_cond = idx[:, -context_size:] + + # Get the predictions + with torch.no_grad(): + logits = model(idx_cond) + + # Focus only on the last time step + # (batch, n_token, vocab_size) becomes (batch, vocab_size) + logits = logits[:, -1, :] + + # Get the idx of the vocab entry with the highest logits value + idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1) + + # Append sampled index to the running sequence + idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1) + + return idx + + +#################################################### +# NEW +def generate_text_simple_cached(model, idx, max_new_tokens): + model.eval() + model.reset_kv_cache() + + # Init cache with full prompt + logits = model(idx, use_cache=True) + + for _ in range(max_new_tokens): + last_logits = logits[:, -1] + next_idx = last_logits.argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) + + logits = model(next_idx, use_cache=True) + + return idx +#################################################### + + +def main(): + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": 1024, # Context length + "emb_dim": 768, # Embedding dimension + "n_heads": 12, # Number of attention heads + "n_layers": 12, # Number of layers + "drop_rate": 0.1, # Dropout rate + "qkv_bias": False, # Query-Key-Value bias + "kv_window_size": 48 # NEW: KV cache window size + } + + torch.manual_seed(123) + model = GPTModel(GPT_CONFIG_124M) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() # disable dropout + + start_context = "Hello, I am" + + tokenizer = tiktoken.get_encoding("gpt2") + encoded = tokenizer.encode(start_context) + encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0) + + print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}") + print("\nInput text:", start_context) + print("Encoded input text:", encoded) + print("encoded_tensor.shape:", encoded_tensor.shape) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + + # token_ids = generate_text_simple( + # model=model, + # idx=encoded_tensor, + # max_new_tokens=200, + # context_size=GPT_CONFIG_124M["context_length"] + # ) + + #################################################### + # NEW + token_ids = generate_text_simple_cached( + model=model, + idx=encoded_tensor, + max_new_tokens=200, + ) + #################################################### + + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_time = time.time() - start + + decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist()) + + print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}") + print("\nOutput:", token_ids) + print("Output length:", len(token_ids[0])) + print("Output text:", decoded_text) + + print(f"\nTime: {total_time:.2f} sec") + print(f"{int(len(token_ids[0])/total_time)} tokens/sec") + if torch.cuda.is_available(): + max_mem_bytes = torch.cuda.max_memory_allocated() + max_mem_gb = max_mem_bytes / (1024 ** 3) + print(f"Max memory allocated: {max_mem_gb:.2f} GB") + + +if __name__ == "__main__": + main()