LLMs-from-scratch/ch02/03_bonus_embedding-vs-matmul/embeddings-and-linear-layers.ipynb
2023-12-17 16:01:46 +01:00

487 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "063850ab-22b0-4838-b53a-9bb11757d9d0",
"metadata": {},
"source": [
"# Embedding Layers and Linear Layers"
]
},
{
"cell_type": "markdown",
"id": "0315c598-701f-46ff-8806-15813cad0e51",
"metadata": {},
"source": [
"- Embedding layers in PyTorch accomplish the same as linear layers that perform matrix multiplications; the reason we use embedding layers is computational efficiency\n",
"- We will take a look at this relationship step by step using code examples in PyTorch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "061720f4-f025-4640-82a0-15098fa94cf9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.1.0.dev20230825\n"
]
}
],
"source": [
"import torch\n",
"\n",
"print(\"PyTorch version:\", torch.__version__)"
]
},
{
"cell_type": "markdown",
"id": "a7895a66-7f69-4f62-9361-5c9da2eb76ef",
"metadata": {},
"source": [
"## Using nn.Embedding"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cc489ea5-73db-40b9-959e-0d70cae25f40",
"metadata": {},
"outputs": [],
"source": [
"# Suppose we have the following 3 training examples,\n",
"# which may represent token IDs in a LLM context\n",
"idx = torch.tensor([2, 3, 1])\n",
"\n",
"# The number of rows in the embedding matrix can be determined\n",
"# by obtaining the largest token ID + 1.\n",
"# If the highest token ID is 3, then we want 4 rows, for the possible\n",
"# token IDs 0, 1, 2, 3\n",
"num_idx = max(idx)+1\n",
"\n",
"# The desired embedding dimension is a hyperparameter\n",
"out_dim = 5"
]
},
{
"cell_type": "markdown",
"id": "93d83a6e-8543-40af-b253-fe647640bf36",
"metadata": {},
"source": [
"- Let's implement a simple embedding layer:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "60a7c104-36e1-4b28-bd02-a24a1099dc66",
"metadata": {},
"outputs": [],
"source": [
"# We use the random seed for reproducibility since\n",
"# weights in the embedding layer are initialized with\n",
"# small random values\n",
"torch.manual_seed(123)\n",
"\n",
"embedding = torch.nn.Embedding(num_idx, out_dim)"
]
},
{
"cell_type": "markdown",
"id": "dd96c00a-3297-4a50-8bfc-247aaea7e761",
"metadata": {},
"source": [
"We can optionally take a look at the embedding weights:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "595f603e-8d2a-4171-8f94-eac8106b2e57",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([[ 0.3374, -0.1778, -0.3035, -0.5880, 1.5810],\n",
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015],\n",
" [ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315],\n",
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953]], requires_grad=True)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embedding.weight"
]
},
{
"cell_type": "markdown",
"id": "c86eb562-61e2-4171-ab6e-b410a1fd5c18",
"metadata": {},
"source": [
"- We can then use the embedding layers to obtain the vector representation of a training example with ID 1:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "8bbc0255-4805-4be9-9f4c-1d0d967ef9d5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]],\n",
" grad_fn=<EmbeddingBackward0>)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embedding(torch.tensor([1]))"
]
},
{
"cell_type": "markdown",
"id": "6a4d47f2-4691-47b8-9855-2528b6c285c9",
"metadata": {},
"source": [
"- Below is a visualization of what happens under the hood:"
]
},
{
"cell_type": "markdown",
"id": "12ffd155-7190-44b1-b6b6-45b11d6fe83b",
"metadata": {},
"source": [
"<img src=\"images/1.png\" width=\"400px\">"
]
},
{
"cell_type": "markdown",
"id": "87d1311b-cfb2-4afc-9e25-e4ecf35370d9",
"metadata": {},
"source": [
"- Similarly, we can use embedding layers to obtain the vector representation of a training example with ID 2:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "c309266a-c601-4633-9404-2e10b1cdde8c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315]],\n",
" grad_fn=<EmbeddingBackward0>)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embedding(torch.tensor([2]))"
]
},
{
"cell_type": "markdown",
"id": "7ad3b601-f68c-41b1-a28d-b624b94ef383",
"metadata": {},
"source": [
"<img src=\"images/2.png\" width=\"400px\">"
]
},
{
"cell_type": "markdown",
"id": "27dd54bd-85ae-4887-9c5e-3139da361cf4",
"metadata": {},
"source": [
"- Now, let's convert all the training examples we have defined previously:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "0191aa4b-f6a8-4b0d-9c36-65e82b81d071",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315],\n",
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953],\n",
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]],\n",
" grad_fn=<EmbeddingBackward0>)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idx = torch.tensor([2, 3, 1])\n",
"embedding(idx)"
]
},
{
"cell_type": "markdown",
"id": "146cf8eb-c517-4cd4-aa91-0e818fed7651",
"metadata": {},
"source": [
"- Under the hood, it's still the same look-up concept:"
]
},
{
"cell_type": "markdown",
"id": "b392eb43-0bda-4821-b446-b8dcbee8ae00",
"metadata": {},
"source": [
"<img src=\"images/3.png\" width=\"450px\">"
]
},
{
"cell_type": "markdown",
"id": "f0fe863b-d6a3-48f3-ace5-09ecd0eb7b59",
"metadata": {},
"source": [
"## Using nn.Linear"
]
},
{
"cell_type": "markdown",
"id": "138de6a4-2689-4c1f-96af-7899b2d82a4e",
"metadata": {},
"source": [
"- Now, we will demonstrate that the embedding layer above accomplishes exactly the same as `nn.Linear` layer on a one-hot encoded representation in PyTorch\n",
"- First, let's convert the token IDs into a one-hot representation:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "b5bb56cf-bc73-41ab-b107-91a43f77bdba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 0, 1, 0],\n",
" [0, 0, 0, 1],\n",
" [0, 1, 0, 0]])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"onehot = torch.nn.functional.one_hot(idx)\n",
"onehot"
]
},
{
"cell_type": "markdown",
"id": "aa45dfdf-fb26-4514-a176-75224f5f179b",
"metadata": {},
"source": [
"- Next, we initialize a `Linear` layer, which caries out a matrix multiplication $X W^\\top$:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "ae04c1ed-242e-4dd7-b8f7-4b7e4caae383",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([[-0.2039, 0.0166, -0.2483, 0.1886],\n",
" [-0.4260, 0.3665, -0.3634, -0.3975],\n",
" [-0.3159, 0.2264, -0.1847, 0.1871],\n",
" [-0.4244, -0.3034, -0.1836, -0.0983],\n",
" [-0.3814, 0.3274, -0.1179, 0.1605]], requires_grad=True)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.manual_seed(123)\n",
"linear = torch.nn.Linear(num_idx, out_dim, bias=False)"
]
},
{
"cell_type": "markdown",
"id": "63efb98e-5cc4-4e8d-9fe6-ef0ad29ae2d7",
"metadata": {},
"source": [
"- Note that the linear layer in PyTorch is also initialized with small random weights; to directly compare it to the `Embedding` layer above, we have to use the same small random weights, which is why we reassign them here:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "a3b90d69-761c-486e-bd19-b38a2988fe62",
"metadata": {},
"outputs": [],
"source": [
"linear.weight = torch.nn.Parameter(embedding.weight.T.detach())"
]
},
{
"cell_type": "markdown",
"id": "9116482d-f1f9-45e2-9bf3-7ef5e9003898",
"metadata": {},
"source": [
"- Now we can use the linear layer on the one-hot encoded representation of the inputs:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "90d2b0dd-9f1d-4c0f-bb16-1f6ce6b8ac2c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315],\n",
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953],\n",
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]], grad_fn=<MmBackward0>)"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"linear(onehot.float())"
]
},
{
"cell_type": "markdown",
"id": "f6204bc8-92e2-4546-9cda-574fe1360fa2",
"metadata": {},
"source": [
"As we can see, this is exactly the same as what we got when we used the embedding layer:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "2b057649-3176-4a54-b58c-fd8fbf818c61",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.6957, -1.8061, -1.1589, 0.3255, -0.6315],\n",
" [-2.8400, -0.7849, -1.4096, -0.4076, 0.7953],\n",
" [ 1.3010, 1.2753, -0.2010, -0.1606, -0.4015]],\n",
" grad_fn=<EmbeddingBackward0>)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embedding(idx)"
]
},
{
"cell_type": "markdown",
"id": "0e447639-8952-460e-8c8f-cf9e23c368c9",
"metadata": {},
"source": [
"- What happens under the hood is the following computation for the first training example's token ID:"
]
},
{
"cell_type": "markdown",
"id": "1830eccf-a707-4753-a24a-9b103f55594a",
"metadata": {},
"source": [
"<img src=\"images/4.png\" width=\"450px\">"
]
},
{
"cell_type": "markdown",
"id": "9ce5211a-14e6-46aa-a3a8-14609f086e97",
"metadata": {},
"source": [
"- And for the second training example's token ID:"
]
},
{
"cell_type": "markdown",
"id": "173f6026-a461-44da-b9c5-f571f8ec8bf3",
"metadata": {},
"source": [
"<img src=\"images/5.png\" width=\"450px\">"
]
},
{
"cell_type": "markdown",
"id": "e2608049-f5d1-49a9-a14b-82695fc32e6a",
"metadata": {},
"source": [
"- Since all but one index in each one-hot encoded row are 0 (by design), this matrix multiplication is essentially the same as a look-up of the one-hot elements\n",
"- This use of the matrix multiplication on one-hot encodings is equivalent to the embedding layer look-up but can be inefficient if we work with large embedding matrices, because there are a lot of wasteful multiplications by zero"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5eacc005-86fc-490c-8f6a-dc37d8a0df7c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}