From 3a194d047bae043fbdc79dfd3cde05be8687dc49 Mon Sep 17 00:00:00 2001 From: Jing Dong Date: Mon, 19 Dec 2022 09:22:16 -0800 Subject: [PATCH] fix checkpoint.value in the notebook and test --- notebook/tune_pytorch.ipynb | 19 +++++++++++++------ test/tune/test_pytorch_cifar10.py | 6 +++++- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/notebook/tune_pytorch.ipynb b/notebook/tune_pytorch.ipynb index d90f4fb9c..93153ac50 100644 --- a/notebook/tune_pytorch.ipynb +++ b/notebook/tune_pytorch.ipynb @@ -347,7 +347,11 @@ " best_trained_model = nn.DataParallel(best_trained_model)\n", "best_trained_model.to(device)\n", "\n", - "checkpoint_path = os.path.join(best_trial.checkpoint.value, \"checkpoint\")\n", + "checkpoint_value = (\n", + " getattr(best_trial.checkpoint, \"dir_or_data\", None)\n", + " or best_trial.checkpoint.value\n", + ")\n", + "checkpoint_path = os.path.join(checkpoint_value, \"checkpoint\")\n", "\n", "model_state, optimizer_state = torch.load(checkpoint_path)\n", "best_trained_model.load_state_dict(model_state)\n", @@ -358,11 +362,9 @@ } ], "metadata": { - "interpreter": { - "hash": "f7771e6a3915580179405189f5aa4eb9047494cbe4e005b29b851351b54902f6" - }, "kernelspec": { - "display_name": "Python 3.8.10 64-bit ('venv': venv)", + "display_name": "Python 3.11.0 64-bit", + "language": "python", "name": "python3" }, "language_info": { @@ -375,12 +377,17 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.11.0" }, "metadata": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } + }, + "vscode": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } } }, "nbformat": 4, diff --git a/test/tune/test_pytorch_cifar10.py b/test/tune/test_pytorch_cifar10.py index 2151bf281..188d9750f 100644 --- a/test/tune/test_pytorch_cifar10.py +++ b/test/tune/test_pytorch_cifar10.py @@ -313,7 +313,11 @@ def cifar10_main( best_trained_model = nn.DataParallel(best_trained_model) best_trained_model.to(device) - checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint") + checkpoint_value = ( + getattr(best_trial.checkpoint, "dir_or_data", None) + or best_trial.checkpoint.value + ) + checkpoint_path = os.path.join(checkpoint_value, "checkpoint") model_state, optimizer_state = torch.load(checkpoint_path) best_trained_model.load_state_dict(model_state)