mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-29 17:01:30 +00:00
consistency
This commit is contained in:
parent
c1f9361428
commit
3ba51abf53
@ -1108,8 +1108,6 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from functools import partial\n",
|
|
||||||
"\n",
|
|
||||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# If you have a Mac with Apple Silicon chip, you can uncomment the next lines of code\n",
|
"# If you have a Mac with Apple Silicon chip, you can uncomment the next lines of code\n",
|
||||||
@ -1120,7 +1118,17 @@
|
|||||||
"# if torch.backends.mps.is_available():\n",
|
"# if torch.backends.mps.is_available():\n",
|
||||||
"# device = torch.device(\"mps\")\n",
|
"# device = torch.device(\"mps\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Device:\", device)\n",
|
"print(\"Device:\", device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4e47fb30-c2c6-4e6d-a64c-76cc65be4a2c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from functools import partial\n",
|
||||||
"\n",
|
"\n",
|
||||||
"customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)"
|
"customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)"
|
||||||
]
|
]
|
||||||
@ -2150,7 +2158,8 @@
|
|||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "8cBU0iHmVfOI",
|
"id": "8cBU0iHmVfOI",
|
||||||
"outputId": "860a2d06-2d0e-4ae8-943d-dd12d299eed9"
|
"outputId": "860a2d06-2d0e-4ae8-943d-dd12d299eed9",
|
||||||
|
"scrolled": true
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -2164,7 +2173,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"import re\n",
|
"import re\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model.cpu()\n",
|
"\n",
|
||||||
"file_name = f\"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft.pth\"\n",
|
"file_name = f\"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft.pth\"\n",
|
||||||
"torch.save(model.state_dict(), file_name)\n",
|
"torch.save(model.state_dict(), file_name)\n",
|
||||||
"print(f\"Model saved as {file_name}\")\n",
|
"print(f\"Model saved as {file_name}\")\n",
|
||||||
@ -2728,7 +2737,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.11"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user