From cb194fa8fa258967b6301cf063c850a834c22a50 Mon Sep 17 00:00:00 2001 From: rasbt Date: Thu, 20 Jun 2024 08:07:00 -0500 Subject: [PATCH] fix device loading --- ch05/01_main-chapter-code/ch05.ipynb | 3 ++- ch05/01_main-chapter-code/exercise-solutions.ipynb | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ch05/01_main-chapter-code/ch05.ipynb b/ch05/01_main-chapter-code/ch05.ipynb index d8a06ef..21906af 100644 --- a/ch05/01_main-chapter-code/ch05.ipynb +++ b/ch05/01_main-chapter-code/ch05.ipynb @@ -1985,7 +1985,8 @@ "outputs": [], "source": [ "model = GPTModel(GPT_CONFIG_124M)\n", - "model.load_state_dict(torch.load(\"model.pth\"))\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "model.load_state_dict(torch.load(\"model.pth\", map_location=device))\n", "model.eval();" ] }, diff --git a/ch05/01_main-chapter-code/exercise-solutions.ipynb b/ch05/01_main-chapter-code/exercise-solutions.ipynb index 577ddf7..c5f7e92 100644 --- a/ch05/01_main-chapter-code/exercise-solutions.ipynb +++ b/ch05/01_main-chapter-code/exercise-solutions.ipynb @@ -427,6 +427,7 @@ "checkpoint = torch.load(\"model_and_optimizer.pth\")\n", "model = GPTModel(GPT_CONFIG_124M)\n", "model.load_state_dict(checkpoint[\"model_state_dict\"])\n", + "model.to(device)\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)\n", "optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n", @@ -958,7 +959,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.6" } }, "nbformat": 4,