mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 01:41:26 +00:00 
			
		
		
		
	Run generate example in ch06 optionally on GPU (#352)
* model.to("cuda")
model.to("cuda")
* update device placement
---------
Co-authored-by: rasbt <mail@sebastianraschka.com>
			
			
This commit is contained in:
		
							parent
							
								
									0da32f3976
								
							
						
					
					
						commit
						21e6971b11
					
				| @ -3,7 +3,9 @@ | |||||||
|   { |   { | ||||||
|    "cell_type": "markdown", |    "cell_type": "markdown", | ||||||
|    "id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d", |    "id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d" | ||||||
|  |    }, | ||||||
|    "source": [ |    "source": [ | ||||||
|     "<table style=\"width:100%\">\n", |     "<table style=\"width:100%\">\n", | ||||||
|     "<tr>\n", |     "<tr>\n", | ||||||
| @ -23,7 +25,9 @@ | |||||||
|   { |   { | ||||||
|    "cell_type": "markdown", |    "cell_type": "markdown", | ||||||
|    "id": "f3f83194-82b9-4478-9550-5ad793467bd0", |    "id": "f3f83194-82b9-4478-9550-5ad793467bd0", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "id": "f3f83194-82b9-4478-9550-5ad793467bd0" | ||||||
|  |    }, | ||||||
|    "source": [ |    "source": [ | ||||||
|     "# Load And Use Finetuned Model" |     "# Load And Use Finetuned Model" | ||||||
|    ] |    ] | ||||||
| @ -31,7 +35,9 @@ | |||||||
|   { |   { | ||||||
|    "cell_type": "markdown", |    "cell_type": "markdown", | ||||||
|    "id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e", |    "id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e" | ||||||
|  |    }, | ||||||
|    "source": [ |    "source": [ | ||||||
|     "This notebook contains minimal code to load the finetuned model that was created and saved in chapter 6 via [ch06.ipynb](ch06.ipynb)." |     "This notebook contains minimal code to load the finetuned model that was created and saved in chapter 6 via [ch06.ipynb](ch06.ipynb)." | ||||||
|    ] |    ] | ||||||
| @ -40,7 +46,13 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 1, |    "execution_count": 1, | ||||||
|    "id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52", |    "id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "colab": { | ||||||
|  |      "base_uri": "https://localhost:8080/" | ||||||
|  |     }, | ||||||
|  |     "id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52", | ||||||
|  |     "outputId": "9eeefb8e-a7eb-4d62-cf78-c797b3ed4e2e" | ||||||
|  |    }, | ||||||
|    "outputs": [ |    "outputs": [ | ||||||
|     { |     { | ||||||
|      "name": "stdout", |      "name": "stdout", | ||||||
| @ -66,7 +78,9 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 2, |    "execution_count": 2, | ||||||
|    "id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201", |    "id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201" | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "from pathlib import Path\n", |     "from pathlib import Path\n", | ||||||
| @ -83,7 +97,9 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 3, |    "execution_count": 3, | ||||||
|    "id": "fb02584a-5e31-45d5-8377-794876907bc6", |    "id": "fb02584a-5e31-45d5-8377-794876907bc6", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "id": "fb02584a-5e31-45d5-8377-794876907bc6" | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "from previous_chapters import GPTModel\n", |     "from previous_chapters import GPTModel\n", | ||||||
| @ -116,7 +132,9 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 4, |    "execution_count": 4, | ||||||
|    "id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94", |    "id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94" | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "import torch\n", |     "import torch\n", | ||||||
| @ -128,6 +146,7 @@ | |||||||
|     "# Then load pretrained weights\n", |     "# Then load pretrained weights\n", | ||||||
|     "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", |     "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||||||
|     "model.load_state_dict(torch.load(\"review_classifier.pth\", map_location=device, weights_only=True))\n", |     "model.load_state_dict(torch.load(\"review_classifier.pth\", map_location=device, weights_only=True))\n", | ||||||
|  |     "model.to(device)\n", | ||||||
|     "model.eval();" |     "model.eval();" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
| @ -135,7 +154,9 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 5, |    "execution_count": 5, | ||||||
|    "id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5", |    "id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5" | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "import tiktoken\n", |     "import tiktoken\n", | ||||||
| @ -147,7 +168,9 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 6, |    "execution_count": 6, | ||||||
|    "id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2", |    "id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2" | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "# This function was implemented in ch06.ipynb\n", |     "# This function was implemented in ch06.ipynb\n", | ||||||
| @ -167,7 +190,7 @@ | |||||||
|     "\n", |     "\n", | ||||||
|     "    # Model inference\n", |     "    # Model inference\n", | ||||||
|     "    with torch.no_grad():\n", |     "    with torch.no_grad():\n", | ||||||
|     "        logits = model(input_tensor)[:, -1, :]  # Logits of the last output token\n", |     "        logits = model(input_tensor.to(device))[:, -1, :]  # Logits of the last output token\n", | ||||||
|     "    predicted_label = torch.argmax(logits, dim=-1).item()\n", |     "    predicted_label = torch.argmax(logits, dim=-1).item()\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "    # Return the classified result\n", |     "    # Return the classified result\n", | ||||||
| @ -178,7 +201,13 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 7, |    "execution_count": 7, | ||||||
|    "id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f", |    "id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "colab": { | ||||||
|  |      "base_uri": "https://localhost:8080/" | ||||||
|  |     }, | ||||||
|  |     "id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f", | ||||||
|  |     "outputId": "28eb2c02-0e38-4356-b2a3-2bf6accb5316" | ||||||
|  |    }, | ||||||
|    "outputs": [ |    "outputs": [ | ||||||
|     { |     { | ||||||
|      "name": "stdout", |      "name": "stdout", | ||||||
| @ -203,7 +232,13 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 8, |    "execution_count": 8, | ||||||
|    "id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8", |    "id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "colab": { | ||||||
|  |      "base_uri": "https://localhost:8080/" | ||||||
|  |     }, | ||||||
|  |     "id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8", | ||||||
|  |     "outputId": "0cd3cd62-f407-45f3-fa4f-51ff665355eb" | ||||||
|  |    }, | ||||||
|    "outputs": [ |    "outputs": [ | ||||||
|     { |     { | ||||||
|      "name": "stdout", |      "name": "stdout", | ||||||
| @ -226,6 +261,11 @@ | |||||||
|   } |   } | ||||||
|  ], |  ], | ||||||
|  "metadata": { |  "metadata": { | ||||||
|  |   "accelerator": "GPU", | ||||||
|  |   "colab": { | ||||||
|  |    "gpuType": "L4", | ||||||
|  |    "provenance": [] | ||||||
|  |   }, | ||||||
|   "kernelspec": { |   "kernelspec": { | ||||||
|    "display_name": "Python 3 (ipykernel)", |    "display_name": "Python 3 (ipykernel)", | ||||||
|    "language": "python", |    "language": "python", | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Mingyuan Xu
						Mingyuan Xu