From ffc5e4e5d6c482e9d382b1fd30d0bdd18f9d964e Mon Sep 17 00:00:00 2001 From: martinzwm Date: Sun, 22 Jun 2025 14:06:00 -0700 Subject: [PATCH] Fix bug in masking when kv cache is used. --- ch04/03_kv-cache/gpt_with_kv_cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ch04/03_kv-cache/gpt_with_kv_cache.py b/ch04/03_kv-cache/gpt_with_kv_cache.py index c685fe0..f92b669 100644 --- a/ch04/03_kv-cache/gpt_with_kv_cache.py +++ b/ch04/03_kv-cache/gpt_with_kv_cache.py @@ -72,7 +72,9 @@ class MultiHeadAttention(nn.Module): attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head # Original mask truncated to the number of tokens and converted to boolean - mask_bool = self.mask.bool()[:num_tokens, :num_tokens] + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K] # Use the mask to fill attention scores attn_scores.masked_fill_(mask_bool, -torch.inf)