remove redundant double-unsequeeze

This commit is contained in:
rasbt 2024-02-29 08:31:07 -06:00
parent d89aaf319d
commit b827bf4eea
4 changed files with 13 additions and 21 deletions

View File

@ -1608,7 +1608,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 39,
"id": "110b0188-6e9e-4e56-a988-10523c6c8538", "id": "110b0188-6e9e-4e56-a988-10523c6c8538",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1672,8 +1672,8 @@
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
" # Original mask truncated to the number of tokens and converted to boolean\n", " # Original mask truncated to the number of tokens and converted to boolean\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
" # Unsqueeze the mask twice to match dimensions\n", " # Unsqueeze the mask to match dimensions\n",
" mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)\n", " mask_unsqueezed = mask_bool.unsqueeze(0)\n",
" # Use the unsqueezed mask to fill attention scores\n", " # Use the unsqueezed mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n", " attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
" \n", " \n",
@ -1729,7 +1729,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 40,
"id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9", "id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1772,7 +1772,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 41,
"id": "053760f1-1a02-42f0-b3bf-3d939e407039", "id": "053760f1-1a02-42f0-b3bf-3d939e407039",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1804,7 +1804,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 42,
"id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937", "id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1814,7 +1814,7 @@
"2360064" "2360064"
] ]
}, },
"execution_count": 40, "execution_count": 42,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -1847,14 +1847,6 @@
"source": [ "source": [
"- See the [./multihead-attention.ipynb](./multihead-attention.ipynb) code notebook, which is a concise version of the data loader (chapter 2) plus the multi-head attention class that we implemented in this chapter and will need for training the GPT model in upcoming chapters." "- See the [./multihead-attention.ipynb](./multihead-attention.ipynb) code notebook, which is a concise version of the data loader (chapter 2) plus the multi-head attention class that we implemented in this chapter and will need for training the GPT model in upcoming chapters."
] ]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f5b7a94-78d0-49d5-896f-21696cb331b7",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@ -278,8 +278,8 @@
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
" # Original mask truncated to the number of tokens and converted to boolean\n", " # Original mask truncated to the number of tokens and converted to boolean\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
" # Unsqueeze the mask twice to match dimensions\n", " # Unsqueeze the mask to match dimensions\n",
" mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)\n", " mask_unsqueezed = mask_bool.unsqueeze(0)\n",
" # Use the unsqueezed mask to fill attention scores\n", " # Use the unsqueezed mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n", " attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
" \n", " \n",

View File

@ -91,8 +91,8 @@ class MultiHeadAttention(nn.Module):
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean # Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens] mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Unsqueeze the mask twice to match dimensions # Unsqueeze the mask to match dimensions
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0) mask_unsqueezed = mask_bool.unsqueeze(0)
# Use the unsqueezed mask to fill attention scores # Use the unsqueezed mask to fill attention scores
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf) attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)

View File

@ -80,8 +80,8 @@ class MultiHeadAttention(nn.Module):
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean # Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens] mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Unsqueeze the mask twice to match dimensions # Unsqueeze the mask to match dimensions
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0) mask_unsqueezed = mask_bool.unsqueeze(0)
# Use the unsqueezed mask to fill attention scores # Use the unsqueezed mask to fill attention scores
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf) attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)