mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-27 07:49:25 +00:00
Optimize KV cache (#673)
* Optimize KV cache * style * interpretable generate * interpretable generate * update readme
This commit is contained in:
parent
ba0370abd1
commit
ece59ba587
@ -157,23 +157,36 @@ def reset_kv_cache(self):
|
||||
With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
|
||||
|
||||
```python
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens):
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
||||
model.eval()
|
||||
model.reset_kv_cache()
|
||||
|
||||
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
||||
if use_cache:
|
||||
# Init cache with full prompt
|
||||
logits = model(idx, use_cache=True)
|
||||
model.reset_kv_cache()
|
||||
with torch.no_grad():
|
||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
last_logits = logits[:, -1]
|
||||
next_idx = last_logits.argmax(dim=-1, keepdim=True)
|
||||
# a) pick the token with the highest log-probability (greedy sampling)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
# b) append it to the running sequence
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
# c) feed model only the new token
|
||||
with torch.no_grad():
|
||||
logits = model(next_idx, use_cache=True)
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
with torch.no_grad():
|
||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
||||
```
|
||||
|
||||
Note that we only feed the model the new token in c) via `logits = model(next_idx, use_cache=True)`. Without caching, we feed the model the whole input `logits = model(idx[:, -ctx_len:], use_cache=False)` as it has no stored keys and values to reuse.
|
||||
|
||||
|
||||
|
||||
## Simple performance comparison
|
||||
@ -191,9 +204,9 @@ python gpt_with_kv_cache.py
|
||||
On a Mac Mini with M4 chip (CPU), the results are as follows:
|
||||
|
||||
| | Tokens/sec |
|
||||
| ----------------------- | ---------- |
|
||||
| ---------------------- | ---------- |
|
||||
| `gpt_ch04.py` | 27 |
|
||||
| `gpt_with_kv_cache.py` | 110 |
|
||||
| `gpt_with_kv_cache.py` | 144 |
|
||||
|
||||
So, as we can see, we already get a ~5x speed-up with a small 124 M parameter model and a short 200-token sequence length. (Note that this implementation is optimized for code readability and not optimized for CUDA or MPS runtime speed, which would require pre-allocating tensors instead of reinstating and concatenating them.)
|
||||
|
||||
@ -263,19 +276,12 @@ cache_v = cache_v[:, :, -window_size:, :]
|
||||
You can find these optimizations in the [`gpt_with_kv_cache_optimized.py`](gpt_with_kv_cache_optimized.py) file.
|
||||
|
||||
|
||||
On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size of 48 below, the code runtimes compare as follows:
|
||||
On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size equal to the context length (to guarantee same results) below, the code runtimes compare as follows:
|
||||
|
||||
| | Tokens/sec |
|
||||
| -------------------------------- | ---------- |
|
||||
| `gpt_ch04.py` | 27 |
|
||||
| `gpt_with_kv_cache.py` | 110 |
|
||||
| `gpt_with_kv_cache_optimized.py` | 148 |
|
||||
|
||||
Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model. However, we can see a significant difference in the memory usage:
|
||||
|
||||
| | RAM |
|
||||
| -------------------------------- | ------- |
|
||||
| `gpt_ch04.py` | 0.74 GB |
|
||||
| `gpt_with_kv_cache.py` | 4.35 GB |
|
||||
| `gpt_with_kv_cache_optimized.py` | 0.89 GB |
|
||||
| `gpt_with_kv_cache.py` | 144 |
|
||||
| `gpt_with_kv_cache_optimized.py` | 166 |
|
||||
|
||||
Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model.
|
||||
|
||||
@ -171,6 +171,7 @@ class GPTModel(nn.Module):
|
||||
|
||||
|
||||
def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
model.eval()
|
||||
# idx is (B, T) array of indices in the current context
|
||||
for _ in range(max_new_tokens):
|
||||
|
||||
|
||||
@ -264,19 +264,30 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens):
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
||||
model.eval()
|
||||
model.reset_kv_cache()
|
||||
|
||||
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
||||
if use_cache:
|
||||
# Init cache with full prompt
|
||||
logits = model(idx, use_cache=True)
|
||||
model.reset_kv_cache()
|
||||
with torch.no_grad():
|
||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
last_logits = logits[:, -1]
|
||||
next_idx = last_logits.argmax(dim=-1, keepdim=True)
|
||||
# a) pick the token with the highest log-probability (greedy sampling)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
# b) append it to the running sequence
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
# c) feed model only the new token
|
||||
with torch.no_grad():
|
||||
logits = model(next_idx, use_cache=True)
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
with torch.no_grad():
|
||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
||||
####################################################
|
||||
|
||||
@ -56,28 +56,29 @@ class MultiHeadAttention(nn.Module):
|
||||
# NEW
|
||||
if use_cache:
|
||||
if self.cache_k is None or self.cache_k.size(0) != b:
|
||||
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
|
||||
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
|
||||
self.current_pos = 0
|
||||
self.cache_k = torch.zeros(b, self.num_heads,
|
||||
self.window_size, self.head_dim,
|
||||
device=x.device)
|
||||
self.cache_v = torch.zeros_like(self.cache_k)
|
||||
self.ptr_cur = 0 # pointer to next free slot
|
||||
|
||||
# write new entries
|
||||
start = self.current_pos
|
||||
end = start + num_tokens
|
||||
self.cache_k[:, :, start:end, :] = keys_new
|
||||
self.cache_v[:, :, start:end, :] = values_new
|
||||
self.current_pos = end
|
||||
# if incoming chunk would overflow discard oldest tokens
|
||||
if self.ptr_cur + num_tokens > self.window_size:
|
||||
overflow = self.ptr_cur + num_tokens - self.window_size
|
||||
# shift everything left by `overflow` (cheap view-copy)
|
||||
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
|
||||
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
|
||||
self.ptr_cur -= overflow # pointer after shift
|
||||
|
||||
# sliding window truncation
|
||||
if self.current_pos > self.window_size:
|
||||
self.cache_k = self.cache_k[:, :, -self.window_size:, :]
|
||||
self.cache_v = self.cache_v[:, :, -self.window_size:, :]
|
||||
self.current_pos = self.window_size
|
||||
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
|
||||
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
|
||||
self.ptr_cur += num_tokens
|
||||
|
||||
keys = self.cache_k[:, :, :self.current_pos, :]
|
||||
values = self.cache_v[:, :, :self.current_pos, :]
|
||||
keys = self.cache_k[:, :, :self.ptr_cur, :]
|
||||
values = self.cache_v[:, :, :self.ptr_cur, :]
|
||||
else:
|
||||
keys = keys_new
|
||||
values = values_new
|
||||
keys, values = keys_new, values_new
|
||||
self.ptr_cur = 0 # keep pointer sane if you interleave modes
|
||||
####################################################
|
||||
|
||||
|
||||
@ -216,7 +217,7 @@ class GPTModel(nn.Module):
|
||||
self.trf_blocks = nn.ModuleList(
|
||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
|
||||
self.current_pos = 0
|
||||
self.ptr_current_pos = 0
|
||||
####################################################
|
||||
|
||||
self.final_norm = LayerNorm(cfg["emb_dim"])
|
||||
@ -232,8 +233,8 @@ class GPTModel(nn.Module):
|
||||
# NEW
|
||||
|
||||
if use_cache:
|
||||
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
||||
self.current_pos += seq_len
|
||||
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
||||
self.ptr_current_pos += seq_len
|
||||
else:
|
||||
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
|
||||
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
|
||||
@ -258,7 +259,7 @@ class GPTModel(nn.Module):
|
||||
def reset_kv_cache(self):
|
||||
for blk in self.trf_blocks:
|
||||
blk.att.reset_cache()
|
||||
|
||||
self.ptr_current_pos = 0
|
||||
####################################################
|
||||
|
||||
|
||||
@ -290,19 +291,30 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens):
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
||||
model.eval()
|
||||
model.reset_kv_cache()
|
||||
|
||||
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
||||
if use_cache:
|
||||
# Init cache with full prompt
|
||||
logits = model(idx, use_cache=True)
|
||||
model.reset_kv_cache()
|
||||
with torch.no_grad():
|
||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
last_logits = logits[:, -1]
|
||||
next_idx = last_logits.argmax(dim=-1, keepdim=True)
|
||||
# a) pick the token with the highest log-probability (greedy sampling)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
# b) append it to the running sequence
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
# c) feed model only the new token
|
||||
with torch.no_grad():
|
||||
logits = model(next_idx, use_cache=True)
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
with torch.no_grad():
|
||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
||||
####################################################
|
||||
@ -317,7 +329,7 @@ def main():
|
||||
"n_layers": 12, # Number of layers
|
||||
"drop_rate": 0.1, # Dropout rate
|
||||
"qkv_bias": False, # Query-Key-Value bias
|
||||
"kv_window_size": 48 # NEW: KV cache window size
|
||||
"kv_window_size": 1024 # NEW: KV cache window size
|
||||
}
|
||||
|
||||
torch.manual_seed(123)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user