mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-02 04:48:07 +00:00
Include mathematical breakdown for exercise solution 4.1 (#483)
This commit is contained in:
parent
15af754304
commit
37aed8fc2c
@ -62,7 +62,33 @@
|
|||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"id": "2751b0e5-ffd3-4be2-8db3-e20dd4d61d69",
|
"id": "2751b0e5-ffd3-4be2-8db3-e20dd4d61d69",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"TransformerBlock(\n",
|
||||||
|
" (att): MultiHeadAttention(\n",
|
||||||
|
" (W_query): Linear(in_features=768, out_features=768, bias=False)\n",
|
||||||
|
" (W_key): Linear(in_features=768, out_features=768, bias=False)\n",
|
||||||
|
" (W_value): Linear(in_features=768, out_features=768, bias=False)\n",
|
||||||
|
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||||
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||||
|
" )\n",
|
||||||
|
" (ff): FeedForward(\n",
|
||||||
|
" (layers): Sequential(\n",
|
||||||
|
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||||||
|
" (1): GELU()\n",
|
||||||
|
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" (norm1): LayerNorm()\n",
|
||||||
|
" (norm2): LayerNorm()\n",
|
||||||
|
" (drop_shortcut): Dropout(p=0.1, inplace=False)\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from gpt import TransformerBlock\n",
|
"from gpt import TransformerBlock\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -76,7 +102,8 @@
|
|||||||
" \"qkv_bias\": False\n",
|
" \"qkv_bias\": False\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"block = TransformerBlock(GPT_CONFIG_124M)"
|
"block = TransformerBlock(GPT_CONFIG_124M)\n",
|
||||||
|
"print(block)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -126,6 +153,31 @@
|
|||||||
"- Optionally multiply by 12 to capture all transformer blocks in the 124M GPT model"
|
"- Optionally multiply by 12 to capture all transformer blocks in the 124M GPT model"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "597e9251-e0a9-4972-8df6-f280f35939f9",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**Bonus: Mathematical breakdown**\n",
|
||||||
|
"\n",
|
||||||
|
"- For those interested in how these parameter counts are calculated mathematically, you can find the breakdown below (assuming `emb_dim=768`):\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Feed forward module:\n",
|
||||||
|
"\n",
|
||||||
|
"- 1st `Linear` layer: 768 inputs × 4×768 outputs + 4×768 bias units = 2,362,368\n",
|
||||||
|
"- 2nd `Linear` layer: 4×768 inputs × 768 outputs + 768 bias units = 2,360,064\n",
|
||||||
|
"- Total: 1st `Linear` layer + 2nd `Linear` layer = 2,362,368 + 2,360,064 = 4,722,432\n",
|
||||||
|
"\n",
|
||||||
|
"Attention module:\n",
|
||||||
|
"\n",
|
||||||
|
"- `W_query`: 768 inputs × 768 outputs = 589,824 \n",
|
||||||
|
"- `W_key`: 768 inputs × 768 outputs = 589,824\n",
|
||||||
|
"- `W_value`: 768 inputs × 768 outputs = 589,824 \n",
|
||||||
|
"- `out_proj`: 768 inputs × 768 outputs + 768 bias units = 590,592\n",
|
||||||
|
"- Total: `W_query` + `W_key` + `W_value` + `out_proj` = 3×589,824 + 590,592 = 2,360,064 "
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "0f7b7c7f-0fa1-4d30-ab44-e499edd55b6d",
|
"id": "0f7b7c7f-0fa1-4d30-ab44-e499edd55b6d",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user