mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-30 01:10:33 +00:00
Add a note about weight tying in Llama 3.2 (#386)
This commit is contained in:
parent
58c3bb3d9d
commit
81053ccadd
File diff suppressed because it is too large
Load Diff
@ -445,13 +445,19 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Total number of parameters: 1,498,482,688\n"
|
"Total number of parameters: 1,498,482,688\n",
|
||||||
|
"\n",
|
||||||
|
"Total number of unique parameters: 1,235,814,400\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"total_params = sum(p.numel() for p in model.parameters())\n",
|
"total_params = sum(p.numel() for p in model.parameters())\n",
|
||||||
"print(f\"Total number of parameters: {total_params:,}\")"
|
"print(f\"Total number of parameters: {total_params:,}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Account for weight tying\n",
|
||||||
|
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
|
||||||
|
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -627,7 +633,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "edcc384a-adb7-43f6-acc3-ebe4b182ec91",
|
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -749,7 +755,8 @@
|
|||||||
" if \"lm_head.weight\" in params.keys():\n",
|
" if \"lm_head.weight\" in params.keys():\n",
|
||||||
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
|
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
|
||||||
|
" print(\"Model uses weight tying.\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -757,7 +764,15 @@
|
|||||||
"execution_count": 18,
|
"execution_count": 18,
|
||||||
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Model uses weight tying.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from safetensors.torch import load_file\n",
|
"from safetensors.torch import load_file\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -773,7 +788,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" combined_weights = {}\n",
|
" combined_weights = {}\n",
|
||||||
" for i in range(1, 5):\n",
|
" for i in range(1, 3):\n",
|
||||||
" weights_file = hf_hub_download(\n",
|
" weights_file = hf_hub_download(\n",
|
||||||
" repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
|
" repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
|
||||||
" filename=f\"model-0000{i}-of-00002.safetensors\",\n",
|
" filename=f\"model-0000{i}-of-00002.safetensors\",\n",
|
||||||
@ -787,6 +802,24 @@
|
|||||||
"model.to(device);"
|
"model.to(device);"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Weight tying: True\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
|
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
|
||||||
@ -798,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 20,
|
||||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -855,7 +888,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 21,
|
||||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -934,14 +967,6 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
|
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "bf864c28-2ce1-44bf-84e4-c0671f494d62",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user