Fix bug in masking when kv cache is used. (#697)

* Fix bug in masking when kv cache is used.

* add tests

* dd tests

* upd

* add kv cache test to gh workflow

* explicit mask slicing

* upd

---------

Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
Martin Ma 2025-06-23 11:12:56 -07:00 committed by GitHub
parent e9ffdbace4
commit ad16b1fbee
5 changed files with 179 additions and 57 deletions

View File

@ -49,6 +49,7 @@ jobs:
source .venv/bin/activate
pytest --ruff setup/02_installing-python-libraries/tests.py
pytest --ruff ch04/01_main-chapter-code/tests.py
pytest --ruff ch04/03_kv-cache/tests.py
pytest --ruff ch05/01_main-chapter-code/tests.py
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
pytest --ruff ch06/01_main-chapter-code/tests.py

View File

@ -86,6 +86,18 @@ def forward(self, x, use_cache=False):
keys, values = self.cache_k, self.cache_v
else:
keys, values = keys_new, values_new
# ...
num_tokens_Q = queries.shape[-2]
num_tokens_K = keys.shape[-2]
if use_cache:
mask_bool = self.mask.bool()[
self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
]
self.ptr_current_pos += num_tokens_Q
else:
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
```
&nbsp;
@ -98,6 +110,7 @@ When generating texts, between independent sequences (for instance to text gener
```python
def reset_cache(self):
self.cache_k, self.cache_v = None, None
self.ptr_current_pos = 0
```
&nbsp;
@ -157,14 +170,15 @@ def reset_kv_cache(self):
With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
```python
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
def generate_text_simple_cached(model, idx, max_new_tokens,
context_size=None, use_cache=True):
model.eval()
ctx_len = context_size or model.pos_emb.num_embeddings
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
with torch.no_grad():
if use_cache:
# Init cache with full prompt
model.reset_kv_cache()
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
@ -173,11 +187,9 @@ def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
# b) append it to the running sequence
idx = torch.cat([idx, next_idx], dim=1)
# c) feed model only the new token
with torch.no_grad():
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)

View File

@ -27,7 +27,7 @@ class MultiHeadAttention(nn.Module):
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),diagonal=1),
torch.triu(torch.ones(context_length, context_length), diagonal=1),
persistent=False
)
@ -35,6 +35,7 @@ class MultiHeadAttention(nn.Module):
# NEW
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.ptr_current_pos = 0
####################################################
def forward(self, x, use_cache=False):
@ -71,8 +72,19 @@ class MultiHeadAttention(nn.Module):
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
####################################################
# NEW
num_tokens_Q = queries.shape[-2]
num_tokens_K = keys.shape[-2]
if use_cache:
mask_bool = self.mask.bool()[
self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
]
self.ptr_current_pos += num_tokens_Q
####################################################
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
else:
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
@ -93,6 +105,7 @@ class MultiHeadAttention(nn.Module):
# NEW
def reset_cache(self):
self.cache_k, self.cache_v = None, None
self.ptr_current_pos = 0
####################################################
@ -264,14 +277,15 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
####################################################
# NEW
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
def generate_text_simple_cached(model, idx, max_new_tokens,
context_size=None, use_cache=True):
model.eval()
ctx_len = context_size or model.pos_emb.num_embeddings
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
with torch.no_grad():
if use_cache:
# Init cache with full prompt
model.reset_kv_cache()
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
@ -280,11 +294,9 @@ def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
# b) append it to the running sequence
idx = torch.cat([idx, next_idx], dim=1)
# c) feed model only the new token
with torch.no_grad():
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)

View File

@ -171,7 +171,8 @@ class TransformerBlock(nn.Module):
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"],
window_size=cfg["kv_window_size"]) # NEW
window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
)
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
@ -289,27 +290,22 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
####################################################
# NEW
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True):
model.eval()
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
if use_cache:
# Init cache with full prompt
model.reset_kv_cache()
ctx_len = context_size or model.pos_emb.num_embeddings
with torch.no_grad():
if use_cache:
model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
# a) pick the token with the highest log-probability (greedy sampling)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
# b) append it to the running sequence
idx = torch.cat([idx, next_idx], dim=1)
# c) feed model only the new token
with torch.no_grad():
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)

101
ch04/03_kv-cache/tests.py Normal file
View File

@ -0,0 +1,101 @@
# Code to test the GPT model implementation against the KV cache variants
import pytest
import torch
import tiktoken
from gpt_ch04 import GPTModel as GPTModelBase
from gpt_ch04 import generate_text_simple
from gpt_with_kv_cache import GPTModel as GPTModelKV1
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
from gpt_with_kv_cache import generate_text_simple_cached
GPT_CONFIG_124M = {
"vocab_size": 50257,
"context_length": 1024,
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
def test_gpt_model_equivalence_not_cached(ModelClass):
torch.manual_seed(123)
model = ModelClass(GPT_CONFIG_124M).to(device)
model.eval()
tokenizer = tiktoken.get_encoding("gpt2")
prompt = "Hello, I am"
encoded = tokenizer.encode(prompt)
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
model_name = ModelClass.__module__ + "." + ModelClass.__name__
token_ids = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
if not hasattr(test_gpt_model_equivalence_not_cached, "results"):
test_gpt_model_equivalence_not_cached.results = []
test_gpt_model_equivalence_not_cached.results.append((model_name, token_ids))
if len(test_gpt_model_equivalence_not_cached.results) == 3:
base_name, base_output = test_gpt_model_equivalence_not_cached.results[0]
for other_name, other_output in test_gpt_model_equivalence_not_cached.results[1:]:
assert torch.equal(base_output, other_output), (
f"Mismatch between {base_name} and {other_name}"
)
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
def test_gpt_model_equivalence_cached(ModelClass):
torch.manual_seed(123)
model = ModelClass(GPT_CONFIG_124M).to(device)
model.eval()
tokenizer = tiktoken.get_encoding("gpt2")
prompt = "Hello, I am"
encoded_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
model_name = ModelClass.__module__ + "." + ModelClass.__name__
if ModelClass is GPTModelBase:
token_ids = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
else:
token_ids = generate_text_simple_cached(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
if not hasattr(test_gpt_model_equivalence_cached, "results"):
test_gpt_model_equivalence_cached.results = []
test_gpt_model_equivalence_cached.results.append((model_name, token_ids))
if len(test_gpt_model_equivalence_cached.results) == 3:
base_name, base_output = test_gpt_model_equivalence_cached.results[0]
for other_name, other_output in test_gpt_model_equivalence_cached.results[1:]:
assert torch.equal(base_output, other_output), (
f"Mismatch between {base_name} and {other_name}"
)