mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 01:41:26 +00:00 
			
		
		
		
	Run generate example in ch06 optionally on GPU (#352)
* model.to("cuda")
model.to("cuda")
* update device placement
---------
Co-authored-by: rasbt <mail@sebastianraschka.com>
			
			
This commit is contained in:
		
							parent
							
								
									0da32f3976
								
							
						
					
					
						commit
						21e6971b11
					
				| @ -3,7 +3,9 @@ | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d" | ||||
|    }, | ||||
|    "source": [ | ||||
|     "<table style=\"width:100%\">\n", | ||||
|     "<tr>\n", | ||||
| @ -23,7 +25,9 @@ | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "f3f83194-82b9-4478-9550-5ad793467bd0", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "id": "f3f83194-82b9-4478-9550-5ad793467bd0" | ||||
|    }, | ||||
|    "source": [ | ||||
|     "# Load And Use Finetuned Model" | ||||
|    ] | ||||
| @ -31,7 +35,9 @@ | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e" | ||||
|    }, | ||||
|    "source": [ | ||||
|     "This notebook contains minimal code to load the finetuned model that was created and saved in chapter 6 via [ch06.ipynb](ch06.ipynb)." | ||||
|    ] | ||||
| @ -40,7 +46,13 @@ | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "colab": { | ||||
|      "base_uri": "https://localhost:8080/" | ||||
|     }, | ||||
|     "id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52", | ||||
|     "outputId": "9eeefb8e-a7eb-4d62-cf78-c797b3ed4e2e" | ||||
|    }, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
| @ -66,7 +78,9 @@ | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201" | ||||
|    }, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from pathlib import Path\n", | ||||
| @ -83,7 +97,9 @@ | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "id": "fb02584a-5e31-45d5-8377-794876907bc6", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "id": "fb02584a-5e31-45d5-8377-794876907bc6" | ||||
|    }, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from previous_chapters import GPTModel\n", | ||||
| @ -116,7 +132,9 @@ | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94" | ||||
|    }, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import torch\n", | ||||
| @ -128,6 +146,7 @@ | ||||
|     "# Then load pretrained weights\n", | ||||
|     "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||||
|     "model.load_state_dict(torch.load(\"review_classifier.pth\", map_location=device, weights_only=True))\n", | ||||
|     "model.to(device)\n", | ||||
|     "model.eval();" | ||||
|    ] | ||||
|   }, | ||||
| @ -135,7 +154,9 @@ | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 5, | ||||
|    "id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5" | ||||
|    }, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import tiktoken\n", | ||||
| @ -147,7 +168,9 @@ | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 6, | ||||
|    "id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2" | ||||
|    }, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "# This function was implemented in ch06.ipynb\n", | ||||
| @ -167,7 +190,7 @@ | ||||
|     "\n", | ||||
|     "    # Model inference\n", | ||||
|     "    with torch.no_grad():\n", | ||||
|     "        logits = model(input_tensor)[:, -1, :]  # Logits of the last output token\n", | ||||
|     "        logits = model(input_tensor.to(device))[:, -1, :]  # Logits of the last output token\n", | ||||
|     "    predicted_label = torch.argmax(logits, dim=-1).item()\n", | ||||
|     "\n", | ||||
|     "    # Return the classified result\n", | ||||
| @ -178,7 +201,13 @@ | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 7, | ||||
|    "id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "colab": { | ||||
|      "base_uri": "https://localhost:8080/" | ||||
|     }, | ||||
|     "id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f", | ||||
|     "outputId": "28eb2c02-0e38-4356-b2a3-2bf6accb5316" | ||||
|    }, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
| @ -203,7 +232,13 @@ | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 8, | ||||
|    "id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8", | ||||
|    "metadata": {}, | ||||
|    "metadata": { | ||||
|     "colab": { | ||||
|      "base_uri": "https://localhost:8080/" | ||||
|     }, | ||||
|     "id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8", | ||||
|     "outputId": "0cd3cd62-f407-45f3-fa4f-51ff665355eb" | ||||
|    }, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
| @ -226,6 +261,11 @@ | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "accelerator": "GPU", | ||||
|   "colab": { | ||||
|    "gpuType": "L4", | ||||
|    "provenance": [] | ||||
|   }, | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3 (ipykernel)", | ||||
|    "language": "python", | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Mingyuan Xu
						Mingyuan Xu