add print statements for illustration purposes

This commit is contained in:
rasbt 2024-02-10 10:10:14 -06:00
parent cc459b6b5a
commit 10aa2d099d

View File

@ -538,7 +538,7 @@
"id": "11190e7d-8c29-4115-824a-e03702f9dd54",
"metadata": {},
"source": [
"## 4.3 Implementing a feed forward network and GELU activations"
"## 4.3 Implementing a feed forward network with GELU activations"
]
},
{
@ -741,19 +741,18 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 38,
"id": "05473938-799c-49fd-86d4-8ed65f94fee6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.2427]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"Before shortcut: tensor([[0.0950, 0.0634, 0.3361]], grad_fn=<MulBackward0>)\n",
"After shortcut: tensor([[-0.9050, 1.0634, 2.3361]], grad_fn=<AddBackward0>)\n",
"Final network output: tensor([[0.2427]], grad_fn=<AddmmBackward0>)\n"
]
}
],
"source": [
@ -768,14 +767,17 @@
" def forward(self, x):\n",
" shortcut = x\n",
" x = self.gelu(self.fc1(x))\n",
" x = self.gelu(self.fc2(x)) + shortcut\n",
" x = self.gelu(self.fc2(x))\n",
" print(\"Before shortcut:\", x)\n",
" x = x + shortcut\n",
" print(\"After shortcut:\", x)\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"torch.manual_seed(123)\n",
"ex_short = ExampleWithShortcut()\n",
"inputs = torch.tensor([[-1., 1., 2.]])\n",
"print(ex_short(inputs))"
"print(\"Final network output:\", ex_short(inputs))"
]
},
{