add clarification about :num_tokens

This commit is contained in:
rasbt 2024-06-29 07:16:42 -05:00
parent 796f0e2a30
commit c7f892550e
No known key found for this signature in database
GPG Key ID: 3C6E5C7C075611DB

View File

@ -1633,7 +1633,7 @@
"\n",
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size\n",
" attn_weights = torch.softmax(\n",
" attn_scores / keys.shape[-1]**0.5, dim=-1\n",
" )\n",
@ -2027,7 +2027,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.4"
}
},
"nbformat": 4,