mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-03 11:20:49 +00:00
rename q,k,v for consistency with chapter 3
This commit is contained in:
parent
496079c61e
commit
99a5e28def
@ -259,11 +259,11 @@
|
|||||||
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
|
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
|
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
|
||||||
" q, k, v = qkv.unbind(0)\n",
|
" queries, keys, values = qkv.unbind(0)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" use_dropout = 0. if not self.training else self.dropout\n",
|
" use_dropout = 0. if not self.training else self.dropout\n",
|
||||||
" context_vec = torch.nn.functional.scaled_dot_product_attention(q, k, v, \n",
|
" context_vec = nn.functional.scaled_dot_product_attention(\n",
|
||||||
" attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
|
" queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
||||||
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
|
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
|
||||||
@ -396,7 +396,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.6"
|
"version": "3.10.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user