{ "cells": [ { "cell_type": "markdown", "id": "78224549-3637-44b0-aed1-8ff889c65192", "metadata": {}, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", "
\n", "
\n", "\n", "
\n" ] }, { "cell_type": "markdown", "id": "51c9672d-8d0c-470d-ac2d-1271f8ec3f14", "metadata": {}, "source": [ "# Chapter 3 Exercise solutions" ] }, { "cell_type": "code", "execution_count": 1, "id": "513b627b-c197-44bd-99a2-756391c8a1cd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch version: 2.4.0\n" ] } ], "source": [ "from importlib.metadata import version\n", "\n", "import torch\n", "print(\"torch version:\", version(\"torch\"))" ] }, { "cell_type": "markdown", "id": "33dfa199-9aee-41d4-a64b-7e3811b9a616", "metadata": {}, "source": [ "# Exercise 3.1" ] }, { "cell_type": "code", "execution_count": 2, "id": "5fee2cf5-61c3-4167-81b5-44ea155bbaf2", "metadata": {}, "outputs": [], "source": [ "inputs = torch.tensor(\n", " [[0.43, 0.15, 0.89], # Your (x^1)\n", " [0.55, 0.87, 0.66], # journey (x^2)\n", " [0.57, 0.85, 0.64], # starts (x^3)\n", " [0.22, 0.58, 0.33], # with (x^4)\n", " [0.77, 0.25, 0.10], # one (x^5)\n", " [0.05, 0.80, 0.55]] # step (x^6)\n", ")\n", "\n", "d_in, d_out = 3, 2" ] }, { "cell_type": "code", "execution_count": 3, "id": "62ea289c-41cd-4416-89dd-dde6383a6f70", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class SelfAttention_v1(nn.Module):\n", "\n", " def __init__(self, d_in, d_out):\n", " super().__init__()\n", " self.d_out = d_out\n", " self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n", " self.W_key = nn.Parameter(torch.rand(d_in, d_out))\n", " self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n", "\n", " def forward(self, x):\n", " keys = x @ self.W_key\n", " queries = x @ self.W_query\n", " values = x @ self.W_value\n", " \n", " attn_scores = queries @ keys.T # omega\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", "\n", " context_vec = attn_weights @ values\n", " return context_vec\n", "\n", "torch.manual_seed(123)\n", "sa_v1 = SelfAttention_v1(d_in, d_out)" ] }, { "cell_type": "code", "execution_count": 4, "id": "7b035143-f4e8-45fb-b398-dec1bd5153d4", "metadata": {}, "outputs": [], "source": [ "class SelfAttention_v2(nn.Module):\n", "\n", " def __init__(self, d_in, d_out):\n", " super().__init__()\n", " self.d_out = d_out\n", " self.W_query = nn.Linear(d_in, d_out, bias=False)\n", " self.W_key = nn.Linear(d_in, d_out, bias=False)\n", " self.W_value = nn.Linear(d_in, d_out, bias=False)\n", "\n", " def forward(self, x):\n", " keys = self.W_key(x)\n", " queries = self.W_query(x)\n", " values = self.W_value(x)\n", " \n", " attn_scores = queries @ keys.T\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n", "\n", " context_vec = attn_weights @ values\n", " return context_vec\n", "\n", "torch.manual_seed(123)\n", "sa_v2 = SelfAttention_v2(d_in, d_out)" ] }, { "cell_type": "code", "execution_count": 5, "id": "7591d79c-c30e-406d-adfd-20c12eb448f6", "metadata": {}, "outputs": [], "source": [ "sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)\n", "sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)\n", "sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)" ] }, { "cell_type": "code", "execution_count": 6, "id": "ddd0f54f-6bce-46cc-a428-17c2a56557d0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.5337, -0.1051],\n", " [-0.5323, -0.1080],\n", " [-0.5323, -0.1079],\n", " [-0.5297, -0.1076],\n", " [-0.5311, -0.1066],\n", " [-0.5299, -0.1081]], grad_fn=)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sa_v1(inputs)" ] }, { "cell_type": "code", "execution_count": 7, "id": "340908f8-1144-4ddd-a9e1-a1c5c3d592f5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.5337, -0.1051],\n", " [-0.5323, -0.1080],\n", " [-0.5323, -0.1079],\n", " [-0.5297, -0.1076],\n", " [-0.5311, -0.1066],\n", " [-0.5299, -0.1081]], grad_fn=)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sa_v2(inputs)" ] }, { "cell_type": "markdown", "id": "33543edb-46b5-4b01-8704-f7f101230544", "metadata": {}, "source": [ "# Exercise 3.2" ] }, { "cell_type": "markdown", "id": "0588e209-1644-496a-8dae-7630b4ef9083", "metadata": {}, "source": [ "If we want to have an output dimension of 2, as earlier in single-head attention, we can have to change the projection dimension `d_out` to 1:" ] }, { "cell_type": "markdown", "id": "18e748ef-3106-4e11-a781-b230b74a0cef", "metadata": {}, "source": [ "```python\n", "torch.manual_seed(123)\n", "\n", "d_out = 1\n", "mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n", "\n", "context_vecs = mha(batch)\n", "\n", "print(context_vecs)\n", "print(\"context_vecs.shape:\", context_vecs.shape)\n", "```" ] }, { "cell_type": "markdown", "id": "78234544-d989-4f71-ac28-85a7ec1e6b7b", "metadata": {}, "source": [ "```\n", "tensor([[[-9.1476e-02, 3.4164e-02],\n", " [-2.6796e-01, -1.3427e-03],\n", " [-4.8421e-01, -4.8909e-02],\n", " [-6.4808e-01, -1.0625e-01],\n", " [-8.8380e-01, -1.7140e-01],\n", " [-1.4744e+00, -3.4327e-01]],\n", "\n", " [[-9.1476e-02, 3.4164e-02],\n", " [-2.6796e-01, -1.3427e-03],\n", " [-4.8421e-01, -4.8909e-02],\n", " [-6.4808e-01, -1.0625e-01],\n", " [-8.8380e-01, -1.7140e-01],\n", " [-1.4744e+00, -3.4327e-01]]], grad_fn=)\n", "context_vecs.shape: torch.Size([2, 6, 2])\n", "```" ] }, { "cell_type": "markdown", "id": "92bdabcb-06cf-4576-b810-d883bbd313ba", "metadata": {}, "source": [ "# Exercise 3.3" ] }, { "cell_type": "markdown", "id": "84c9b963-d01f-46e6-96bf-8eb2a54c5e42", "metadata": {}, "source": [ "```python\n", "context_length = 1024\n", "d_in, d_out = 768, 768\n", "num_heads = 12\n", "\n", "mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads)\n", "```" ] }, { "cell_type": "markdown", "id": "375d5290-8e8b-4149-958e-1efb58a69191", "metadata": {}, "source": [ "Optionally, the number of parameters is as follows:" ] }, { "cell_type": "markdown", "id": "6d7e603c-1658-4da9-9c0b-ef4bc72832b4", "metadata": {}, "source": [ "```python\n", "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "count_parameters(mha)\n", "```" ] }, { "cell_type": "markdown", "id": "51ba00bd-feb0-4424-84cb-7c2b1f908779", "metadata": {}, "source": [ "```\n", "2360064 # (2.36 M)\n", "```" ] }, { "cell_type": "markdown", "id": "a56c1d47-9b95-4bd1-a517-580a6f779c52", "metadata": {}, "source": [ "The GPT-2 model has 117M parameters in total, but as we can see, most of its parameters are not in the multi-head attention module itself." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }