Optimize KV cache (#673)

* Optimize KV cache

* style

* interpretable generate

* interpretable generate

* update readme
This commit is contained in:
Sebastian Raschka 2025-06-16 16:00:50 -05:00 committed by GitHub
parent ba0370abd1
commit ece59ba587
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 98 additions and 68 deletions

View File

@ -157,23 +157,36 @@ def reset_kv_cache(self):
With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
```python
def generate_text_simple_cached(model, idx, max_new_tokens):
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
model.eval()
model.reset_kv_cache()
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
if use_cache:
# Init cache with full prompt
logits = model(idx, use_cache=True)
model.reset_kv_cache()
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
last_logits = logits[:, -1]
next_idx = last_logits.argmax(dim=-1, keepdim=True)
# a) pick the token with the highest log-probability (greedy sampling)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
# b) append it to the running sequence
idx = torch.cat([idx, next_idx], dim=1)
# c) feed model only the new token
with torch.no_grad():
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx
```
Note that we only feed the model the new token in c) via `logits = model(next_idx, use_cache=True)`. Without caching, we feed the model the whole input `logits = model(idx[:, -ctx_len:], use_cache=False)` as it has no stored keys and values to reuse.
 
## Simple performance comparison
@ -191,9 +204,9 @@ python gpt_with_kv_cache.py
On a Mac Mini with M4 chip (CPU), the results are as follows:
| | Tokens/sec |
| ----------------------- | ---------- |
| ---------------------- | ---------- |
| `gpt_ch04.py` | 27 |
| `gpt_with_kv_cache.py` | 110 |
| `gpt_with_kv_cache.py` | 144 |
So, as we can see, we already get a ~5x speed-up with a small 124 M parameter model and a short 200-token sequence length. (Note that this implementation is optimized for code readability and not optimized for CUDA or MPS runtime speed, which would require pre-allocating tensors instead of reinstating and concatenating them.)
@ -263,19 +276,12 @@ cache_v = cache_v[:, :, -window_size:, :]
You can find these optimizations in the [`gpt_with_kv_cache_optimized.py`](gpt_with_kv_cache_optimized.py) file.
On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size of 48 below, the code runtimes compare as follows:
On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size equal to the context length (to guarantee same results) below, the code runtimes compare as follows:
| | Tokens/sec |
| -------------------------------- | ---------- |
| `gpt_ch04.py` | 27 |
| `gpt_with_kv_cache.py` | 110 |
| `gpt_with_kv_cache_optimized.py` | 148 |
Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model. However, we can see a significant difference in the memory usage:
| | RAM |
| -------------------------------- | ------- |
| `gpt_ch04.py` | 0.74 GB |
| `gpt_with_kv_cache.py` | 4.35 GB |
| `gpt_with_kv_cache_optimized.py` | 0.89 GB |
| `gpt_with_kv_cache.py` | 144 |
| `gpt_with_kv_cache_optimized.py` | 166 |
Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model.

View File

@ -171,6 +171,7 @@ class GPTModel(nn.Module):
def generate_text_simple(model, idx, max_new_tokens, context_size):
model.eval()
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):

View File

@ -264,19 +264,30 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
####################################################
# NEW
def generate_text_simple_cached(model, idx, max_new_tokens):
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
model.eval()
model.reset_kv_cache()
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
if use_cache:
# Init cache with full prompt
logits = model(idx, use_cache=True)
model.reset_kv_cache()
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
last_logits = logits[:, -1]
next_idx = last_logits.argmax(dim=-1, keepdim=True)
# a) pick the token with the highest log-probability (greedy sampling)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
# b) append it to the running sequence
idx = torch.cat([idx, next_idx], dim=1)
# c) feed model only the new token
with torch.no_grad():
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx
####################################################

View File

@ -56,28 +56,29 @@ class MultiHeadAttention(nn.Module):
# NEW
if use_cache:
if self.cache_k is None or self.cache_k.size(0) != b:
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
self.current_pos = 0
self.cache_k = torch.zeros(b, self.num_heads,
self.window_size, self.head_dim,
device=x.device)
self.cache_v = torch.zeros_like(self.cache_k)
self.ptr_cur = 0 # pointer to next free slot
# write new entries
start = self.current_pos
end = start + num_tokens
self.cache_k[:, :, start:end, :] = keys_new
self.cache_v[:, :, start:end, :] = values_new
self.current_pos = end
# if incoming chunk would overflow discard oldest tokens
if self.ptr_cur + num_tokens > self.window_size:
overflow = self.ptr_cur + num_tokens - self.window_size
# shift everything left by `overflow` (cheap view-copy)
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
self.ptr_cur -= overflow # pointer after shift
# sliding window truncation
if self.current_pos > self.window_size:
self.cache_k = self.cache_k[:, :, -self.window_size:, :]
self.cache_v = self.cache_v[:, :, -self.window_size:, :]
self.current_pos = self.window_size
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
self.ptr_cur += num_tokens
keys = self.cache_k[:, :, :self.current_pos, :]
values = self.cache_v[:, :, :self.current_pos, :]
keys = self.cache_k[:, :, :self.ptr_cur, :]
values = self.cache_v[:, :, :self.ptr_cur, :]
else:
keys = keys_new
values = values_new
keys, values = keys_new, values_new
self.ptr_cur = 0 # keep pointer sane if you interleave modes
####################################################
@ -216,7 +217,7 @@ class GPTModel(nn.Module):
self.trf_blocks = nn.ModuleList(
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.current_pos = 0
self.ptr_current_pos = 0
####################################################
self.final_norm = LayerNorm(cfg["emb_dim"])
@ -232,8 +233,8 @@ class GPTModel(nn.Module):
# NEW
if use_cache:
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
self.current_pos += seq_len
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
self.ptr_current_pos += seq_len
else:
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
@ -258,7 +259,7 @@ class GPTModel(nn.Module):
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
####################################################
@ -290,19 +291,30 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
####################################################
# NEW
def generate_text_simple_cached(model, idx, max_new_tokens):
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
model.eval()
model.reset_kv_cache()
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
if use_cache:
# Init cache with full prompt
logits = model(idx, use_cache=True)
model.reset_kv_cache()
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
last_logits = logits[:, -1]
next_idx = last_logits.argmax(dim=-1, keepdim=True)
# a) pick the token with the highest log-probability (greedy sampling)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
# b) append it to the running sequence
idx = torch.cat([idx, next_idx], dim=1)
# c) feed model only the new token
with torch.no_grad():
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx
####################################################
@ -317,7 +329,7 @@ def main():
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False, # Query-Key-Value bias
"kv_window_size": 48 # NEW: KV cache window size
"kv_window_size": 1024 # NEW: KV cache window size
}
torch.manual_seed(123)