mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-06-26 23:50:03 +00:00

* fixed typo * removed remaining nbviewer links * Update mha-implementations.ipynb --------- Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
2010 lines
340 KiB
Plaintext
2010 lines
340 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e2e65c03-36d4-413f-9b23-5cdd816729ab",
|
|
"metadata": {
|
|
"id": "e2e65c03-36d4-413f-9b23-5cdd816729ab"
|
|
},
|
|
"source": [
|
|
"<table style=\"width:100%\">\n",
|
|
"<tr>\n",
|
|
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
|
"<font size=\"2\">\n",
|
|
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
|
|
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
|
|
"</font>\n",
|
|
"</td>\n",
|
|
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
|
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
|
|
"</td>\n",
|
|
"</tr>\n",
|
|
"</table>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1HABx0Hr3PDD",
|
|
"metadata": {
|
|
"id": "1HABx0Hr3PDD"
|
|
},
|
|
"source": [
|
|
"Uncomment and execute the following code cell to install the dependencies:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "qPnVNAOxwy5s",
|
|
"metadata": {
|
|
"id": "qPnVNAOxwy5s"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/requirements.txt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "LYLcq3403Yq6",
|
|
"metadata": {
|
|
"id": "LYLcq3403Yq6"
|
|
},
|
|
"source": [
|
|
"Uncomment and execute the following code cell to install the PyTorch nightly dependency if you want to run the FlexAttention benchmarks (this is required because FlexAttention is not yet included in the latest PyTorch release):"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "gAgYvxm_xVct",
|
|
"metadata": {
|
|
"id": "gAgYvxm_xVct"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 -U"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6f678e62-7bcb-4405-86ae-dce94f494303",
|
|
"metadata": {
|
|
"id": "6f678e62-7bcb-4405-86ae-dce94f494303"
|
|
},
|
|
"source": [
|
|
"# Comparing Efficient Multi-Head Attention Implementations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b742938a-4bfc-4527-a1f1-d5963508967d",
|
|
"metadata": {
|
|
"id": "b742938a-4bfc-4527-a1f1-d5963508967d"
|
|
},
|
|
"source": [
|
|
"This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
|
|
"outputId": "1a7d22c1-96d8-4a42-e3ec-ce78abaf18eb"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"PyTorch version: 2.5.0.dev20240905+cu121\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"\n",
|
|
"torch.manual_seed(123)\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"print(f\"PyTorch version: {torch.__version__}\")\n",
|
|
"\n",
|
|
"batch_size = 8\n",
|
|
"context_len = 1024\n",
|
|
"embed_dim = 768\n",
|
|
"embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
|
|
"metadata": {
|
|
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## 1) CausalAttention MHA wrapper class from chapter 3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
|
|
"outputId": "b6f596e4-b778-496c-bea8-3fe83d873c5b"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([8, 1024, 768])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch.nn as nn\n",
|
|
"\n",
|
|
"class CausalAttention(nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
|
|
" super().__init__()\n",
|
|
" self.d_out = d_out\n",
|
|
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
|
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
|
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
|
" self.dropout = nn.Dropout(dropout) # New\n",
|
|
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" b, num_tokens, d_in = x.shape # New batch dimension b\n",
|
|
" keys = self.W_key(x)\n",
|
|
" queries = self.W_query(x)\n",
|
|
" values = self.W_value(x)\n",
|
|
"\n",
|
|
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
|
|
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
|
|
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n",
|
|
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
|
" attn_weights = self.dropout(attn_weights) # New\n",
|
|
"\n",
|
|
" context_vec = attn_weights @ values\n",
|
|
" return context_vec\n",
|
|
"\n",
|
|
"\n",
|
|
"class Ch03_MHA_Wrapper(nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
|
|
" super().__init__()\n",
|
|
" self.heads = nn.ModuleList(\n",
|
|
" [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)\n",
|
|
" for _ in range(num_heads)]\n",
|
|
" )\n",
|
|
" self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
|
|
" return self.out_proj(context_vec)\n",
|
|
"\n",
|
|
"\n",
|
|
"mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
|
|
" d_in=embed_dim,\n",
|
|
" d_out=embed_dim//12,\n",
|
|
" context_length=context_len,\n",
|
|
" dropout=0.0,\n",
|
|
" num_heads=12,\n",
|
|
" qkv_bias=False\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"out = mha_ch03_wrapper(embeddings)\n",
|
|
"print(out.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "21930804-b327-40b1-8e63-94dcad39ce7b",
|
|
"metadata": {
|
|
"id": "21930804-b327-40b1-8e63-94dcad39ce7b"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## 2) The multi-head attention class from chapter 3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
|
|
"outputId": "4d9ade55-4710-4ae6-9f00-aa87811bfb04"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([8, 1024, 768])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"class Ch03_MHA(nn.Module):\n",
|
|
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
|
|
" super().__init__()\n",
|
|
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
|
"\n",
|
|
" self.d_out = d_out\n",
|
|
" self.num_heads = num_heads\n",
|
|
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
|
|
"\n",
|
|
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
|
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
|
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
|
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
|
|
" self.dropout = nn.Dropout(dropout)\n",
|
|
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" b, num_tokens, d_in = x.shape\n",
|
|
"\n",
|
|
" keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n",
|
|
" queries = self.W_query(x)\n",
|
|
" values = self.W_value(x)\n",
|
|
"\n",
|
|
" # We implicitly split the matrix by adding a `num_heads` dimension\n",
|
|
" # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
|
|
" keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n",
|
|
" values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
|
|
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
|
|
"\n",
|
|
" # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
|
|
" keys = keys.transpose(1, 2)\n",
|
|
" queries = queries.transpose(1, 2)\n",
|
|
" values = values.transpose(1, 2)\n",
|
|
"\n",
|
|
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
|
|
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
|
"\n",
|
|
" # Original mask truncated to the number of tokens and converted to boolean\n",
|
|
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
|
|
"\n",
|
|
" # Use the mask to fill attention scores\n",
|
|
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
|
|
"\n",
|
|
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
|
" attn_weights = self.dropout(attn_weights)\n",
|
|
"\n",
|
|
" # Shape: (b, num_tokens, num_heads, head_dim)\n",
|
|
" context_vec = (attn_weights @ values).transpose(1, 2)\n",
|
|
"\n",
|
|
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
|
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
|
|
" context_vec = self.out_proj(context_vec) # optional projection\n",
|
|
"\n",
|
|
" return context_vec\n",
|
|
"\n",
|
|
"\n",
|
|
"mha_ch03 = Ch03_MHA(\n",
|
|
" d_in=embed_dim,\n",
|
|
" d_out=embed_dim,\n",
|
|
" context_length=context_len,\n",
|
|
" dropout=0.0,\n",
|
|
" num_heads=12,\n",
|
|
" qkv_bias=False\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"out = mha_ch03(embeddings)\n",
|
|
"print(out.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4",
|
|
"metadata": {
|
|
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## 3) An alternative multi-head attention with combined weights"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd",
|
|
"metadata": {
|
|
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd"
|
|
},
|
|
"source": [
|
|
"- The code for the `MultiHeadAttentionCombinedQKV` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n",
|
|
"- The main difference between the `MultiHeadAttentionCombinedQKV` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionCombinedQKV` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n",
|
|
"\n",
|
|
" - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
|
|
" - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
|
|
" - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
|
|
"\n",
|
|
"- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n",
|
|
"- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
|
|
"outputId": "a0a023ee-3bc7-4a89-cdba-7c97921160ee"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([8, 1024, 768])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch.nn as nn\n",
|
|
"\n",
|
|
"\n",
|
|
"class MultiHeadAttentionCombinedQKV(nn.Module):\n",
|
|
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
|
|
"\n",
|
|
" self.num_heads = num_heads\n",
|
|
" self.context_length = context_length\n",
|
|
" self.head_dim = d_out // num_heads\n",
|
|
"\n",
|
|
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
|
|
" self.proj = nn.Linear(d_out, d_out)\n",
|
|
" self.dropout = nn.Dropout(dropout)\n",
|
|
"\n",
|
|
" self.register_buffer(\n",
|
|
" \"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
|
" )\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" batch_size, num_tokens, embed_dim = x.shape\n",
|
|
"\n",
|
|
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
|
|
" qkv = self.qkv(x)\n",
|
|
"\n",
|
|
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
|
|
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
|
|
"\n",
|
|
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
|
|
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
|
|
"\n",
|
|
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
|
|
" queries, keys, values = qkv.unbind(0)\n",
|
|
"\n",
|
|
" # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
|
|
" attn_scores = queries @ keys.transpose(-2, -1)\n",
|
|
" attn_scores = attn_scores.masked_fill(\n",
|
|
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
|
|
" )\n",
|
|
"\n",
|
|
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n",
|
|
" attn_weights = self.dropout(attn_weights)\n",
|
|
"\n",
|
|
" # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n",
|
|
" context_vec = attn_weights @ values\n",
|
|
"\n",
|
|
" # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n",
|
|
" context_vec = context_vec.transpose(1, 2)\n",
|
|
"\n",
|
|
" # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n",
|
|
" context_vec = context_vec.contiguous().view(batch_size, num_tokens, embed_dim)\n",
|
|
"\n",
|
|
" context_vec = self.proj(context_vec)\n",
|
|
"\n",
|
|
" return context_vec\n",
|
|
"\n",
|
|
"\n",
|
|
"mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n",
|
|
" d_in=embed_dim,\n",
|
|
" d_out=embed_dim,\n",
|
|
" context_length=context_len,\n",
|
|
" dropout=0.0,\n",
|
|
" num_heads=12,\n",
|
|
" qkv_bias=False\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"out = mha_combined_qkv(embeddings)\n",
|
|
"print(out.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9b14390d-3e21-43fd-87be-43e7029163ee",
|
|
"metadata": {
|
|
"id": "9b14390d-3e21-43fd-87be-43e7029163ee"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## 4) Multi-head attention with Einsum\n",
|
|
"\n",
|
|
"- Implementing multi-head attention using Einstein summation via [`torch.einsum`](https://pytorch.org/docs/stable/generated/torch.einsum.html)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "92481814-068d-439b-a65c-b1310ebbe0aa",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "92481814-068d-439b-a65c-b1310ebbe0aa",
|
|
"outputId": "59a75f6e-ef06-418f-8e54-d3b368fbed13"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([8, 1024, 768])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import math\n",
|
|
"\n",
|
|
"\n",
|
|
"class MHAEinsum(nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
|
|
" super().__init__()\n",
|
|
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
|
"\n",
|
|
" self.d_out = d_out\n",
|
|
" self.num_heads = num_heads\n",
|
|
" self.head_dim = d_out // num_heads\n",
|
|
"\n",
|
|
" # Initialize parameters for Q, K, V\n",
|
|
" self.W_query = nn.Parameter(torch.randn(d_out, d_in))\n",
|
|
" self.W_key = nn.Parameter(torch.randn(d_out, d_in))\n",
|
|
" self.W_value = nn.Parameter(torch.randn(d_out, d_in))\n",
|
|
"\n",
|
|
" if qkv_bias:\n",
|
|
" self.bias_q = nn.Parameter(torch.zeros(d_out))\n",
|
|
" self.bias_k = nn.Parameter(torch.zeros(d_out))\n",
|
|
" self.bias_v = nn.Parameter(torch.zeros(d_out))\n",
|
|
" else:\n",
|
|
" self.register_parameter(\"bias_q\", None)\n",
|
|
" self.register_parameter(\"bias_k\", None)\n",
|
|
" self.register_parameter(\"bias_v\", None)\n",
|
|
"\n",
|
|
" self.out_proj = nn.Linear(d_out, d_out)\n",
|
|
" self.dropout = nn.Dropout(dropout)\n",
|
|
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
|
"\n",
|
|
" # Initialize parameters\n",
|
|
" self.reset_parameters()\n",
|
|
"\n",
|
|
"\n",
|
|
" def reset_parameters(self):\n",
|
|
" nn.init.kaiming_uniform_(self.W_query, a=math.sqrt(5))\n",
|
|
" nn.init.kaiming_uniform_(self.W_key, a=math.sqrt(5))\n",
|
|
" nn.init.kaiming_uniform_(self.W_value, a=math.sqrt(5))\n",
|
|
" if self.bias_q is not None:\n",
|
|
" fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_query)\n",
|
|
" bound = 1 / math.sqrt(fan_in)\n",
|
|
" nn.init.uniform_(self.bias_q, -bound, bound)\n",
|
|
" nn.init.uniform_(self.bias_k, -bound, bound)\n",
|
|
" nn.init.uniform_(self.bias_v, -bound, bound)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" b, n, _ = x.shape\n",
|
|
"\n",
|
|
" # Calculate Q, K, V using einsum, first perform linear transformations\n",
|
|
" Q = torch.einsum(\"bnd,di->bni\", x, self.W_query)\n",
|
|
" K = torch.einsum(\"bnd,di->bni\", x, self.W_key)\n",
|
|
" V = torch.einsum(\"bnd,di->bni\", x, self.W_value)\n",
|
|
"\n",
|
|
" # Add biases if they are used\n",
|
|
" if self.bias_q is not None:\n",
|
|
" Q += self.bias_q\n",
|
|
" K += self.bias_k\n",
|
|
" V += self.bias_v\n",
|
|
"\n",
|
|
" # Reshape for multi-head attention\n",
|
|
" Q = Q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)\n",
|
|
" K = K.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)\n",
|
|
" V = V.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)\n",
|
|
"\n",
|
|
" # Scaled dot-product attention\n",
|
|
" scores = torch.einsum(\"bhnd,bhmd->bhnm\", Q, K) / (self.head_dim ** 0.5)\n",
|
|
"\n",
|
|
" # Apply mask\n",
|
|
" mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)\n",
|
|
" scores = scores.masked_fill(mask.bool(), -torch.inf)\n",
|
|
"\n",
|
|
" # Softmax and dropout\n",
|
|
" attn_weights = torch.softmax(scores, dim=-1)\n",
|
|
" attn_weights = self.dropout(attn_weights)\n",
|
|
"\n",
|
|
" # Aggregate the attended context vectors\n",
|
|
" context_vec = torch.einsum(\"bhnm,bhmd->bhnd\", attn_weights, V)\n",
|
|
"\n",
|
|
" # Combine heads and project the output\n",
|
|
" context_vec = context_vec.transpose(1, 2).reshape(b, n, self.d_out)\n",
|
|
" context_vec = self.out_proj(context_vec)\n",
|
|
"\n",
|
|
" return context_vec\n",
|
|
"\n",
|
|
"\n",
|
|
"mha_einsum = MHAEinsum(\n",
|
|
" d_in=embed_dim,\n",
|
|
" d_out=embed_dim,\n",
|
|
" context_length=context_len,\n",
|
|
" dropout=0.0,\n",
|
|
" num_heads=12,\n",
|
|
" qkv_bias=False\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"out = mha_einsum(embeddings)\n",
|
|
"print(out.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632",
|
|
"metadata": {
|
|
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## 5) Multi-head attention with PyTorch's scaled dot product attention and FlashAttention"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f78e346f-3b85-44e6-9feb-f01131381148",
|
|
"metadata": {
|
|
"id": "f78e346f-3b85-44e6-9feb-f01131381148"
|
|
},
|
|
"source": [
|
|
"- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention called [FlashAttention](https://arxiv.org/abs/2205.14135)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
|
|
"metadata": {
|
|
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class MHAPyTorchScaledDotProduct(nn.Module):\n",
|
|
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
|
|
"\n",
|
|
" self.num_heads = num_heads\n",
|
|
" self.context_length = context_length\n",
|
|
" self.head_dim = d_out // num_heads\n",
|
|
" self.d_out = d_out\n",
|
|
"\n",
|
|
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
|
|
" self.proj = nn.Linear(d_out, d_out)\n",
|
|
" self.dropout = dropout\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" batch_size, num_tokens, embed_dim = x.shape\n",
|
|
"\n",
|
|
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
|
|
" qkv = self.qkv(x)\n",
|
|
"\n",
|
|
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
|
|
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
|
|
"\n",
|
|
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
|
|
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
|
|
"\n",
|
|
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
|
|
" queries, keys, values = qkv\n",
|
|
"\n",
|
|
" use_dropout = 0. if not self.training else self.dropout\n",
|
|
"\n",
|
|
" context_vec = nn.functional.scaled_dot_product_attention(\n",
|
|
" queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
|
|
"\n",
|
|
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
|
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
|
|
"\n",
|
|
" context_vec = self.proj(context_vec)\n",
|
|
"\n",
|
|
" return context_vec"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
|
|
"outputId": "087a53e7-86d8-48dc-bf2e-023f0f2104cb"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([8, 1024, 768])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n",
|
|
" d_in=embed_dim,\n",
|
|
" d_out=embed_dim,\n",
|
|
" context_length=context_len,\n",
|
|
" dropout=0.0,\n",
|
|
" num_heads=12,\n",
|
|
" qkv_bias=False\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"out = mha_pytorch_scaled(embeddings)\n",
|
|
"print(out.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "51492724-6018-49f6-8bf6-ae9e585229c3",
|
|
"metadata": {
|
|
"id": "51492724-6018-49f6-8bf6-ae9e585229c3"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## 6) PyTorch's scaled dot product attention without FlashAttention\n",
|
|
"\n",
|
|
"- This is similar to above, except that we disable FlashAttention by passing an explicit causal mask"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "bad53538-e905-4065-ba0c-caacdfec5a0b",
|
|
"metadata": {
|
|
"id": "bad53538-e905-4065-ba0c-caacdfec5a0b"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class MHAPyTorchSDPAWithoutFlash(nn.Module):\n",
|
|
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
|
|
"\n",
|
|
" self.num_heads = num_heads\n",
|
|
" self.context_length = context_length\n",
|
|
" self.head_dim = d_out // num_heads\n",
|
|
" self.d_out = d_out\n",
|
|
"\n",
|
|
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
|
|
" self.proj = nn.Linear(d_out, d_out)\n",
|
|
" self.dropout = dropout\n",
|
|
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" batch_size, num_tokens, embed_dim = x.shape\n",
|
|
"\n",
|
|
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
|
|
" qkv = self.qkv(x)\n",
|
|
"\n",
|
|
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
|
|
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
|
|
"\n",
|
|
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
|
|
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
|
|
"\n",
|
|
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
|
|
" queries, keys, values = qkv\n",
|
|
"\n",
|
|
" use_dropout = 0. if not self.training else self.dropout\n",
|
|
"\n",
|
|
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
|
|
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
|
|
" if self.context_length >= num_tokens:\n",
|
|
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
|
|
" else:\n",
|
|
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
|
|
"\n",
|
|
" context_vec = nn.functional.scaled_dot_product_attention(\n",
|
|
" queries, keys, values, attn_mask=attn_mask, dropout_p=use_dropout, is_causal=False)\n",
|
|
"\n",
|
|
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
|
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
|
|
"\n",
|
|
" context_vec = self.proj(context_vec)\n",
|
|
"\n",
|
|
" return context_vec"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "f3da7850-e772-47d3-bd51-22d077b01412",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "f3da7850-e772-47d3-bd51-22d077b01412",
|
|
"outputId": "cc8fc837-8e06-42fc-bad5-b17816f47fcd"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([8, 1024, 768])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"mha_pytorch_sdpa_no_flash = MHAPyTorchSDPAWithoutFlash(\n",
|
|
" d_in=embed_dim,\n",
|
|
" d_out=embed_dim,\n",
|
|
" context_length=context_len,\n",
|
|
" dropout=0.0,\n",
|
|
" num_heads=12,\n",
|
|
" qkv_bias=False\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"out = mha_pytorch_sdpa_no_flash(embeddings)\n",
|
|
"print(out.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "351c318f-4835-4d74-8d58-a070222447c4",
|
|
"metadata": {
|
|
"id": "351c318f-4835-4d74-8d58-a070222447c4"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## 7) Using PyTorch's torch.nn.MultiheadAttention"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965",
|
|
"metadata": {
|
|
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965"
|
|
},
|
|
"source": [
|
|
"- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
|
|
"outputId": "78236eea-a0f4-47e4-c846-606e7f8f8768"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([8, 1024, 768])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch.nn as nn\n",
|
|
"\n",
|
|
"\n",
|
|
"class MHAPyTorchClass(nn.Module):\n",
|
|
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False, need_weights=True):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" self.context_length = context_length\n",
|
|
" self.multihead_attn = nn.MultiheadAttention(\n",
|
|
" embed_dim=d_out,\n",
|
|
" num_heads=num_heads,\n",
|
|
" dropout=dropout,\n",
|
|
" bias=qkv_bias,\n",
|
|
" add_bias_kv=qkv_bias,\n",
|
|
" batch_first=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" self.need_weights = need_weights\n",
|
|
" self.proj = nn.Linear(d_out, d_out)\n",
|
|
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" batch_size, num_tokens, _ = x.shape\n",
|
|
"\n",
|
|
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
|
|
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
|
|
" if self.context_length >= num_tokens:\n",
|
|
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
|
|
" else:\n",
|
|
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
|
|
"\n",
|
|
" # attn_mask broadcasting will handle batch_size dimension implicitly\n",
|
|
" attn_output, _ = self.multihead_attn(\n",
|
|
" x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n",
|
|
" )\n",
|
|
"\n",
|
|
" output = self.proj(attn_output)\n",
|
|
"\n",
|
|
" return output\n",
|
|
"\n",
|
|
"\n",
|
|
"mha_pytorch_class_default = MHAPyTorchClass(\n",
|
|
" d_in=embed_dim,\n",
|
|
" d_out=embed_dim,\n",
|
|
" context_length=context_len,\n",
|
|
" dropout=0.0,\n",
|
|
" num_heads=12,\n",
|
|
" qkv_bias=False\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"out = mha_pytorch_class_default(embeddings)\n",
|
|
"print(out.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4",
|
|
"metadata": {
|
|
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## 8) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d2164859-31a0-4537-b4fb-27d57675ba77",
|
|
"metadata": {
|
|
"id": "d2164859-31a0-4537-b4fb-27d57675ba77"
|
|
},
|
|
"source": [
|
|
"- Set `need_weights` (default `True`) to need_weights=False so that `MultiheadAttention` uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
|
|
"\n",
|
|
"> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n",
|
|
" Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n",
|
|
" and achieve the best performance for MHA.\n",
|
|
" Default: ``True``."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
|
|
"outputId": "6359dcff-ddcf-4cf9-eada-c3f0685cced2"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([8, 1024, 768])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"mha_pytorch_class_noweights = MHAPyTorchClass(\n",
|
|
" d_in=embed_dim,\n",
|
|
" d_out=embed_dim,\n",
|
|
" context_length=context_len,\n",
|
|
" dropout=0.0,\n",
|
|
" num_heads=12,\n",
|
|
" qkv_bias=False,\n",
|
|
" need_weights=False # NEW!\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"out = mha_pytorch_class_noweights(embeddings)\n",
|
|
"print(out.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc",
|
|
"metadata": {
|
|
"id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## 9) Using PyTorch's FlexAttention\n",
|
|
"\n",
|
|
"- See [FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention](https://pytorch.org/blog/flexattention/) to learn more about FlexAttention\n",
|
|
"- This is currently only supported in PyTorch 2.5 (nightly), which you can install on a CPU machine via\n",
|
|
"\n",
|
|
" ```bash\n",
|
|
" pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu -U\n",
|
|
" ```\n",
|
|
"\n",
|
|
"- To install PyTorch nighly on a GPU machine, use the following (for more information, also see the installation menu on [pytorch.org](https://pytorch.org/))\n",
|
|
"\n",
|
|
" ```bash\n",
|
|
" pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 -U\n",
|
|
" ```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "834318c8-4748-4902-99f0-70ee02bef63e",
|
|
"metadata": {
|
|
"id": "834318c8-4748-4902-99f0-70ee02bef63e"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from packaging.version import parse as parse_version\n",
|
|
"\n",
|
|
"def normalize_version(version):\n",
|
|
" parsed_version = parse_version(version)\n",
|
|
" return parse_version(f\"{parsed_version.major}.{parsed_version.minor}.{parsed_version.micro}\")\n",
|
|
"\n",
|
|
"current_version = normalize_version(torch.__version__)\n",
|
|
"MIN_TORCH_VERSION = \"2.5.0\"\n",
|
|
"required_version = parse_version(MIN_TORCH_VERSION)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "WYyFRCXndVH9",
|
|
"metadata": {
|
|
"id": "WYyFRCXndVH9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"if current_version >= required_version:\n",
|
|
" from torch.nn.attention.flex_attention import flex_attention, create_block_mask\n",
|
|
"\n",
|
|
"\n",
|
|
"def causal(b, h, q_idx, kv_idx):\n",
|
|
" return q_idx >= kv_idx\n",
|
|
"\n",
|
|
"\n",
|
|
"class MHAPyTorchFlexAttention(nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
|
|
"\n",
|
|
" self.num_heads = num_heads\n",
|
|
" self.context_length = context_length\n",
|
|
" self.head_dim = d_out // num_heads\n",
|
|
" self.d_out = d_out\n",
|
|
"\n",
|
|
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
|
|
" self.proj = nn.Linear(d_out, d_out)\n",
|
|
" self.dropout = dropout\n",
|
|
" # self.register_buffer(\"block_mask\", create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length))\n",
|
|
" # `create_block_mask` function does not support buffers, yet\n",
|
|
" self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n",
|
|
"\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" batch_size, num_tokens, embed_dim = x.shape\n",
|
|
"\n",
|
|
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
|
|
" qkv = self.qkv(x)\n",
|
|
"\n",
|
|
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
|
|
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
|
|
"\n",
|
|
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
|
|
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
|
|
"\n",
|
|
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
|
|
" queries, keys, values = qkv\n",
|
|
"\n",
|
|
" use_dropout = 0. if not self.training else self.dropout\n",
|
|
"\n",
|
|
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
|
|
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
|
|
" if self.context_length >= num_tokens:\n",
|
|
" attn_mask = self.block_mask[:num_tokens, :num_tokens]\n",
|
|
" else:\n",
|
|
" attn_mask = self.block_mask[:self.context_length, :self.context_length]\n",
|
|
"\n",
|
|
" context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n",
|
|
"\n",
|
|
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
|
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
|
|
"\n",
|
|
" context_vec = self.proj(context_vec)\n",
|
|
"\n",
|
|
" return context_vec"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
|
|
"outputId": "a88a7398-159e-401f-d96c-2fc928908e3e"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([8, 1024, 768])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"if current_version >= required_version and torch.cuda.is_available():\n",
|
|
"\n",
|
|
" mha_pytorch_flex = MHAPyTorchFlexAttention(\n",
|
|
" d_in=embed_dim,\n",
|
|
" d_out=embed_dim,\n",
|
|
" context_length=context_len,\n",
|
|
" dropout=0.0,\n",
|
|
" num_heads=12,\n",
|
|
" qkv_bias=False\n",
|
|
" ).to(device)\n",
|
|
"\n",
|
|
" out = mha_pytorch_flex(embeddings)\n",
|
|
" print(out.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8877de71-f84f-4f6d-bc87-7552013b6301",
|
|
"metadata": {
|
|
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## Quick speed comparison (M3 Macbook Air CPU)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "219cf93a-078f-434d-888c-2458d0731285",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "219cf93a-078f-434d-888c-2458d0731285",
|
|
"outputId": "a10b52d4-b4e6-43c2-9677-113c41edd3b7"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"PyTorch version: 2.4.0\n",
|
|
"Running on cpu\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"torch.manual_seed(123)\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"print(f\"PyTorch version: {torch.__version__}\")\n",
|
|
"print(f\"Running on {device}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
|
|
"outputId": "7bcd7da4-d115-4ba6-efba-377a0bd7d3a8"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"179 ms ± 7.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
|
|
"%timeit mha_ch03_wrapper(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
|
|
"outputId": "b04b4d0d-71aa-4944-f02b-131bf5a50202"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"166 ms ± 2.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 2) The multi-head attention class from chapter 3\n",
|
|
"%timeit mha_ch03(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
|
|
"outputId": "5436928a-7b98-4c40-bf51-97973f13327e"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"190 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 3) An alternative multi-head attention with combined weights\n",
|
|
"%timeit mha_combined_qkv(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "131ca826-35bf-47e5-b497-540aba439ef9",
|
|
"metadata": {
|
|
"id": "131ca826-35bf-47e5-b497-540aba439ef9",
|
|
"outputId": "f5848852-f81b-4e5f-a7ff-e37a8445ad91"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"196 ms ± 1.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 4) Multi-head attention using Einstein summation\n",
|
|
"%timeit mha_einsum(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
|
|
"outputId": "9e07ce73-a2de-4e2c-8276-64626df9450e"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"110 ms ± 423 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 5) Multi-head attention with PyTorch's scaled dot product attention\n",
|
|
"%timeit mha_pytorch_scaled(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c44305ce-9f61-451a-b9ef-30caba222357",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "c44305ce-9f61-451a-b9ef-30caba222357",
|
|
"outputId": "6bab4a24-5bb4-4ad6-b260-3b442f598950"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"99.5 ms ± 790 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 6) PyTorch's scaled dot product attention without FlashAttention\n",
|
|
"%timeit mha_pytorch_sdpa_no_flash(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
|
|
"outputId": "630c49d1-8a06-4148-cd96-a7b2467310a0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"198 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 7) Using PyTorch's torch.nn.MultiheadAttention\n",
|
|
"%timeit mha_pytorch_class_default(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
|
|
"outputId": "10f6a268-f9cf-446c-aa83-e87b6a0b4f5c"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"168 ms ± 2.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 8) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
|
|
"%timeit mha_pytorch_class_noweights(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bdd8e0fc-ef24-424c-bccf-c381e73da228",
|
|
"metadata": {
|
|
"id": "bdd8e0fc-ef24-424c-bccf-c381e73da228"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"## 9) Using PyTorch's FlexAttention\n",
|
|
"\n",
|
|
"# Requires PyTorch 2.5.0 or newer and currently only supports CUDA PyTorch\n",
|
|
"%timeit mha_pytorch_flex(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9",
|
|
"metadata": {
|
|
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"## Quick speed comparison (Nvidia A100 GPU)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "RStnI1pEi6Eo",
|
|
"metadata": {
|
|
"id": "RStnI1pEi6Eo"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Enable tensor cores\n",
|
|
"torch.set_float32_matmul_precision(\"high\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
|
|
"outputId": "f6356d4c-7a3f-47f5-cf51-5507db3f5748"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"PyTorch version: 2.5.0.dev20240905+cu121\n",
|
|
"Running on cuda\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"torch.manual_seed(123)\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"print(f\"PyTorch version: {torch.__version__}\")\n",
|
|
"print(f\"Running on {device}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
|
|
"outputId": "4ea5798b-a590-401b-d049-3fed0716db34"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"4.33 ms ± 51.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
|
|
"%timeit mha_ch03_wrapper(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "8686dd69-3655-40e4-a57b-a2c55532a010",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "8686dd69-3655-40e4-a57b-a2c55532a010",
|
|
"outputId": "88094b61-4d87-47bd-8c8b-c9344ab57062"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"3.09 ms ± 363 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 2) The multi-head attention class from chapter 3\n",
|
|
"%timeit mha_ch03(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
|
|
"outputId": "e3d82c53-f75b-425a-ed3e-5e48ea9ef768"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"3.81 ms ± 656 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 3) An alternative multi-head attention with combined weights\n",
|
|
"%timeit mha_combined_qkv(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "abee5edf-2585-4f0e-846c-b1c7ca88f545",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "abee5edf-2585-4f0e-846c-b1c7ca88f545",
|
|
"outputId": "c9bf17f5-de62-4c39-a328-fe430812b156"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"4.12 ms ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 4) Multi-head attention using Einstein summation\n",
|
|
"%timeit mha_einsum(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
|
|
"outputId": "b63f4769-3be5-44df-b8f2-2ac379be1ff4"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1.25 ms ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 5) Multi-head attention with PyTorch's scaled dot product attention\n",
|
|
"%timeit mha_pytorch_scaled(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "218adbaf-f17f-47d9-81d5-41c758218df7",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "218adbaf-f17f-47d9-81d5-41c758218df7",
|
|
"outputId": "a30ab365-865d-4175-f148-dc15abc4e07a"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2.03 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 6) PyTorch's scaled dot product attention without FlashAttention\n",
|
|
"%timeit mha_pytorch_sdpa_no_flash(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
|
|
"outputId": "e20e77ac-6573-4830-82c7-795bd139af4f"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"3.05 ms ± 388 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 7) Using PyTorch's torch.nn.MultiheadAttention\n",
|
|
"%timeit mha_pytorch_class_default(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
|
|
"outputId": "26df6295-fa5c-4b3f-89be-c7183f079fce"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2.37 ms ± 6.53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 8) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
|
|
"%timeit mha_pytorch_class_noweights(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "evKtpb5QN_2A",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "evKtpb5QN_2A",
|
|
"outputId": "23bf5398-c8ec-4463-8af9-17de8f920a33"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"14.6 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## 9) Using PyTorch's FlexAttention\n",
|
|
"\n",
|
|
"# Requires PyTorch 2.5.0 or newer\n",
|
|
"%timeit mha_pytorch_flex(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "dabc6575-0316-4640-a729-e616d5c17b73",
|
|
"metadata": {
|
|
"id": "dabc6575-0316-4640-a729-e616d5c17b73"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"\n",
|
|
"# Visualizations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
|
|
"outputId": "a45fe256-6416-4f43-87d2-27bbf97239e3"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"PyTorch version: 2.5.0.dev20240905+cu121\n",
|
|
"Running on cuda\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"torch.manual_seed(123)\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"print(f\"PyTorch version: {torch.__version__}\")\n",
|
|
"print(f\"Running on {device}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "b0620bf5",
|
|
"metadata": {
|
|
"id": "b0620bf5"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"functions = {\n",
|
|
" \"1) MHA wrapper class\": mha_ch03_wrapper,\n",
|
|
" \"2) MHA Ch03\": mha_ch03,\n",
|
|
" \"3) MHA with combined QKV weights\": mha_combined_qkv,\n",
|
|
" \"4) MHA with Einsum\": mha_einsum,\n",
|
|
" \"5) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n",
|
|
" \"6) PyTorch's SDPA, no FlashAttention\": mha_pytorch_sdpa_no_flash,\n",
|
|
" \"7) PyTorch MHA class defaults\": mha_pytorch_class_default,\n",
|
|
" \"8) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights\n",
|
|
" }\n",
|
|
"\n",
|
|
"if current_version >= required_version:\n",
|
|
" functions[\"8) PyTorch's FlexAttention\"] = mha_pytorch_flex"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"id": "CDJAPZaszaqx",
|
|
"metadata": {
|
|
"id": "CDJAPZaszaqx"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"# Customize further for dark mode aesthetics\n",
|
|
"plt.rcParams[\"figure.facecolor\"] = \"#121212\"\n",
|
|
"plt.rcParams[\"axes.facecolor\"] = \"#121212\"\n",
|
|
"plt.rcParams[\"axes.edgecolor\"] = \"white\"\n",
|
|
"plt.rcParams[\"axes.labelcolor\"] = \"white\"\n",
|
|
"plt.rcParams[\"text.color\"] = \"white\"\n",
|
|
"plt.rcParams[\"xtick.color\"] = \"white\"\n",
|
|
"plt.rcParams[\"ytick.color\"] = \"white\"\n",
|
|
"plt.rcParams[\"grid.color\"] = \"#444444\"\n",
|
|
"plt.rcParams[\"lines.linewidth\"] = 2\n",
|
|
"plt.rcParams[\"lines.markersize\"] = 8\n",
|
|
"\n",
|
|
"def plot_execution_times(functions, execution_means, execution_stds, filename):\n",
|
|
"\n",
|
|
" # Create plot\n",
|
|
" fig, ax = plt.subplots()\n",
|
|
" bars = ax.bar(functions.keys(), execution_means, yerr=execution_stds, capsize=5, error_kw={'ecolor': 'grey'})\n",
|
|
"\n",
|
|
" plt.ylabel(\"Execution time (ms)\")\n",
|
|
" plt.xticks(rotation=45, ha=\"right\")\n",
|
|
"\n",
|
|
" # Calculate new ylim with a margin\n",
|
|
" max_execution_time = max(execution_means)\n",
|
|
" upper_ylim = max_execution_time + 0.4 * max_execution_time # Adding a 40% margin\n",
|
|
" plt.ylim(0, upper_ylim)\n",
|
|
"\n",
|
|
" # Annotate bars with execution times\n",
|
|
" for bar in bars:\n",
|
|
" yval = bar.get_height()\n",
|
|
" plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha=\"center\", va=\"bottom\")\n",
|
|
"\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.savefig(filename)\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4df834dc",
|
|
"metadata": {
|
|
"id": "4df834dc"
|
|
},
|
|
"source": [
|
|
"## Speed comparison (Nvidia A100 GPU) with warmup (forward pass only)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000",
|
|
"metadata": {
|
|
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# CUDA benchmark code shared by Andrei Aksionov\n",
|
|
"# and based on code from\n",
|
|
"# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"def time_pytorch_function(func, *input, num_repeats=1_000):\n",
|
|
" start = torch.cuda.Event(enable_timing=True)\n",
|
|
" end = torch.cuda.Event(enable_timing=True)\n",
|
|
"\n",
|
|
" # Warmup\n",
|
|
" for _ in range(5):\n",
|
|
" func(*input)\n",
|
|
" torch.cuda.synchronize()\n",
|
|
"\n",
|
|
" times = []\n",
|
|
" for _ in range(num_repeats):\n",
|
|
" start.record()\n",
|
|
" func(*input)\n",
|
|
" end.record()\n",
|
|
" torch.cuda.synchronize()\n",
|
|
" times.append(start.elapsed_time(end))\n",
|
|
"\n",
|
|
" return np.mean(times), np.std(times)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "9dd07a09",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 488
|
|
},
|
|
"id": "9dd07a09",
|
|
"outputId": "491d06f4-a6bc-431a-a1ca-4db38df57e0c"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"execution_stats = [time_pytorch_function(fn, embeddings) for fn in functions.values()]\n",
|
|
"execution_means = [stat[0] for stat in execution_stats]\n",
|
|
"execution_stds = [stat[1] for stat in execution_stats]\n",
|
|
"\n",
|
|
"\n",
|
|
"plot_execution_times(functions, execution_means, execution_stds, filename=\"1_forward-only.pdf\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "VQaSerWCOnYB",
|
|
"metadata": {
|
|
"id": "VQaSerWCOnYB"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"\n",
|
|
"## Speed comparison (Nvidia A100 GPU) with warmup (forward and backward pass)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"id": "69e6377b",
|
|
"metadata": {
|
|
"id": "69e6377b"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def forward_backward(func, embeddings):\n",
|
|
" if embeddings.grad is not None:\n",
|
|
" embeddings.grad.zero_()\n",
|
|
"\n",
|
|
" output = func(embeddings)\n",
|
|
" loss = output.sum()\n",
|
|
" loss.backward()\n",
|
|
"\n",
|
|
"\n",
|
|
"def time_pytorch_function_forward_backward(func, *input, num_repeats = 1_000):\n",
|
|
" # CUDA IS ASYNC so can't use python time module\n",
|
|
" start = torch.cuda.Event(enable_timing=True)\n",
|
|
" end = torch.cuda.Event(enable_timing=True)\n",
|
|
"\n",
|
|
" # Warmup\n",
|
|
" for _ in range(5):\n",
|
|
" forward_backward(func, *input)\n",
|
|
" torch.cuda.synchronize()\n",
|
|
"\n",
|
|
" times = []\n",
|
|
" for _ in range(num_repeats):\n",
|
|
" start.record()\n",
|
|
" forward_backward(func, *input)\n",
|
|
" end.record()\n",
|
|
" torch.cuda.synchronize()\n",
|
|
" times.append(start.elapsed_time(end))\n",
|
|
"\n",
|
|
" return np.mean(times), np.std(times)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"id": "ReCmeRhCOpm8",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 488
|
|
},
|
|
"id": "ReCmeRhCOpm8",
|
|
"outputId": "2bcfa909-ba87-4d31-b926-bc66e63736cc"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"execution_stats = [time_pytorch_function_forward_backward(fn, embeddings) for fn in functions.values()]\n",
|
|
"execution_means = [stat[0] for stat in execution_stats]\n",
|
|
"execution_stds = [stat[1] for stat in execution_stats]\n",
|
|
"\n",
|
|
"\n",
|
|
"plot_execution_times(functions, execution_means, execution_stds, filename=\"2_forward-and-backward.pdf\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1gWX-Ayqia1k",
|
|
"metadata": {
|
|
"id": "1gWX-Ayqia1k"
|
|
},
|
|
"source": [
|
|
"<br>\n",
|
|
" \n",
|
|
"\n",
|
|
"\n",
|
|
"## Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"id": "LQDiAPooiYAz",
|
|
"metadata": {
|
|
"id": "LQDiAPooiYAz"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch._dynamo\n",
|
|
"torch._dynamo.config.suppress_errors = True\n",
|
|
"\n",
|
|
"def prepare_function(fn):\n",
|
|
" fn = torch.compile(fn)\n",
|
|
" return fn"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"id": "aac06ffe",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 489
|
|
},
|
|
"id": "aac06ffe",
|
|
"outputId": "098c66b4-1201-4bdd-af23-e634f5ade806"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"execution_stats = [time_pytorch_function_forward_backward(prepare_function(fn), embeddings) for fn in functions.values()]\n",
|
|
"execution_means = [stat[0] for stat in execution_stats]\n",
|
|
"execution_stds = [stat[1] for stat in execution_stats]\n",
|
|
"\n",
|
|
"\n",
|
|
"plot_execution_times(functions, execution_means, execution_stds, filename=\"3_forward-and-backward-compiled.pdf\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"gpuType": "A100",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|