diff --git a/ch04/01_main-chapter-code/ch04.ipynb b/ch04/01_main-chapter-code/ch04.ipynb index dc4e3ec..4f75a93 100644 --- a/ch04/01_main-chapter-code/ch04.ipynb +++ b/ch04/01_main-chapter-code/ch04.ipynb @@ -967,32 +967,6 @@ "print(\"Output shape:\", output.shape)" ] }, - { - "cell_type": "code", - "execution_count": 22, - "id": "01e737a6-fc99-42bb-9f7e-4da899168811", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input shape: torch.Size([2, 4, 768])\n", - "Output shape: torch.Size([2, 4, 768])\n" - ] - } - ], - "source": [ - "torch.manual_seed(123)\n", - "\n", - "x = torch.rand(2, 4, 768) # Shape: [batch_size, num_tokens, emb_dim]\n", - "block = TransformerBlock(GPT_CONFIG_124M)\n", - "output = block(x)\n", - "\n", - "print(\"Input shape:\", x.shape)\n", - "print(\"Output shape:\", output.shape)" - ] - }, { "cell_type": "markdown", "id": "91f502e4-f3e4-40cb-8268-179eec002394", @@ -1114,44 +1088,6 @@ "print(out)" ] }, - { - "cell_type": "code", - "execution_count": 44, - "id": "252b78c2-4404-483b-84fe-a412e55c16fc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input batch:\n", - " tensor([[6109, 3626, 6100, 345],\n", - " [6109, 1110, 6622, 257]])\n", - "\n", - "Output shape: torch.Size([2, 4, 50257])\n", - "tensor([[[-0.0055, 0.3224, 0.2185, ..., 0.2539, 0.4578, -0.4747],\n", - " [ 0.2663, -0.2975, -0.5040, ..., -0.3903, 0.5328, -0.4224],\n", - " [ 1.1146, -0.0923, 0.1303, ..., 0.1521, -0.4494, 0.0276],\n", - " [-0.8239, 0.1174, -0.2566, ..., 1.1197, 0.1036, -0.3993]],\n", - "\n", - " [[-0.1027, 0.1752, -0.1048, ..., 0.2258, 0.1559, -0.8747],\n", - " [ 0.2230, 0.1246, 0.0492, ..., 0.8573, -0.2933, 0.3036],\n", - " [ 0.9409, 1.3068, -0.1610, ..., 0.8244, 0.1763, 0.0811],\n", - " [ 0.4395, 0.2753, 0.1540, ..., 1.3410, -0.3709, 0.1643]]],\n", - " grad_fn=)\n" - ] - } - ], - "source": [ - "torch.manual_seed(123)\n", - "model = GPTModel(GPT_CONFIG_124M)\n", - "\n", - "out = model(batch)\n", - "print(\"Input batch:\\n\", batch)\n", - "print(\"\\nOutput shape:\", out.shape)\n", - "print(out)" - ] - }, { "cell_type": "markdown", "id": "6d616e7a-568b-4921-af29-bd3f4683cd2e",