Sebastian Raschka 7cd6a670ed
RoPE updates (#412)
* RoPE updates

* Apply suggestions from code review

* updates

* updates

* updates
2024-10-23 18:07:49 -05:00

75 lines
2.4 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"id": "40d2405d-ee10-44ad-b20e-cf32078f926a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True | head dim: 1, tensor([]), tensor([])\n",
"True | head dim: 2, tensor([1.]), tensor([1.])\n",
"True | head dim: 3, tensor([1.]), tensor([1.])\n",
"True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n",
"False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n",
"True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n",
"False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n",
"True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n",
"False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n",
"True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n",
"False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n"
]
}
],
"source": [
"import torch\n",
"\n",
"theta_base = 10_000\n",
"\n",
"for head_dim in range(1, 12):\n",
"\n",
" before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
" after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
" \n",
" s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n",
" print(s)\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0abfbf38-93a4-4994-8e7e-a543477268a8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}