Fix bug in masking when kv cache is used.

This commit is contained in:
martinzwm 2025-06-22 14:06:00 -07:00
parent 01be5a42e4
commit ffc5e4e5d6

View File

@ -72,7 +72,9 @@ class MultiHeadAttention(nn.Module):
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
num_tokens_Q = queries.shape[-2]
num_tokens_K = keys.shape[-2]
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)