Fix an incorrect input dimension

This commit is contained in:
Kostyantyn Borysenko 2024-05-26 13:05:07 -07:00
parent 8f362634b8
commit 18b5724e75

View File

@ -228,7 +228,7 @@
" [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
" for _ in range(num_heads)]\n",
" )\n",
" self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
" self.out_proj = nn.Linear(d_in*num_heads, d_out*num_heads)\n",
"\n",
" def forward(self, x):\n",
" context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
@ -383,7 +383,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.12.3"
}
},
"nbformat": 4,