mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-21 14:14:19 +00:00
Run generate example in ch06 optionally on GPU (#352)
* model.to("cuda") model.to("cuda") * update device placement --------- Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
parent
6f92909e58
commit
f77c376b05
@ -3,7 +3,9 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d",
|
"id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<table style=\"width:100%\">\n",
|
"<table style=\"width:100%\">\n",
|
||||||
"<tr>\n",
|
"<tr>\n",
|
||||||
@ -23,7 +25,9 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "f3f83194-82b9-4478-9550-5ad793467bd0",
|
"id": "f3f83194-82b9-4478-9550-5ad793467bd0",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "f3f83194-82b9-4478-9550-5ad793467bd0"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# Load And Use Finetuned Model"
|
"# Load And Use Finetuned Model"
|
||||||
]
|
]
|
||||||
@ -31,7 +35,9 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e",
|
"id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"This notebook contains minimal code to load the finetuned model that was created and saved in chapter 6 via [ch06.ipynb](ch06.ipynb)."
|
"This notebook contains minimal code to load the finetuned model that was created and saved in chapter 6 via [ch06.ipynb](ch06.ipynb)."
|
||||||
]
|
]
|
||||||
@ -40,7 +46,13 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52",
|
"id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52",
|
||||||
|
"outputId": "9eeefb8e-a7eb-4d62-cf78-c797b3ed4e2e"
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@ -66,7 +78,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201",
|
"id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
@ -83,7 +97,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"id": "fb02584a-5e31-45d5-8377-794876907bc6",
|
"id": "fb02584a-5e31-45d5-8377-794876907bc6",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "fb02584a-5e31-45d5-8377-794876907bc6"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from previous_chapters import GPTModel\n",
|
"from previous_chapters import GPTModel\n",
|
||||||
@ -116,7 +132,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
"id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94",
|
"id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
@ -128,6 +146,7 @@
|
|||||||
"# Then load pretrained weights\n",
|
"# Then load pretrained weights\n",
|
||||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
"model.load_state_dict(torch.load(\"review_classifier.pth\", map_location=device, weights_only=True))\n",
|
"model.load_state_dict(torch.load(\"review_classifier.pth\", map_location=device, weights_only=True))\n",
|
||||||
|
"model.to(device)\n",
|
||||||
"model.eval();"
|
"model.eval();"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -135,7 +154,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 5,
|
||||||
"id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5",
|
"id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import tiktoken\n",
|
"import tiktoken\n",
|
||||||
@ -147,7 +168,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 6,
|
||||||
"id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2",
|
"id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# This function was implemented in ch06.ipynb\n",
|
"# This function was implemented in ch06.ipynb\n",
|
||||||
@ -167,7 +190,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Model inference\n",
|
" # Model inference\n",
|
||||||
" with torch.no_grad():\n",
|
" with torch.no_grad():\n",
|
||||||
" logits = model(input_tensor)[:, -1, :] # Logits of the last output token\n",
|
" logits = model(input_tensor.to(device))[:, -1, :] # Logits of the last output token\n",
|
||||||
" predicted_label = torch.argmax(logits, dim=-1).item()\n",
|
" predicted_label = torch.argmax(logits, dim=-1).item()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Return the classified result\n",
|
" # Return the classified result\n",
|
||||||
@ -178,7 +201,13 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 7,
|
||||||
"id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f",
|
"id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f",
|
||||||
|
"outputId": "28eb2c02-0e38-4356-b2a3-2bf6accb5316"
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@ -203,7 +232,13 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 8,
|
||||||
"id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8",
|
"id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8",
|
||||||
|
"outputId": "0cd3cd62-f407-45f3-fa4f-51ff665355eb"
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@ -226,6 +261,11 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"gpuType": "L4",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user