mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-02 02:41:00 +00:00
amend
This commit is contained in:
parent
76205521d7
commit
244137e8a1
@ -1769,36 +1769,6 @@
|
||||
"print(\"\\nSecond head:\\n\", second_res)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"2360064"
|
||||
]
|
||||
},
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"block_size = 1024\n",
|
||||
"d_in, d_out = 768, 768\n",
|
||||
"num_heads = 12\n",
|
||||
"\n",
|
||||
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)\n",
|
||||
"\n",
|
||||
"def count_parameters(model):\n",
|
||||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||||
"\n",
|
||||
"count_parameters(mha)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dec671bf-7938-4304-ad1e-75d9920e7f43",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user