{ "cells": [ { "cell_type": "markdown", "id": "6f678e62-7bcb-4405-86ae-dce94f494303", "metadata": {}, "source": [ "# Efficient Multi-Head Attention Implementations" ] }, { "cell_type": "markdown", "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", "metadata": {}, "source": [ "## Multi-head attention implementations from chapter 3" ] }, { "cell_type": "code", "execution_count": 1, "id": "7898551e-f582-48ac-9f66-3632abe2a93f", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "torch.manual_seed(123)\n", "\n", "batch_size = 8\n", "context_len = 1024\n", "embed_dim = 768\n", "embeddings = torch.randn((batch_size, context_len, embed_dim))" ] }, { "cell_type": "code", "execution_count": 2, "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 9216])\n" ] } ], "source": [ "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_1\n", "\n", "mha_ch03_1 = Ch03_MHA_1(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False\n", ")\n", "\n", "out = mha_ch03_1(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "code", "execution_count": 3, "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 768])\n" ] } ], "source": [ "from ch03 import MultiHeadAttention as Ch03_MHA_2\n", "\n", "mha_ch03_2 = Ch03_MHA_2(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False\n", ")\n", "\n", "out = mha_ch03_2(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", "metadata": {}, "source": [ "## An alternative multi-head attention with combined weights" ] }, { "cell_type": "markdown", "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", "metadata": {}, "source": [ "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", "\n", " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", " \n", "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" ] }, { "cell_type": "code", "execution_count": 4, "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 768])\n" ] } ], "source": [ "import torch.nn as nn\n", "\n", "\n", "class MultiHeadAttentionAlt(nn.Module):\n", " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", " super().__init__()\n", "\n", " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", "\n", " self.num_heads = num_heads\n", " self.block_size = block_size\n", " self.head_dim = d_out // num_heads\n", "\n", " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", " self.proj = nn.Linear(d_in, d_out)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " self.register_buffer(\n", " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", " )\n", "\n", " def forward(self, x):\n", " batch_size, num_tokens, embed_dim = x.shape\n", "\n", " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", " qkv = self.qkv(x)\n", "\n", " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", "\n", " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", " qkv = qkv.permute(2, 0, 3, 1, 4)\n", "\n", " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", " queries, keys, values = qkv.unbind(0)\n", "\n", " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", " attn_scores = queries @ keys.transpose(-2, -1)\n", " attn_scores = attn_scores.masked_fill(\n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", " )\n", " \n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", " attn_weights = self.dropout(attn_weights)\n", "\n", " # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n", " context_vec = attn_weights @ values\n", "\n", " # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n", " context_vec = context_vec.transpose(1, 2)\n", "\n", " # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n", " context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n", "\n", " context_vec = self.proj(context_vec)\n", "\n", " return context_vec\n", "\n", "\n", "mha_alt = MultiHeadAttentionAlt(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False\n", ")\n", "\n", "out = mha_alt(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "id": "48a042d3-ee78-4c29-bf63-d92fe6706632", "metadata": {}, "source": [ "## Multihead attention with PyTorch's scaled dot product attention" ] }, { "cell_type": "markdown", "id": "f78e346f-3b85-44e6-9feb-f01131381148", "metadata": {}, "source": [ "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" ] }, { "cell_type": "code", "execution_count": 5, "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", "metadata": {}, "outputs": [], "source": [ "class MultiHeadAttentionPyTorch(nn.Module):\n", " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", " super().__init__()\n", "\n", " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", "\n", " self.num_heads = num_heads\n", " self.block_size = block_size\n", " self.head_dim = d_out // num_heads\n", " self.d_out = d_out\n", "\n", " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", " self.proj = nn.Linear(d_in, d_out)\n", " self.dropout = dropout\n", "\n", " self.register_buffer(\n", " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", " )\n", "\n", " def forward(self, x):\n", " batch_size, num_tokens, embed_dim = x.shape\n", "\n", " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", " qkv = self.qkv(x)\n", "\n", " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", "\n", " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", " qkv = qkv.permute(2, 0, 3, 1, 4)\n", "\n", " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", " queries, keys, values = qkv.unbind(0)\n", "\n", " use_dropout = 0. if not self.training else self.dropout\n", " context_vec = nn.functional.scaled_dot_product_attention(\n", " queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", "\n", " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", "\n", " return context_vec" ] }, { "cell_type": "code", "execution_count": 6, "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 768])\n" ] } ], "source": [ "mha_pytorch = MultiHeadAttentionPyTorch(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False\n", ")\n", "\n", "out = mha_pytorch(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "id": "8877de71-f84f-4f6d-bc87-7552013b6301", "metadata": {}, "source": [ "## Speed comparison" ] }, { "cell_type": "code", "execution_count": 7, "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "879 ms ± 4.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%timeit mha_ch03_1(embeddings)" ] }, { "cell_type": "code", "execution_count": 8, "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "259 ms ± 7.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%timeit mha_ch03_2(embeddings)" ] }, { "cell_type": "code", "execution_count": 9, "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "290 ms ± 2.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%timeit mha_alt(embeddings)" ] }, { "cell_type": "code", "execution_count": 10, "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "91.5 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "%timeit mha_pytorch(embeddings)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }