mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-14 19:51:35 +00:00
Remove duplicate cells
This commit is contained in:
parent
244137e8a1
commit
dbb5e65a29
@ -967,32 +967,6 @@
|
|||||||
"print(\"Output shape:\", output.shape)"
|
"print(\"Output shape:\", output.shape)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 22,
|
|
||||||
"id": "01e737a6-fc99-42bb-9f7e-4da899168811",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Input shape: torch.Size([2, 4, 768])\n",
|
|
||||||
"Output shape: torch.Size([2, 4, 768])\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"torch.manual_seed(123)\n",
|
|
||||||
"\n",
|
|
||||||
"x = torch.rand(2, 4, 768) # Shape: [batch_size, num_tokens, emb_dim]\n",
|
|
||||||
"block = TransformerBlock(GPT_CONFIG_124M)\n",
|
|
||||||
"output = block(x)\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"Input shape:\", x.shape)\n",
|
|
||||||
"print(\"Output shape:\", output.shape)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "91f502e4-f3e4-40cb-8268-179eec002394",
|
"id": "91f502e4-f3e4-40cb-8268-179eec002394",
|
||||||
@ -1114,44 +1088,6 @@
|
|||||||
"print(out)"
|
"print(out)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 44,
|
|
||||||
"id": "252b78c2-4404-483b-84fe-a412e55c16fc",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Input batch:\n",
|
|
||||||
" tensor([[6109, 3626, 6100, 345],\n",
|
|
||||||
" [6109, 1110, 6622, 257]])\n",
|
|
||||||
"\n",
|
|
||||||
"Output shape: torch.Size([2, 4, 50257])\n",
|
|
||||||
"tensor([[[-0.0055, 0.3224, 0.2185, ..., 0.2539, 0.4578, -0.4747],\n",
|
|
||||||
" [ 0.2663, -0.2975, -0.5040, ..., -0.3903, 0.5328, -0.4224],\n",
|
|
||||||
" [ 1.1146, -0.0923, 0.1303, ..., 0.1521, -0.4494, 0.0276],\n",
|
|
||||||
" [-0.8239, 0.1174, -0.2566, ..., 1.1197, 0.1036, -0.3993]],\n",
|
|
||||||
"\n",
|
|
||||||
" [[-0.1027, 0.1752, -0.1048, ..., 0.2258, 0.1559, -0.8747],\n",
|
|
||||||
" [ 0.2230, 0.1246, 0.0492, ..., 0.8573, -0.2933, 0.3036],\n",
|
|
||||||
" [ 0.9409, 1.3068, -0.1610, ..., 0.8244, 0.1763, 0.0811],\n",
|
|
||||||
" [ 0.4395, 0.2753, 0.1540, ..., 1.3410, -0.3709, 0.1643]]],\n",
|
|
||||||
" grad_fn=<UnsafeViewBackward0>)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"torch.manual_seed(123)\n",
|
|
||||||
"model = GPTModel(GPT_CONFIG_124M)\n",
|
|
||||||
"\n",
|
|
||||||
"out = model(batch)\n",
|
|
||||||
"print(\"Input batch:\\n\", batch)\n",
|
|
||||||
"print(\"\\nOutput shape:\", out.shape)\n",
|
|
||||||
"print(out)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "6d616e7a-568b-4921-af29-bd3f4683cd2e",
|
"id": "6d616e7a-568b-4921-af29-bd3f4683cd2e",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user