From ece59ba58768db7b34d9b5d5f88677de8c1e84ea Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Mon, 16 Jun 2025 16:00:50 -0500 Subject: [PATCH] Optimize KV cache (#673) * Optimize KV cache * style * interpretable generate * interpretable generate * update readme --- ch04/03_kv-cache/README.md | 56 +++++++------ ch04/03_kv-cache/gpt_ch04.py | 1 + ch04/03_kv-cache/gpt_with_kv_cache.py | 31 +++++--- .../gpt_with_kv_cache_optimized.py | 78 +++++++++++-------- 4 files changed, 98 insertions(+), 68 deletions(-) diff --git a/ch04/03_kv-cache/README.md b/ch04/03_kv-cache/README.md index 8c00d1f..3474ab8 100644 --- a/ch04/03_kv-cache/README.md +++ b/ch04/03_kv-cache/README.md @@ -157,23 +157,36 @@ def reset_kv_cache(self): With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function: ```python -def generate_text_simple_cached(model, idx, max_new_tokens): +def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True): model.eval() - model.reset_kv_cache() - # Init cache with full prompt - logits = model(idx, use_cache=True) + ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024 + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() + with torch.no_grad(): + logits = model(idx[:, -ctx_len:], use_cache=True) - for _ in range(max_new_tokens): - last_logits = logits[:, -1] - next_idx = last_logits.argmax(dim=-1, keepdim=True) - idx = torch.cat([idx, next_idx], dim=1) - - logits = model(next_idx, use_cache=True) + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + # b) append it to the running sequence + idx = torch.cat([idx, next_idx], dim=1) + # c) feed model only the new token + with torch.no_grad(): + logits = model(next_idx, use_cache=True) + else: + for _ in range(max_new_tokens): + with torch.no_grad(): + logits = model(idx[:, -ctx_len:], use_cache=False) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) return idx ``` +Note that we only feed the model the new token in c) via `logits = model(next_idx, use_cache=True)`. Without caching, we feed the model the whole input `logits = model(idx[:, -ctx_len:], use_cache=False)` as it has no stored keys and values to reuse. +   ## Simple performance comparison @@ -190,10 +203,10 @@ python gpt_with_kv_cache.py On a Mac Mini with M4 chip (CPU), the results are as follows: -| | Tokens/sec | -| ----------------------- | ---------- | -| `gpt_ch04.py` | 27 | -| `gpt_with_kv_cache.py` | 110 | +| | Tokens/sec | +| ---------------------- | ---------- | +| `gpt_ch04.py` | 27 | +| `gpt_with_kv_cache.py` | 144 | So, as we can see, we already get a ~5x speed-up with a small 124 M parameter model and a short 200-token sequence length. (Note that this implementation is optimized for code readability and not optimized for CUDA or MPS runtime speed, which would require pre-allocating tensors instead of reinstating and concatenating them.) @@ -263,19 +276,12 @@ cache_v = cache_v[:, :, -window_size:, :] You can find these optimizations in the [`gpt_with_kv_cache_optimized.py`](gpt_with_kv_cache_optimized.py) file. -On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size of 48 below, the code runtimes compare as follows: +On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size equal to the context length (to guarantee same results) below, the code runtimes compare as follows: | | Tokens/sec | | -------------------------------- | ---------- | | `gpt_ch04.py` | 27 | -| `gpt_with_kv_cache.py` | 110 | -| `gpt_with_kv_cache_optimized.py` | 148 | - -Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model. However, we can see a significant difference in the memory usage: - -| | RAM | -| -------------------------------- | ------- | -| `gpt_ch04.py` | 0.74 GB | -| `gpt_with_kv_cache.py` | 4.35 GB | -| `gpt_with_kv_cache_optimized.py` | 0.89 GB | +| `gpt_with_kv_cache.py` | 144 | +| `gpt_with_kv_cache_optimized.py` | 166 | +Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model. diff --git a/ch04/03_kv-cache/gpt_ch04.py b/ch04/03_kv-cache/gpt_ch04.py index 47b75bb..efda40b 100644 --- a/ch04/03_kv-cache/gpt_ch04.py +++ b/ch04/03_kv-cache/gpt_ch04.py @@ -171,6 +171,7 @@ class GPTModel(nn.Module): def generate_text_simple(model, idx, max_new_tokens, context_size): + model.eval() # idx is (B, T) array of indices in the current context for _ in range(max_new_tokens): diff --git a/ch04/03_kv-cache/gpt_with_kv_cache.py b/ch04/03_kv-cache/gpt_with_kv_cache.py index 9576d8a..c685fe0 100644 --- a/ch04/03_kv-cache/gpt_with_kv_cache.py +++ b/ch04/03_kv-cache/gpt_with_kv_cache.py @@ -264,19 +264,30 @@ def generate_text_simple(model, idx, max_new_tokens, context_size): #################################################### # NEW -def generate_text_simple_cached(model, idx, max_new_tokens): +def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True): model.eval() - model.reset_kv_cache() - # Init cache with full prompt - logits = model(idx, use_cache=True) + ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024 + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() + with torch.no_grad(): + logits = model(idx[:, -ctx_len:], use_cache=True) - for _ in range(max_new_tokens): - last_logits = logits[:, -1] - next_idx = last_logits.argmax(dim=-1, keepdim=True) - idx = torch.cat([idx, next_idx], dim=1) - - logits = model(next_idx, use_cache=True) + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + # b) append it to the running sequence + idx = torch.cat([idx, next_idx], dim=1) + # c) feed model only the new token + with torch.no_grad(): + logits = model(next_idx, use_cache=True) + else: + for _ in range(max_new_tokens): + with torch.no_grad(): + logits = model(idx[:, -ctx_len:], use_cache=False) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) return idx #################################################### diff --git a/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py b/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py index 8b233ca..a17cc46 100644 --- a/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py +++ b/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py @@ -56,28 +56,29 @@ class MultiHeadAttention(nn.Module): # NEW if use_cache: if self.cache_k is None or self.cache_k.size(0) != b: - self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device) - self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device) - self.current_pos = 0 + self.cache_k = torch.zeros(b, self.num_heads, + self.window_size, self.head_dim, + device=x.device) + self.cache_v = torch.zeros_like(self.cache_k) + self.ptr_cur = 0 # pointer to next free slot - # write new entries - start = self.current_pos - end = start + num_tokens - self.cache_k[:, :, start:end, :] = keys_new - self.cache_v[:, :, start:end, :] = values_new - self.current_pos = end + # if incoming chunk would overflow discard oldest tokens + if self.ptr_cur + num_tokens > self.window_size: + overflow = self.ptr_cur + num_tokens - self.window_size + # shift everything left by `overflow` (cheap view-copy) + self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone() + self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone() + self.ptr_cur -= overflow # pointer after shift - # sliding window truncation - if self.current_pos > self.window_size: - self.cache_k = self.cache_k[:, :, -self.window_size:, :] - self.cache_v = self.cache_v[:, :, -self.window_size:, :] - self.current_pos = self.window_size + self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new + self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new + self.ptr_cur += num_tokens - keys = self.cache_k[:, :, :self.current_pos, :] - values = self.cache_v[:, :, :self.current_pos, :] + keys = self.cache_k[:, :, :self.ptr_cur, :] + values = self.cache_v[:, :, :self.ptr_cur, :] else: - keys = keys_new - values = values_new + keys, values = keys_new, values_new + self.ptr_cur = 0 # keep pointer sane if you interleave modes #################################################### @@ -216,7 +217,7 @@ class GPTModel(nn.Module): self.trf_blocks = nn.ModuleList( [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) - self.current_pos = 0 + self.ptr_current_pos = 0 #################################################### self.final_norm = LayerNorm(cfg["emb_dim"]) @@ -232,8 +233,8 @@ class GPTModel(nn.Module): # NEW if use_cache: - pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) - self.current_pos += seq_len + pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long) + self.ptr_current_pos += seq_len else: pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long) pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) @@ -258,7 +259,7 @@ class GPTModel(nn.Module): def reset_kv_cache(self): for blk in self.trf_blocks: blk.att.reset_cache() - + self.ptr_current_pos = 0 #################################################### @@ -290,19 +291,30 @@ def generate_text_simple(model, idx, max_new_tokens, context_size): #################################################### # NEW -def generate_text_simple_cached(model, idx, max_new_tokens): +def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True): model.eval() - model.reset_kv_cache() - # Init cache with full prompt - logits = model(idx, use_cache=True) + ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024 + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() + with torch.no_grad(): + logits = model(idx[:, -ctx_len:], use_cache=True) - for _ in range(max_new_tokens): - last_logits = logits[:, -1] - next_idx = last_logits.argmax(dim=-1, keepdim=True) - idx = torch.cat([idx, next_idx], dim=1) - - logits = model(next_idx, use_cache=True) + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + # b) append it to the running sequence + idx = torch.cat([idx, next_idx], dim=1) + # c) feed model only the new token + with torch.no_grad(): + logits = model(next_idx, use_cache=True) + else: + for _ in range(max_new_tokens): + with torch.no_grad(): + logits = model(idx[:, -ctx_len:], use_cache=False) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) return idx #################################################### @@ -317,7 +329,7 @@ def main(): "n_layers": 12, # Number of layers "drop_rate": 0.1, # Dropout rate "qkv_bias": False, # Query-Key-Value bias - "kv_window_size": 48 # NEW: KV cache window size + "kv_window_size": 1024 # NEW: KV cache window size } torch.manual_seed(123)