mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-28 16:30:07 +00:00
Fix bug in masking when kv cache is used.
This commit is contained in:
parent
01be5a42e4
commit
ffc5e4e5d6
@ -72,7 +72,9 @@ class MultiHeadAttention(nn.Module):
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
|
||||
# Original mask truncated to the number of tokens and converted to boolean
|
||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||
num_tokens_Q = queries.shape[-2]
|
||||
num_tokens_K = keys.shape[-2]
|
||||
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
|
||||
|
||||
# Use the mask to fill attention scores
|
||||
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user