mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 09:50:23 +00:00 
			
		
		
		
	Add a note about weight tying in Llama 3.2 (#386)
This commit is contained in:
		
							parent
							
								
									58c3bb3d9d
								
							
						
					
					
						commit
						81053ccadd
					
				
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -445,13 +445,19 @@ | |||||||
|      "name": "stdout", |      "name": "stdout", | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|       "Total number of parameters: 1,498,482,688\n" |       "Total number of parameters: 1,498,482,688\n", | ||||||
|  |       "\n", | ||||||
|  |       "Total number of unique parameters: 1,235,814,400\n" | ||||||
|      ] |      ] | ||||||
|     } |     } | ||||||
|    ], |    ], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "total_params = sum(p.numel() for p in model.parameters())\n", |     "total_params = sum(p.numel() for p in model.parameters())\n", | ||||||
|     "print(f\"Total number of parameters: {total_params:,}\")" |     "print(f\"Total number of parameters: {total_params:,}\")\n", | ||||||
|  |     "\n", | ||||||
|  |     "# Account for weight tying\n", | ||||||
|  |     "total_params_normalized = total_params - model.tok_emb.weight.numel()\n", | ||||||
|  |     "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -627,7 +633,7 @@ | |||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "id": "edcc384a-adb7-43f6-acc3-ebe4b182ec91", |    "id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
| @ -749,7 +755,8 @@ | |||||||
|     "    if \"lm_head.weight\" in params.keys():\n", |     "    if \"lm_head.weight\" in params.keys():\n", | ||||||
|     "        model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n", |     "        model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n", | ||||||
|     "    else:\n", |     "    else:\n", | ||||||
|     "        model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")" |     "        model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n", | ||||||
|  |     "        print(\"Model uses weight tying.\")" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -757,7 +764,15 @@ | |||||||
|    "execution_count": 18, |    "execution_count": 18, | ||||||
|    "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", |    "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "Model uses weight tying.\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "from safetensors.torch import load_file\n", |     "from safetensors.torch import load_file\n", | ||||||
|     "\n", |     "\n", | ||||||
| @ -773,7 +788,7 @@ | |||||||
|     "\n", |     "\n", | ||||||
|     "else:\n", |     "else:\n", | ||||||
|     "    combined_weights = {}\n", |     "    combined_weights = {}\n", | ||||||
|     "    for i in range(1, 5):\n", |     "    for i in range(1, 3):\n", | ||||||
|     "        weights_file = hf_hub_download(\n", |     "        weights_file = hf_hub_download(\n", | ||||||
|     "            repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", |     "            repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", | ||||||
|     "            filename=f\"model-0000{i}-of-00002.safetensors\",\n", |     "            filename=f\"model-0000{i}-of-00002.safetensors\",\n", | ||||||
| @ -787,6 +802,24 @@ | |||||||
|     "model.to(device);" |     "model.to(device);" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 19, | ||||||
|  |    "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "Weight tying: True\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "markdown", |    "cell_type": "markdown", | ||||||
|    "id": "57d07df1-4401-4792-b549-7c4cc5632323", |    "id": "57d07df1-4401-4792-b549-7c4cc5632323", | ||||||
| @ -798,7 +831,7 @@ | |||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 19, |    "execution_count": 20, | ||||||
|    "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", |    "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
| @ -855,7 +888,7 @@ | |||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 20, |    "execution_count": 21, | ||||||
|    "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", |    "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [ | ||||||
| @ -934,14 +967,6 @@ | |||||||
|     "\n", |     "\n", | ||||||
|     "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>" |     "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>" | ||||||
|    ] |    ] | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": null, |  | ||||||
|    "id": "bf864c28-2ce1-44bf-84e4-c0671f494d62", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [] |  | ||||||
|   } |   } | ||||||
|  ], |  ], | ||||||
|  "metadata": { |  "metadata": { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian Raschka
						Sebastian Raschka