mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-28 18:40:01 +00:00
Fix bug in masking when kv cache is used. (#697)
* Fix bug in masking when kv cache is used. * add tests * dd tests * upd * add kv cache test to gh workflow * explicit mask slicing * upd --------- Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
parent
e9ffdbace4
commit
ad16b1fbee
1
.github/workflows/basic-tests-linux-uv.yml
vendored
1
.github/workflows/basic-tests-linux-uv.yml
vendored
@ -49,6 +49,7 @@ jobs:
|
|||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
pytest --ruff setup/02_installing-python-libraries/tests.py
|
pytest --ruff setup/02_installing-python-libraries/tests.py
|
||||||
pytest --ruff ch04/01_main-chapter-code/tests.py
|
pytest --ruff ch04/01_main-chapter-code/tests.py
|
||||||
|
pytest --ruff ch04/03_kv-cache/tests.py
|
||||||
pytest --ruff ch05/01_main-chapter-code/tests.py
|
pytest --ruff ch05/01_main-chapter-code/tests.py
|
||||||
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
|
||||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||||
|
@ -86,6 +86,18 @@ def forward(self, x, use_cache=False):
|
|||||||
keys, values = self.cache_k, self.cache_v
|
keys, values = self.cache_k, self.cache_v
|
||||||
else:
|
else:
|
||||||
keys, values = keys_new, values_new
|
keys, values = keys_new, values_new
|
||||||
|
|
||||||
|
# ...
|
||||||
|
|
||||||
|
num_tokens_Q = queries.shape[-2]
|
||||||
|
num_tokens_K = keys.shape[-2]
|
||||||
|
if use_cache:
|
||||||
|
mask_bool = self.mask.bool()[
|
||||||
|
self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
|
||||||
|
]
|
||||||
|
self.ptr_current_pos += num_tokens_Q
|
||||||
|
else:
|
||||||
|
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@ -98,6 +110,7 @@ When generating texts, between independent sequences (for instance to text gener
|
|||||||
```python
|
```python
|
||||||
def reset_cache(self):
|
def reset_cache(self):
|
||||||
self.cache_k, self.cache_v = None, None
|
self.cache_k, self.cache_v = None, None
|
||||||
|
self.ptr_current_pos = 0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@ -157,14 +170,15 @@ def reset_kv_cache(self):
|
|||||||
With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
|
With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
def generate_text_simple_cached(model, idx, max_new_tokens,
|
||||||
|
context_size=None, use_cache=True):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
ctx_len = context_size or model.pos_emb.num_embeddings
|
||||||
|
|
||||||
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
with torch.no_grad():
|
||||||
if use_cache:
|
if use_cache:
|
||||||
# Init cache with full prompt
|
# Init cache with full prompt
|
||||||
model.reset_kv_cache()
|
model.reset_kv_cache()
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
@ -173,11 +187,9 @@ def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
|||||||
# b) append it to the running sequence
|
# b) append it to the running sequence
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
# c) feed model only the new token
|
# c) feed model only the new token
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(next_idx, use_cache=True)
|
logits = model(next_idx, use_cache=True)
|
||||||
else:
|
else:
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
|
@ -35,6 +35,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# NEW
|
# NEW
|
||||||
self.register_buffer("cache_k", None, persistent=False)
|
self.register_buffer("cache_k", None, persistent=False)
|
||||||
self.register_buffer("cache_v", None, persistent=False)
|
self.register_buffer("cache_v", None, persistent=False)
|
||||||
|
self.ptr_current_pos = 0
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
def forward(self, x, use_cache=False):
|
def forward(self, x, use_cache=False):
|
||||||
@ -71,8 +72,19 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
# NEW
|
||||||
|
num_tokens_Q = queries.shape[-2]
|
||||||
|
num_tokens_K = keys.shape[-2]
|
||||||
|
if use_cache:
|
||||||
|
mask_bool = self.mask.bool()[
|
||||||
|
self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
|
||||||
|
]
|
||||||
|
self.ptr_current_pos += num_tokens_Q
|
||||||
|
####################################################
|
||||||
# Original mask truncated to the number of tokens and converted to boolean
|
# Original mask truncated to the number of tokens and converted to boolean
|
||||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
else:
|
||||||
|
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
|
||||||
|
|
||||||
# Use the mask to fill attention scores
|
# Use the mask to fill attention scores
|
||||||
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||||
@ -93,6 +105,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# NEW
|
# NEW
|
||||||
def reset_cache(self):
|
def reset_cache(self):
|
||||||
self.cache_k, self.cache_v = None, None
|
self.cache_k, self.cache_v = None, None
|
||||||
|
self.ptr_current_pos = 0
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
|
|
||||||
@ -264,14 +277,15 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
|||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
# NEW
|
# NEW
|
||||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
def generate_text_simple_cached(model, idx, max_new_tokens,
|
||||||
|
context_size=None, use_cache=True):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
ctx_len = context_size or model.pos_emb.num_embeddings
|
||||||
|
|
||||||
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
with torch.no_grad():
|
||||||
if use_cache:
|
if use_cache:
|
||||||
# Init cache with full prompt
|
# Init cache with full prompt
|
||||||
model.reset_kv_cache()
|
model.reset_kv_cache()
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
@ -280,11 +294,9 @@ def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
|||||||
# b) append it to the running sequence
|
# b) append it to the running sequence
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
# c) feed model only the new token
|
# c) feed model only the new token
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(next_idx, use_cache=True)
|
logits = model(next_idx, use_cache=True)
|
||||||
else:
|
else:
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
|
@ -171,7 +171,8 @@ class TransformerBlock(nn.Module):
|
|||||||
num_heads=cfg["n_heads"],
|
num_heads=cfg["n_heads"],
|
||||||
dropout=cfg["drop_rate"],
|
dropout=cfg["drop_rate"],
|
||||||
qkv_bias=cfg["qkv_bias"],
|
qkv_bias=cfg["qkv_bias"],
|
||||||
window_size=cfg["kv_window_size"]) # NEW
|
window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
|
||||||
|
)
|
||||||
self.ff = FeedForward(cfg)
|
self.ff = FeedForward(cfg)
|
||||||
self.norm1 = LayerNorm(cfg["emb_dim"])
|
self.norm1 = LayerNorm(cfg["emb_dim"])
|
||||||
self.norm2 = LayerNorm(cfg["emb_dim"])
|
self.norm2 = LayerNorm(cfg["emb_dim"])
|
||||||
@ -289,27 +290,22 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
|||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
# NEW
|
# NEW
|
||||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
ctx_len = context_size or model.pos_emb.num_embeddings
|
||||||
if use_cache:
|
|
||||||
# Init cache with full prompt
|
|
||||||
model.reset_kv_cache()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
if use_cache:
|
||||||
|
model.reset_kv_cache()
|
||||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
# a) pick the token with the highest log-probability (greedy sampling)
|
|
||||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
# b) append it to the running sequence
|
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
# c) feed model only the new token
|
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(next_idx, use_cache=True)
|
logits = model(next_idx, use_cache=True)
|
||||||
else:
|
else:
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
|
101
ch04/03_kv-cache/tests.py
Normal file
101
ch04/03_kv-cache/tests.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# Code to test the GPT model implementation against the KV cache variants
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from gpt_ch04 import GPTModel as GPTModelBase
|
||||||
|
from gpt_ch04 import generate_text_simple
|
||||||
|
|
||||||
|
from gpt_with_kv_cache import GPTModel as GPTModelKV1
|
||||||
|
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
|
||||||
|
from gpt_with_kv_cache import generate_text_simple_cached
|
||||||
|
|
||||||
|
|
||||||
|
GPT_CONFIG_124M = {
|
||||||
|
"vocab_size": 50257,
|
||||||
|
"context_length": 1024,
|
||||||
|
"emb_dim": 768,
|
||||||
|
"n_heads": 12,
|
||||||
|
"n_layers": 12,
|
||||||
|
"drop_rate": 0.1,
|
||||||
|
"qkv_bias": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
|
||||||
|
def test_gpt_model_equivalence_not_cached(ModelClass):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
|
||||||
|
model = ModelClass(GPT_CONFIG_124M).to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
prompt = "Hello, I am"
|
||||||
|
encoded = tokenizer.encode(prompt)
|
||||||
|
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
|
||||||
|
|
||||||
|
model_name = ModelClass.__module__ + "." + ModelClass.__name__
|
||||||
|
|
||||||
|
token_ids = generate_text_simple(
|
||||||
|
model=model,
|
||||||
|
idx=encoded_tensor,
|
||||||
|
max_new_tokens=30,
|
||||||
|
context_size=GPT_CONFIG_124M["context_length"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hasattr(test_gpt_model_equivalence_not_cached, "results"):
|
||||||
|
test_gpt_model_equivalence_not_cached.results = []
|
||||||
|
|
||||||
|
test_gpt_model_equivalence_not_cached.results.append((model_name, token_ids))
|
||||||
|
|
||||||
|
if len(test_gpt_model_equivalence_not_cached.results) == 3:
|
||||||
|
base_name, base_output = test_gpt_model_equivalence_not_cached.results[0]
|
||||||
|
for other_name, other_output in test_gpt_model_equivalence_not_cached.results[1:]:
|
||||||
|
assert torch.equal(base_output, other_output), (
|
||||||
|
f"Mismatch between {base_name} and {other_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
|
||||||
|
def test_gpt_model_equivalence_cached(ModelClass):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
|
||||||
|
model = ModelClass(GPT_CONFIG_124M).to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
prompt = "Hello, I am"
|
||||||
|
encoded_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
|
||||||
|
|
||||||
|
model_name = ModelClass.__module__ + "." + ModelClass.__name__
|
||||||
|
|
||||||
|
if ModelClass is GPTModelBase:
|
||||||
|
token_ids = generate_text_simple(
|
||||||
|
model=model,
|
||||||
|
idx=encoded_tensor,
|
||||||
|
max_new_tokens=30,
|
||||||
|
context_size=GPT_CONFIG_124M["context_length"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
token_ids = generate_text_simple_cached(
|
||||||
|
model=model,
|
||||||
|
idx=encoded_tensor,
|
||||||
|
max_new_tokens=30,
|
||||||
|
context_size=GPT_CONFIG_124M["context_length"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hasattr(test_gpt_model_equivalence_cached, "results"):
|
||||||
|
test_gpt_model_equivalence_cached.results = []
|
||||||
|
|
||||||
|
test_gpt_model_equivalence_cached.results.append((model_name, token_ids))
|
||||||
|
|
||||||
|
if len(test_gpt_model_equivalence_cached.results) == 3:
|
||||||
|
base_name, base_output = test_gpt_model_equivalence_cached.results[0]
|
||||||
|
for other_name, other_output in test_gpt_model_equivalence_cached.results[1:]:
|
||||||
|
assert torch.equal(base_output, other_output), (
|
||||||
|
f"Mismatch between {base_name} and {other_name}"
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user