RoPE theta rescaling (#419)

* rope fixes

* update

* update

* cleanup
This commit is contained in:
Sebastian Raschka 2024-10-25 15:27:23 -05:00 committed by GitHub
parent 0ed1e0d099
commit 75ede3e340
2 changed files with 1025 additions and 105 deletions

View File

@ -2094,12 +2094,37 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 10,
"id": "9bdbe32f-4c96-4e60-8bf4-52b5217df1e6", "id": "a55a8769-1a03-4265-8fd0-15f1c423da53",
"metadata": {}, "metadata": {
"outputs": [], "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New RoPE theta: 31250.0\n"
]
}
],
"source": [ "source": [
"LLAMA31_CONFIG_8B[\"context_length\"] = 8192" "old_context_length = LLAMA31_CONFIG_8B[\"context_length\"]\n",
"LLAMA31_CONFIG_8B[\"context_length\"] = 8192\n",
"\n",
"\n",
"def rescale_theta(theta_old, context_length_old, context_length_new):\n",
" scaling_factor = context_length_new / context_length_old\n",
" theta_new = theta_old * scaling_factor\n",
" return theta_new\n",
"\n",
"LLAMA31_CONFIG_8B[\"rope_base\"] = rescale_theta(\n",
" LLAMA31_CONFIG_8B[\"rope_base\"],\n",
" old_context_length,\n",
" LLAMA31_CONFIG_8B[\"context_length\"]\n",
")\n",
"\n",
"print(\"New RoPE theta:\", LLAMA31_CONFIG_8B[\"rope_base\"])"
] ]
}, },
{ {
@ -2462,12 +2487,31 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 10,
"id": "387456c3-c6a1-46fe-8830-6e00eb46ac13", "id": "73f001a6-7ae0-4204-aa83-a27a8878dfd2",
"metadata": {}, "metadata": {
"outputs": [], "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New RoPE theta: 31250.0\n"
]
}
],
"source": [ "source": [
"LLAMA32_CONFIG_1B[\"context_length\"] = 8192" "old_context_length = LLAMA32_CONFIG_1B[\"context_length\"]\n",
"LLAMA32_CONFIG_1B[\"context_length\"] = 8192\n",
"\n",
"LLAMA32_CONFIG_1B[\"rope_base\"] = rescale_theta(\n",
" LLAMA32_CONFIG_1B[\"rope_base\"],\n",
" old_context_length,\n",
" LLAMA32_CONFIG_1B[\"context_length\"]\n",
")\n",
"\n",
"print(\"New RoPE theta:\", LLAMA32_CONFIG_1B[\"rope_base\"])"
] ]
}, },
{ {
@ -2689,7 +2733,7 @@
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "pt", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -2703,7 +2747,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.9" "version": "3.11.4"
}, },
"widgets": { "widgets": {
"application/vnd.jupyter.widget-state+json": { "application/vnd.jupyter.widget-state+json": {

File diff suppressed because it is too large Load Diff