consistency

This commit is contained in:
rasbt 2024-06-19 19:47:31 -05:00
parent c1f9361428
commit 3ba51abf53

View File

@ -1108,8 +1108,6 @@
} }
], ],
"source": [ "source": [
"from functools import partial\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n", "\n",
"# If you have a Mac with Apple Silicon chip, you can uncomment the next lines of code\n", "# If you have a Mac with Apple Silicon chip, you can uncomment the next lines of code\n",
@ -1120,7 +1118,17 @@
"# if torch.backends.mps.is_available():\n", "# if torch.backends.mps.is_available():\n",
"# device = torch.device(\"mps\")\n", "# device = torch.device(\"mps\")\n",
"\n", "\n",
"print(\"Device:\", device)\n", "print(\"Device:\", device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e47fb30-c2c6-4e6d-a64c-76cc65be4a2c",
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n", "\n",
"customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)" "customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)"
] ]
@ -2150,7 +2158,8 @@
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
}, },
"id": "8cBU0iHmVfOI", "id": "8cBU0iHmVfOI",
"outputId": "860a2d06-2d0e-4ae8-943d-dd12d299eed9" "outputId": "860a2d06-2d0e-4ae8-943d-dd12d299eed9",
"scrolled": true
}, },
"outputs": [ "outputs": [
{ {
@ -2164,7 +2173,7 @@
"source": [ "source": [
"import re\n", "import re\n",
"\n", "\n",
"model.cpu()\n", "\n",
"file_name = f\"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft.pth\"\n", "file_name = f\"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft.pth\"\n",
"torch.save(model.state_dict(), file_name)\n", "torch.save(model.state_dict(), file_name)\n",
"print(f\"Model saved as {file_name}\")\n", "print(f\"Model saved as {file_name}\")\n",
@ -2728,7 +2737,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.11" "version": "3.11.4"
} }
}, },
"nbformat": 4, "nbformat": 4,