mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 18:00:08 +00:00 
			
		
		
		
	Fix bug in masking when kv cache is used.
This commit is contained in:
		
							parent
							
								
									01be5a42e4
								
							
						
					
					
						commit
						ffc5e4e5d6
					
				| @ -72,7 +72,9 @@ class MultiHeadAttention(nn.Module): | ||||
|         attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head | ||||
| 
 | ||||
|         # Original mask truncated to the number of tokens and converted to boolean | ||||
|         mask_bool = self.mask.bool()[:num_tokens, :num_tokens] | ||||
|         num_tokens_Q = queries.shape[-2] | ||||
|         num_tokens_K = keys.shape[-2] | ||||
|         mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K] | ||||
| 
 | ||||
|         # Use the mask to fill attention scores | ||||
|         attn_scores.masked_fill_(mask_bool, -torch.inf) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user