diff --git a/ch03/01_main-chapter-code/multihead-attention.ipynb b/ch03/01_main-chapter-code/multihead-attention.ipynb index 1748607..98eed3f 100644 --- a/ch03/01_main-chapter-code/multihead-attention.ipynb +++ b/ch03/01_main-chapter-code/multihead-attention.ipynb @@ -228,7 +228,7 @@ " [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) \n", " for _ in range(num_heads)]\n", " )\n", - " self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n", + " self.out_proj = nn.Linear(d_in*num_heads, d_out*num_heads)\n", "\n", " def forward(self, x):\n", " context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n", @@ -383,7 +383,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.12.3" } }, "nbformat": 4,