From ec6e09136aa1a404cbd2d98a2029c28bc40261e8 Mon Sep 17 00:00:00 2001 From: rasbt Date: Sun, 26 May 2024 15:38:35 -0500 Subject: [PATCH] update pr --- .../01_main-chapter-code/multihead-attention.ipynb | 14 +++++++++++--- .../mha-implementations.ipynb | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ch03/01_main-chapter-code/multihead-attention.ipynb b/ch03/01_main-chapter-code/multihead-attention.ipynb index 98eed3f..b788040 100644 --- a/ch03/01_main-chapter-code/multihead-attention.ipynb +++ b/ch03/01_main-chapter-code/multihead-attention.ipynb @@ -38,7 +38,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "torch version: 2.2.1\n" + "torch version: 2.2.2\n" ] } ], @@ -228,7 +228,7 @@ " [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) \n", " for _ in range(num_heads)]\n", " )\n", - " self.out_proj = nn.Linear(d_in*num_heads, d_out*num_heads)\n", + " self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n", "\n", " def forward(self, x):\n", " context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n", @@ -365,6 +365,14 @@ "\n", "print(\"context_vecs.shape:\", context_vecs.shape)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1d965a5-9b98-4554-8646-7ecd497874cb", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -383,7 +391,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index ce5a33e..82f5cde 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -341,7 +341,7 @@ " self.d_out = d_out\n", "\n", " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", - " self.proj = nn.Linear(d_in, d_out)\n", + " self.proj = nn.Linear(d_out, d_out)\n", " self.dropout = dropout\n", "\n", " def forward(self, x):\n",