{ "cells": [ { "cell_type": "code", "execution_count": 9, "id": "40d2405d-ee10-44ad-b20e-cf32078f926a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True | head dim: 1, tensor([]), tensor([])\n", "True | head dim: 2, tensor([1.]), tensor([1.])\n", "True | head dim: 3, tensor([1.]), tensor([1.])\n", "True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n", "False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n", "True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n", "False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n", "True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n", "False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n", "True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n", "False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n" ] } ], "source": [ "import torch\n", "\n", "theta_base = 10_000\n", "\n", "for head_dim in range(1, 12):\n", "\n", " before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", " after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n", " \n", " s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n", " print(s)\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0abfbf38-93a4-4994-8e7e-a543477268a8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }