mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 02:50:15 +00:00
Include mathematical breakdown for exercise solution 4.1 (#483)
This commit is contained in:
parent
15af754304
commit
37aed8fc2c
@ -62,7 +62,33 @@
|
||||
"execution_count": 2,
|
||||
"id": "2751b0e5-ffd3-4be2-8db3-e20dd4d61d69",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TransformerBlock(\n",
|
||||
" (att): MultiHeadAttention(\n",
|
||||
" (W_query): Linear(in_features=768, out_features=768, bias=False)\n",
|
||||
" (W_key): Linear(in_features=768, out_features=768, bias=False)\n",
|
||||
" (W_value): Linear(in_features=768, out_features=768, bias=False)\n",
|
||||
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" )\n",
|
||||
" (ff): FeedForward(\n",
|
||||
" (layers): Sequential(\n",
|
||||
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||||
" (1): GELU()\n",
|
||||
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (norm1): LayerNorm()\n",
|
||||
" (norm2): LayerNorm()\n",
|
||||
" (drop_shortcut): Dropout(p=0.1, inplace=False)\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from gpt import TransformerBlock\n",
|
||||
"\n",
|
||||
@ -76,7 +102,8 @@
|
||||
" \"qkv_bias\": False\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"block = TransformerBlock(GPT_CONFIG_124M)"
|
||||
"block = TransformerBlock(GPT_CONFIG_124M)\n",
|
||||
"print(block)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -126,6 +153,31 @@
|
||||
"- Optionally multiply by 12 to capture all transformer blocks in the 124M GPT model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "597e9251-e0a9-4972-8df6-f280f35939f9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Bonus: Mathematical breakdown**\n",
|
||||
"\n",
|
||||
"- For those interested in how these parameter counts are calculated mathematically, you can find the breakdown below (assuming `emb_dim=768`):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Feed forward module:\n",
|
||||
"\n",
|
||||
"- 1st `Linear` layer: 768 inputs × 4×768 outputs + 4×768 bias units = 2,362,368\n",
|
||||
"- 2nd `Linear` layer: 4×768 inputs × 768 outputs + 768 bias units = 2,360,064\n",
|
||||
"- Total: 1st `Linear` layer + 2nd `Linear` layer = 2,362,368 + 2,360,064 = 4,722,432\n",
|
||||
"\n",
|
||||
"Attention module:\n",
|
||||
"\n",
|
||||
"- `W_query`: 768 inputs × 768 outputs = 589,824 \n",
|
||||
"- `W_key`: 768 inputs × 768 outputs = 589,824\n",
|
||||
"- `W_value`: 768 inputs × 768 outputs = 589,824 \n",
|
||||
"- `out_proj`: 768 inputs × 768 outputs + 768 bias units = 590,592\n",
|
||||
"- Total: `W_query` + `W_key` + `W_value` + `out_proj` = 3×589,824 + 590,592 = 2,360,064 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0f7b7c7f-0fa1-4d30-ab44-e499edd55b6d",
|
||||
|
Loading…
x
Reference in New Issue
Block a user