From 496079c61ec3376df93abc2a862d07db078b1257 Mon Sep 17 00:00:00 2001 From: Rayed Bin Wahed Date: Wed, 6 Mar 2024 23:03:57 +0800 Subject: [PATCH 1/2] Update mha-implementations.ipynb Fix variable spelling in comments to keep consistent with code --- .../mha-implementations.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index d2bfba2..7cc1058 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -168,7 +168,7 @@ " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", " queries, keys, values = qkv.unbind(0)\n", "\n", - " # (b, num_head, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", + " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", " attn_scores = queries @ keys.transpose(-2, -1)\n", " attn_scores = attn_scores.masked_fill(\n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", @@ -258,7 +258,7 @@ " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", " qkv = qkv.permute(2, 0, 3, 1, 4)\n", "\n", - " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", " q, k, v = qkv.unbind(0)\n", "\n", " use_dropout = 0. if not self.training else self.dropout\n", From 99a5e28defb5a589941d3ad73fa9c7d5bbe88f4e Mon Sep 17 00:00:00 2001 From: rasbt Date: Thu, 7 Mar 2024 06:30:40 -0600 Subject: [PATCH 2/2] rename q,k,v for consistency with chapter 3 --- .../mha-implementations.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index 7cc1058..53795ca 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -259,11 +259,11 @@ " qkv = qkv.permute(2, 0, 3, 1, 4)\n", "\n", " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", - " q, k, v = qkv.unbind(0)\n", + " queries, keys, values = qkv.unbind(0)\n", "\n", " use_dropout = 0. if not self.training else self.dropout\n", - " context_vec = torch.nn.functional.scaled_dot_product_attention(q, k, v, \n", - " attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", + " context_vec = nn.functional.scaled_dot_product_attention(\n", + " queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", "\n", " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", @@ -396,7 +396,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.12" } }, "nbformat": 4,