mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-27 07:49:25 +00:00
Optimize KV cache (#673)
* Optimize KV cache * style * interpretable generate * interpretable generate * update readme
This commit is contained in:
parent
ba0370abd1
commit
ece59ba587
@ -157,23 +157,36 @@ def reset_kv_cache(self):
|
|||||||
With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
|
With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def generate_text_simple_cached(model, idx, max_new_tokens):
|
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
||||||
model.eval()
|
model.eval()
|
||||||
model.reset_kv_cache()
|
|
||||||
|
|
||||||
# Init cache with full prompt
|
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
||||||
logits = model(idx, use_cache=True)
|
if use_cache:
|
||||||
|
# Init cache with full prompt
|
||||||
|
model.reset_kv_cache()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
last_logits = logits[:, -1]
|
# a) pick the token with the highest log-probability (greedy sampling)
|
||||||
next_idx = last_logits.argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
# b) append it to the running sequence
|
||||||
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
logits = model(next_idx, use_cache=True)
|
# c) feed model only the new token
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(next_idx, use_cache=True)
|
||||||
|
else:
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||||
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
|
|
||||||
return idx
|
return idx
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note that we only feed the model the new token in c) via `logits = model(next_idx, use_cache=True)`. Without caching, we feed the model the whole input `logits = model(idx[:, -ctx_len:], use_cache=False)` as it has no stored keys and values to reuse.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Simple performance comparison
|
## Simple performance comparison
|
||||||
@ -190,10 +203,10 @@ python gpt_with_kv_cache.py
|
|||||||
|
|
||||||
On a Mac Mini with M4 chip (CPU), the results are as follows:
|
On a Mac Mini with M4 chip (CPU), the results are as follows:
|
||||||
|
|
||||||
| | Tokens/sec |
|
| | Tokens/sec |
|
||||||
| ----------------------- | ---------- |
|
| ---------------------- | ---------- |
|
||||||
| `gpt_ch04.py` | 27 |
|
| `gpt_ch04.py` | 27 |
|
||||||
| `gpt_with_kv_cache.py` | 110 |
|
| `gpt_with_kv_cache.py` | 144 |
|
||||||
|
|
||||||
So, as we can see, we already get a ~5x speed-up with a small 124 M parameter model and a short 200-token sequence length. (Note that this implementation is optimized for code readability and not optimized for CUDA or MPS runtime speed, which would require pre-allocating tensors instead of reinstating and concatenating them.)
|
So, as we can see, we already get a ~5x speed-up with a small 124 M parameter model and a short 200-token sequence length. (Note that this implementation is optimized for code readability and not optimized for CUDA or MPS runtime speed, which would require pre-allocating tensors instead of reinstating and concatenating them.)
|
||||||
|
|
||||||
@ -263,19 +276,12 @@ cache_v = cache_v[:, :, -window_size:, :]
|
|||||||
You can find these optimizations in the [`gpt_with_kv_cache_optimized.py`](gpt_with_kv_cache_optimized.py) file.
|
You can find these optimizations in the [`gpt_with_kv_cache_optimized.py`](gpt_with_kv_cache_optimized.py) file.
|
||||||
|
|
||||||
|
|
||||||
On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size of 48 below, the code runtimes compare as follows:
|
On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size equal to the context length (to guarantee same results) below, the code runtimes compare as follows:
|
||||||
|
|
||||||
| | Tokens/sec |
|
| | Tokens/sec |
|
||||||
| -------------------------------- | ---------- |
|
| -------------------------------- | ---------- |
|
||||||
| `gpt_ch04.py` | 27 |
|
| `gpt_ch04.py` | 27 |
|
||||||
| `gpt_with_kv_cache.py` | 110 |
|
| `gpt_with_kv_cache.py` | 144 |
|
||||||
| `gpt_with_kv_cache_optimized.py` | 148 |
|
| `gpt_with_kv_cache_optimized.py` | 166 |
|
||||||
|
|
||||||
Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model. However, we can see a significant difference in the memory usage:
|
|
||||||
|
|
||||||
| | RAM |
|
|
||||||
| -------------------------------- | ------- |
|
|
||||||
| `gpt_ch04.py` | 0.74 GB |
|
|
||||||
| `gpt_with_kv_cache.py` | 4.35 GB |
|
|
||||||
| `gpt_with_kv_cache_optimized.py` | 0.89 GB |
|
|
||||||
|
|
||||||
|
Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model.
|
||||||
|
|||||||
@ -171,6 +171,7 @@ class GPTModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def generate_text_simple(model, idx, max_new_tokens, context_size):
|
def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||||
|
model.eval()
|
||||||
# idx is (B, T) array of indices in the current context
|
# idx is (B, T) array of indices in the current context
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
|
|
||||||
|
|||||||
@ -264,19 +264,30 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
|||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
# NEW
|
# NEW
|
||||||
def generate_text_simple_cached(model, idx, max_new_tokens):
|
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
||||||
model.eval()
|
model.eval()
|
||||||
model.reset_kv_cache()
|
|
||||||
|
|
||||||
# Init cache with full prompt
|
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
||||||
logits = model(idx, use_cache=True)
|
if use_cache:
|
||||||
|
# Init cache with full prompt
|
||||||
|
model.reset_kv_cache()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
last_logits = logits[:, -1]
|
# a) pick the token with the highest log-probability (greedy sampling)
|
||||||
next_idx = last_logits.argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
# b) append it to the running sequence
|
||||||
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
logits = model(next_idx, use_cache=True)
|
# c) feed model only the new token
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(next_idx, use_cache=True)
|
||||||
|
else:
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||||
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
|
|
||||||
return idx
|
return idx
|
||||||
####################################################
|
####################################################
|
||||||
|
|||||||
@ -56,28 +56,29 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# NEW
|
# NEW
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if self.cache_k is None or self.cache_k.size(0) != b:
|
if self.cache_k is None or self.cache_k.size(0) != b:
|
||||||
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
|
self.cache_k = torch.zeros(b, self.num_heads,
|
||||||
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
|
self.window_size, self.head_dim,
|
||||||
self.current_pos = 0
|
device=x.device)
|
||||||
|
self.cache_v = torch.zeros_like(self.cache_k)
|
||||||
|
self.ptr_cur = 0 # pointer to next free slot
|
||||||
|
|
||||||
# write new entries
|
# if incoming chunk would overflow discard oldest tokens
|
||||||
start = self.current_pos
|
if self.ptr_cur + num_tokens > self.window_size:
|
||||||
end = start + num_tokens
|
overflow = self.ptr_cur + num_tokens - self.window_size
|
||||||
self.cache_k[:, :, start:end, :] = keys_new
|
# shift everything left by `overflow` (cheap view-copy)
|
||||||
self.cache_v[:, :, start:end, :] = values_new
|
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
|
||||||
self.current_pos = end
|
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
|
||||||
|
self.ptr_cur -= overflow # pointer after shift
|
||||||
|
|
||||||
# sliding window truncation
|
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
|
||||||
if self.current_pos > self.window_size:
|
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
|
||||||
self.cache_k = self.cache_k[:, :, -self.window_size:, :]
|
self.ptr_cur += num_tokens
|
||||||
self.cache_v = self.cache_v[:, :, -self.window_size:, :]
|
|
||||||
self.current_pos = self.window_size
|
|
||||||
|
|
||||||
keys = self.cache_k[:, :, :self.current_pos, :]
|
keys = self.cache_k[:, :, :self.ptr_cur, :]
|
||||||
values = self.cache_v[:, :, :self.current_pos, :]
|
values = self.cache_v[:, :, :self.ptr_cur, :]
|
||||||
else:
|
else:
|
||||||
keys = keys_new
|
keys, values = keys_new, values_new
|
||||||
values = values_new
|
self.ptr_cur = 0 # keep pointer sane if you interleave modes
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
|
|
||||||
@ -216,7 +217,7 @@ class GPTModel(nn.Module):
|
|||||||
self.trf_blocks = nn.ModuleList(
|
self.trf_blocks = nn.ModuleList(
|
||||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||||
|
|
||||||
self.current_pos = 0
|
self.ptr_current_pos = 0
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
self.final_norm = LayerNorm(cfg["emb_dim"])
|
self.final_norm = LayerNorm(cfg["emb_dim"])
|
||||||
@ -232,8 +233,8 @@ class GPTModel(nn.Module):
|
|||||||
# NEW
|
# NEW
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
||||||
self.current_pos += seq_len
|
self.ptr_current_pos += seq_len
|
||||||
else:
|
else:
|
||||||
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
|
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
|
||||||
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
|
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
|
||||||
@ -258,7 +259,7 @@ class GPTModel(nn.Module):
|
|||||||
def reset_kv_cache(self):
|
def reset_kv_cache(self):
|
||||||
for blk in self.trf_blocks:
|
for blk in self.trf_blocks:
|
||||||
blk.att.reset_cache()
|
blk.att.reset_cache()
|
||||||
|
self.ptr_current_pos = 0
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
|
|
||||||
@ -290,19 +291,30 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
|||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
# NEW
|
# NEW
|
||||||
def generate_text_simple_cached(model, idx, max_new_tokens):
|
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
||||||
model.eval()
|
model.eval()
|
||||||
model.reset_kv_cache()
|
|
||||||
|
|
||||||
# Init cache with full prompt
|
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
||||||
logits = model(idx, use_cache=True)
|
if use_cache:
|
||||||
|
# Init cache with full prompt
|
||||||
|
model.reset_kv_cache()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
last_logits = logits[:, -1]
|
# a) pick the token with the highest log-probability (greedy sampling)
|
||||||
next_idx = last_logits.argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
# b) append it to the running sequence
|
||||||
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
logits = model(next_idx, use_cache=True)
|
# c) feed model only the new token
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(next_idx, use_cache=True)
|
||||||
|
else:
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||||
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
|
|
||||||
return idx
|
return idx
|
||||||
####################################################
|
####################################################
|
||||||
@ -317,7 +329,7 @@ def main():
|
|||||||
"n_layers": 12, # Number of layers
|
"n_layers": 12, # Number of layers
|
||||||
"drop_rate": 0.1, # Dropout rate
|
"drop_rate": 0.1, # Dropout rate
|
||||||
"qkv_bias": False, # Query-Key-Value bias
|
"qkv_bias": False, # Query-Key-Value bias
|
||||||
"kv_window_size": 48 # NEW: KV cache window size
|
"kv_window_size": 1024 # NEW: KV cache window size
|
||||||
}
|
}
|
||||||
|
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user