diff --git a/ch03/01_main-chapter-code/ch03.ipynb b/ch03/01_main-chapter-code/ch03.ipynb index 62a26ac..688ff27 100644 --- a/ch03/01_main-chapter-code/ch03.ipynb +++ b/ch03/01_main-chapter-code/ch03.ipynb @@ -37,7 +37,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "torch version: 2.2.1\n" + "torch version: 2.2.2\n" ] } ], @@ -625,7 +625,7 @@ } ], "source": [ - "attn_weights = torch.softmax(attn_scores, dim=1)\n", + "attn_weights = torch.softmax(attn_scores, dim=-1)\n", "print(attn_weights)" ] }, @@ -656,7 +656,7 @@ "row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n", "print(\"Row 2 sum:\", row_2_sum)\n", "\n", - "print(\"All row sums:\", attn_weights.sum(dim=1))" + "print(\"All row sums:\", attn_weights.sum(dim=-1))" ] }, { @@ -1139,7 +1139,7 @@ " values = self.W_value(x)\n", " \n", " attn_scores = queries @ keys.T\n", - " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", "\n", " context_vec = attn_weights @ values\n", " return context_vec\n", @@ -1243,7 +1243,7 @@ "keys = sa_v2.W_key(inputs) \n", "attn_scores = queries @ keys.T\n", "\n", - "attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n", + "attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", "print(attn_weights)" ] }, @@ -1429,7 +1429,7 @@ } ], "source": [ - "attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)\n", + "attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)\n", "print(attn_weights)" ] }, @@ -1765,7 +1765,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 36, "id": "110b0188-6e9e-4e56-a988-10523c6c8538", "metadata": {}, "outputs": [ @@ -1894,7 +1894,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 37, "id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9", "metadata": {}, "outputs": [ @@ -1937,7 +1937,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 38, "id": "053760f1-1a02-42f0-b3bf-3d939e407039", "metadata": {}, "outputs": [ @@ -2000,7 +2000,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.11.4" } }, "nbformat": 4,