diff --git a/notebook/tune_pytorch.ipynb b/notebook/tune_pytorch.ipynb index d90f4fb9c..93153ac50 100644 --- a/notebook/tune_pytorch.ipynb +++ b/notebook/tune_pytorch.ipynb @@ -347,7 +347,11 @@ " best_trained_model = nn.DataParallel(best_trained_model)\n", "best_trained_model.to(device)\n", "\n", - "checkpoint_path = os.path.join(best_trial.checkpoint.value, \"checkpoint\")\n", + "checkpoint_value = (\n", + " getattr(best_trial.checkpoint, \"dir_or_data\", None)\n", + " or best_trial.checkpoint.value\n", + ")\n", + "checkpoint_path = os.path.join(checkpoint_value, \"checkpoint\")\n", "\n", "model_state, optimizer_state = torch.load(checkpoint_path)\n", "best_trained_model.load_state_dict(model_state)\n", @@ -358,11 +362,9 @@ } ], "metadata": { - "interpreter": { - "hash": "f7771e6a3915580179405189f5aa4eb9047494cbe4e005b29b851351b54902f6" - }, "kernelspec": { - "display_name": "Python 3.8.10 64-bit ('venv': venv)", + "display_name": "Python 3.11.0 64-bit", + "language": "python", "name": "python3" }, "language_info": { @@ -375,12 +377,17 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.11.0" }, "metadata": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } + }, + "vscode": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } } }, "nbformat": 4, diff --git a/test/tune/test_pytorch_cifar10.py b/test/tune/test_pytorch_cifar10.py index 2151bf281..188d9750f 100644 --- a/test/tune/test_pytorch_cifar10.py +++ b/test/tune/test_pytorch_cifar10.py @@ -313,7 +313,11 @@ def cifar10_main( best_trained_model = nn.DataParallel(best_trained_model) best_trained_model.to(device) - checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint") + checkpoint_value = ( + getattr(best_trial.checkpoint, "dir_or_data", None) + or best_trial.checkpoint.value + ) + checkpoint_path = os.path.join(checkpoint_value, "checkpoint") model_state, optimizer_state = torch.load(checkpoint_path) best_trained_model.load_state_dict(model_state) diff --git a/website/docs/Examples/Tune-PyTorch.md b/website/docs/Examples/Tune-PyTorch.md index 83f38e609..d75c716c7 100644 --- a/website/docs/Examples/Tune-PyTorch.md +++ b/website/docs/Examples/Tune-PyTorch.md @@ -261,7 +261,8 @@ if torch.cuda.is_available(): best_trained_model = nn.DataParallel(best_trained_model) best_trained_model.to(device) -checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint") +checkpoint_value = getattr(best_trial.checkpoint, "dir_or_data", None) or best_trial.checkpoint.value +checkpoint_path = os.path.join(checkpoint_value, "checkpoint") model_state, optimizer_state = torch.load(checkpoint_path) best_trained_model.load_state_dict(model_state) @@ -283,4 +284,4 @@ Files already downloaded and verified Best trial test set accuracy: 0.6294 ``` -[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) \ No newline at end of file +[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb)