add print statements for illustration purposes

This commit is contained in:
rasbt 2024-02-10 10:10:14 -06:00
parent cc459b6b5a
commit 10aa2d099d

View File

@ -538,7 +538,7 @@
"id": "11190e7d-8c29-4115-824a-e03702f9dd54", "id": "11190e7d-8c29-4115-824a-e03702f9dd54",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 4.3 Implementing a feed forward network and GELU activations" "## 4.3 Implementing a feed forward network with GELU activations"
] ]
}, },
{ {
@ -741,19 +741,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 38,
"id": "05473938-799c-49fd-86d4-8ed65f94fee6", "id": "05473938-799c-49fd-86d4-8ed65f94fee6",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "name": "stdout",
"text/plain": [ "output_type": "stream",
"tensor([[0.2427]], grad_fn=<AddmmBackward0>)" "text": [
"Before shortcut: tensor([[0.0950, 0.0634, 0.3361]], grad_fn=<MulBackward0>)\n",
"After shortcut: tensor([[-0.9050, 1.0634, 2.3361]], grad_fn=<AddBackward0>)\n",
"Final network output: tensor([[0.2427]], grad_fn=<AddmmBackward0>)\n"
] ]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
@ -768,14 +767,17 @@
" def forward(self, x):\n", " def forward(self, x):\n",
" shortcut = x\n", " shortcut = x\n",
" x = self.gelu(self.fc1(x))\n", " x = self.gelu(self.fc1(x))\n",
" x = self.gelu(self.fc2(x)) + shortcut\n", " x = self.gelu(self.fc2(x))\n",
" print(\"Before shortcut:\", x)\n",
" x = x + shortcut\n",
" print(\"After shortcut:\", x)\n",
" x = self.fc3(x)\n", " x = self.fc3(x)\n",
" return x\n", " return x\n",
"\n", "\n",
"torch.manual_seed(123)\n", "torch.manual_seed(123)\n",
"ex_short = ExampleWithShortcut()\n", "ex_short = ExampleWithShortcut()\n",
"inputs = torch.tensor([[-1., 1., 2.]])\n", "inputs = torch.tensor([[-1., 1., 2.]])\n",
"print(ex_short(inputs))" "print(\"Final network output:\", ex_short(inputs))"
] ]
}, },
{ {