"- Embedding layers in PyTorch accomplish the same as linear layers that perform matrix multiplications; the reason we use embedding layers is computational efficiency\n",
"- We will take a look at this relationship step by step using code examples in PyTorch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "061720f4-f025-4640-82a0-15098fa94cf9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.1.0.dev20230825\n"
]
}
],
"source": [
"import torch\n",
"\n",
"print(\"PyTorch version:\", torch.__version__)"
]
},
{
"cell_type": "markdown",
"id": "a7895a66-7f69-4f62-9361-5c9da2eb76ef",
"metadata": {},
"source": [
"## Using nn.Embedding"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cc489ea5-73db-40b9-959e-0d70cae25f40",
"metadata": {},
"outputs": [],
"source": [
"# Suppose we have the following 3 training examples,\n",
"# which may represent token IDs in a LLM context\n",
"idx = torch.tensor([2, 3, 1])\n",
"\n",
"# The number of rows in the embedding matrix can be determined\n",
"# by obtaining the largest token ID + 1.\n",
"# If the highest token ID is 3, then we want 4 rows, for the possible\n",
"# token IDs 0, 1, 2, 3\n",
"num_idx = max(idx)+1\n",
"\n",
"# The desired embedding dimension is a hyperparameter\n",
"out_dim = 5"
]
},
{
"cell_type": "markdown",
"id": "93d83a6e-8543-40af-b253-fe647640bf36",
"metadata": {},
"source": [
"- Let's implement a simple embedding layer:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "60a7c104-36e1-4b28-bd02-a24a1099dc66",
"metadata": {},
"outputs": [],
"source": [
"# We use the random seed for reproducibility since\n",
"# weights in the embedding layer are initialized with\n",
"- Under the hood, it's still the same look-up concept:"
]
},
{
"cell_type": "markdown",
"id": "b392eb43-0bda-4821-b446-b8dcbee8ae00",
"metadata": {},
"source": [
"<img src=\"images/3.png\" width=\"450px\">"
]
},
{
"cell_type": "markdown",
"id": "f0fe863b-d6a3-48f3-ace5-09ecd0eb7b59",
"metadata": {},
"source": [
"## Using nn.Linear"
]
},
{
"cell_type": "markdown",
"id": "138de6a4-2689-4c1f-96af-7899b2d82a4e",
"metadata": {},
"source": [
"- Now, we will demonstrate that the embedding layer above accomplishes exactly the same as `nn.Linear` layer on a one-hot encoded representation in PyTorch\n",
"- First, let's convert the token IDs into a one-hot representation:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "b5bb56cf-bc73-41ab-b107-91a43f77bdba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 0, 1, 0],\n",
" [0, 0, 0, 1],\n",
" [0, 1, 0, 0]])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"onehot = torch.nn.functional.one_hot(idx)\n",
"onehot"
]
},
{
"cell_type": "markdown",
"id": "aa45dfdf-fb26-4514-a176-75224f5f179b",
"metadata": {},
"source": [
"- Next, we initialize a `Linear` layer, which caries out a matrix multiplication $X W^\\top$:"
"- Note that the linear layer in PyTorch is also initialized with small random weights; to directly compare it to the `Embedding` layer above, we have to use the same small random weights, which is why we reassign them here:"
"- What happens under the hood is the following computation for the first training example's token ID:"
]
},
{
"cell_type": "markdown",
"id": "1830eccf-a707-4753-a24a-9b103f55594a",
"metadata": {},
"source": [
"<img src=\"images/4.png\" width=\"450px\">"
]
},
{
"cell_type": "markdown",
"id": "9ce5211a-14e6-46aa-a3a8-14609f086e97",
"metadata": {},
"source": [
"- And for the second training example's token ID:"
]
},
{
"cell_type": "markdown",
"id": "173f6026-a461-44da-b9c5-f571f8ec8bf3",
"metadata": {},
"source": [
"<img src=\"images/5.png\" width=\"450px\">"
]
},
{
"cell_type": "markdown",
"id": "e2608049-f5d1-49a9-a14b-82695fc32e6a",
"metadata": {},
"source": [
"- Since all but one index in each one-hot encoded row are 0 (by design), this matrix multiplication is essentially the same as a look-up of the one-hot elements\n",
"- This use of the matrix multiplication on one-hot encodings is equivalent to the embedding layer look-up but can be inefficient if we work with large embedding matrices, because there are a lot of wasteful multiplications by zero"