mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-09 22:34:29 +00:00
fix checkpoint.value in the notebook and test
This commit is contained in:
parent
7e4e4c7901
commit
3a194d047b
@ -347,7 +347,11 @@
|
|||||||
" best_trained_model = nn.DataParallel(best_trained_model)\n",
|
" best_trained_model = nn.DataParallel(best_trained_model)\n",
|
||||||
"best_trained_model.to(device)\n",
|
"best_trained_model.to(device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"checkpoint_path = os.path.join(best_trial.checkpoint.value, \"checkpoint\")\n",
|
"checkpoint_value = (\n",
|
||||||
|
" getattr(best_trial.checkpoint, \"dir_or_data\", None)\n",
|
||||||
|
" or best_trial.checkpoint.value\n",
|
||||||
|
")\n",
|
||||||
|
"checkpoint_path = os.path.join(checkpoint_value, \"checkpoint\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model_state, optimizer_state = torch.load(checkpoint_path)\n",
|
"model_state, optimizer_state = torch.load(checkpoint_path)\n",
|
||||||
"best_trained_model.load_state_dict(model_state)\n",
|
"best_trained_model.load_state_dict(model_state)\n",
|
||||||
@ -358,11 +362,9 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"interpreter": {
|
|
||||||
"hash": "f7771e6a3915580179405189f5aa4eb9047494cbe4e005b29b851351b54902f6"
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3.8.10 64-bit ('venv': venv)",
|
"display_name": "Python 3.11.0 64-bit",
|
||||||
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
@ -375,12 +377,17 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.8.12"
|
"version": "3.11.0"
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@ -313,7 +313,11 @@ def cifar10_main(
|
|||||||
best_trained_model = nn.DataParallel(best_trained_model)
|
best_trained_model = nn.DataParallel(best_trained_model)
|
||||||
best_trained_model.to(device)
|
best_trained_model.to(device)
|
||||||
|
|
||||||
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
|
checkpoint_value = (
|
||||||
|
getattr(best_trial.checkpoint, "dir_or_data", None)
|
||||||
|
or best_trial.checkpoint.value
|
||||||
|
)
|
||||||
|
checkpoint_path = os.path.join(checkpoint_value, "checkpoint")
|
||||||
|
|
||||||
model_state, optimizer_state = torch.load(checkpoint_path)
|
model_state, optimizer_state = torch.load(checkpoint_path)
|
||||||
best_trained_model.load_state_dict(model_state)
|
best_trained_model.load_state_dict(model_state)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user