Merge pull request #853 from jingdong00/jingdong00-patch-1

Fix example tune-pytorch where the checkpoint path may be named differently
This commit is contained in:
Shaokun 2022-12-19 16:38:19 -05:00 committed by GitHub
commit f98b7555e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 9 deletions

View File

@ -347,7 +347,11 @@
" best_trained_model = nn.DataParallel(best_trained_model)\n", " best_trained_model = nn.DataParallel(best_trained_model)\n",
"best_trained_model.to(device)\n", "best_trained_model.to(device)\n",
"\n", "\n",
"checkpoint_path = os.path.join(best_trial.checkpoint.value, \"checkpoint\")\n", "checkpoint_value = (\n",
" getattr(best_trial.checkpoint, \"dir_or_data\", None)\n",
" or best_trial.checkpoint.value\n",
")\n",
"checkpoint_path = os.path.join(checkpoint_value, \"checkpoint\")\n",
"\n", "\n",
"model_state, optimizer_state = torch.load(checkpoint_path)\n", "model_state, optimizer_state = torch.load(checkpoint_path)\n",
"best_trained_model.load_state_dict(model_state)\n", "best_trained_model.load_state_dict(model_state)\n",
@ -358,11 +362,9 @@
} }
], ],
"metadata": { "metadata": {
"interpreter": {
"hash": "f7771e6a3915580179405189f5aa4eb9047494cbe4e005b29b851351b54902f6"
},
"kernelspec": { "kernelspec": {
"display_name": "Python 3.8.10 64-bit ('venv': venv)", "display_name": "Python 3.11.0 64-bit",
"language": "python",
"name": "python3" "name": "python3"
}, },
"language_info": { "language_info": {
@ -375,12 +377,17 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.12" "version": "3.11.0"
}, },
"metadata": { "metadata": {
"interpreter": { "interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
} }
},
"vscode": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
}
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -313,7 +313,11 @@ def cifar10_main(
best_trained_model = nn.DataParallel(best_trained_model) best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device) best_trained_model.to(device)
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint") checkpoint_value = (
getattr(best_trial.checkpoint, "dir_or_data", None)
or best_trial.checkpoint.value
)
checkpoint_path = os.path.join(checkpoint_value, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint_path) model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state) best_trained_model.load_state_dict(model_state)

View File

@ -261,7 +261,8 @@ if torch.cuda.is_available():
best_trained_model = nn.DataParallel(best_trained_model) best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device) best_trained_model.to(device)
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint") checkpoint_value = getattr(best_trial.checkpoint, "dir_or_data", None) or best_trial.checkpoint.value
checkpoint_path = os.path.join(checkpoint_value, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint_path) model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state) best_trained_model.load_state_dict(model_state)
@ -283,4 +284,4 @@ Files already downloaded and verified
Best trial test set accuracy: 0.6294 Best trial test set accuracy: 0.6294
``` ```
[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) [Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb)