mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-30 09:20:46 +00:00
Fix bug in masking when kv cache is used.
This commit is contained in:
parent
01be5a42e4
commit
ffc5e4e5d6
@ -72,7 +72,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||||
|
|
||||||
# Original mask truncated to the number of tokens and converted to boolean
|
# Original mask truncated to the number of tokens and converted to boolean
|
||||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
num_tokens_Q = queries.shape[-2]
|
||||||
|
num_tokens_K = keys.shape[-2]
|
||||||
|
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
|
||||||
|
|
||||||
# Use the mask to fill attention scores
|
# Use the mask to fill attention scores
|
||||||
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user