Fix checkpoint path issue

checkpoint path may be named dir_or_data instead of value
This commit is contained in:
Jing Dong 2022-12-16 14:41:33 +08:00 committed by GitHub
parent c1872861b6
commit 5778227a71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -261,7 +261,8 @@ if torch.cuda.is_available():
best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device)
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
checkpoint_value = getattr(best_trial.checkpoint, "dir_or_data", None) or best_trial.checkpoint.value
checkpoint_path = os.path.join(checkpoint_value, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)
@ -283,4 +284,4 @@ Files already downloaded and verified
Best trial test set accuracy: 0.6294
```
[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb)
[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb)