diff --git a/ch04/03_kv-cache/gpt_with_kv_cache.py b/ch04/03_kv-cache/gpt_with_kv_cache.py index c685fe0..f92b669 100644 --- a/ch04/03_kv-cache/gpt_with_kv_cache.py +++ b/ch04/03_kv-cache/gpt_with_kv_cache.py @@ -72,7 +72,9 @@ class MultiHeadAttention(nn.Module): attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head # Original mask truncated to the number of tokens and converted to boolean - mask_bool = self.mask.bool()[:num_tokens, :num_tokens] + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K] # Use the mask to fill attention scores attn_scores.masked_fill_(mask_bool, -torch.inf)