mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-01-06 12:11:20 +00:00
mha variants
This commit is contained in:
parent
d4754f1bdd
commit
87fcfd9245
@ -41,12 +41,15 @@ Alternatively, you can view this and other files on GitHub at [https://github.co
|
||||
| Ch 6: Finetuning for Text Classification | Q2 2024 | ... |
|
||||
| Ch 7: Finetuning with Human Feedback | Q2 2024 | ... |
|
||||
| Ch 8: Using Large Language Models in Practice | Q2/3 2024 | ... |
|
||||
| Appendix A: Introduction to PyTorch* | - [code-part1.ipynb](appendix-A/03_main-chapter-code/code-part1.ipynb)<br/>- [code-part2.ipynb](appendix-A/03_main-chapter-code/code-part2.ipynb)<br/>- [DDP-script.py](appendix-A/03_main-chapter-code/DDP-script.py)<br/>- [exercise-solutions.ipynb](appendix-A/03_main-chapter-code/exercise-solutions.ipynb) | [./appendix-A](./appendix-A) |
|
||||
| Appendix A: Introduction to PyTorch | - [code-part1.ipynb](appendix-A/03_main-chapter-code/code-part1.ipynb)<br/>- [code-part2.ipynb](appendix-A/03_main-chapter-code/code-part2.ipynb)<br/>- [DDP-script.py](appendix-A/03_main-chapter-code/DDP-script.py)<br/>- [exercise-solutions.ipynb](appendix-A/03_main-chapter-code/exercise-solutions.ipynb) | [./appendix-A](./appendix-A) |
|
||||
| Appendix B: References and Further Reading | No code | |
|
||||
| Appendix C: Exercises | No code | |
|
||||
|
||||
|
||||
<br>
|
||||
|
||||
> [!TIP]
|
||||
> Please see [this](appendix-A/01_optional-python-setup-preferences) and [this](appendix-A/02_installing-python-libraries) folder if you need more guidance on installing Python and Python packages.)
|
||||
> Please see [this](appendix-A/01_optional-python-setup-preferences) and [this](appendix-A/02_installing-python-libraries) folder if you need more guidance on installing Python and Python packages.
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1637,7 +1637,7 @@
|
||||
"class MultiHeadAttention(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
||||
"\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
@ -1865,7 +1865,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -243,7 +243,7 @@
|
||||
"class MultiHeadAttention(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
||||
"\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
@ -342,7 +342,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
3
ch03/02_bonus_efficient-multihead-attention/README.md
Normal file
3
ch03/02_bonus_efficient-multihead-attention/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# More Efficient Multi-Head Attention Implementations
|
||||
|
||||
- [mha-implementations.ipynb](mha-implementations.ipynb) contains and compares different implementations of multi-head attention
|
||||
58
ch03/02_bonus_efficient-multihead-attention/ch03.py
Normal file
58
ch03/02_bonus_efficient-multihead-attention/ch03.py
Normal file
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
|
||||
self.d_out = d_out
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
|
||||
|
||||
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
|
||||
|
||||
def forward(self, x):
|
||||
b, num_tokens, d_in = x.shape
|
||||
|
||||
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
||||
queries = self.W_query(x)
|
||||
values = self.W_value(x)
|
||||
|
||||
# We implicitly split the matrix by adding a `num_heads` dimension
|
||||
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
|
||||
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
|
||||
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
|
||||
keys = keys.transpose(1, 2)
|
||||
queries = queries.transpose(1, 2)
|
||||
values = values.transpose(1, 2)
|
||||
|
||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
# Original mask truncated to the number of tokens and converted to boolean
|
||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||
# Unsqueeze the mask to match dimensions
|
||||
mask_unsqueezed = mask_bool.unsqueeze(0)
|
||||
# Use the unsqueezed mask to fill attention scores
|
||||
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
|
||||
|
||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||
attn_weights = self.dropout(attn_weights)
|
||||
|
||||
# Shape: (b, num_tokens, num_heads, head_dim)
|
||||
context_vec = (attn_weights @ values).transpose(1, 2)
|
||||
|
||||
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
||||
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
|
||||
context_vec = self.out_proj(context_vec) # optional projection
|
||||
|
||||
return context_vec
|
||||
@ -0,0 +1,356 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6f678e62-7bcb-4405-86ae-dce94f494303",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Appendix D: Efficient Multi-Head Attention Implementations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Multi-head attention implementation from chapter 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"batch_size = 8\n",
|
||||
"context_len = 1024\n",
|
||||
"embed_dim = 768\n",
|
||||
"embeddings = torch.randn((batch_size, context_len, embed_dim))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([8, 1024, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from ch03 import MultiHeadAttention as Ch03_MHA\n",
|
||||
"\n",
|
||||
"mha_ch03 = Ch03_MHA(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"out = mha_ch03(embeddings)\n",
|
||||
"print(out.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## An alternative multi-head attention with combined weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n",
|
||||
"- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n",
|
||||
"\n",
|
||||
" - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
|
||||
" - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
|
||||
" - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
|
||||
" \n",
|
||||
"- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n",
|
||||
"- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([8, 1024, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch.nn as nn\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MultiHeadAttentionAlt(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
|
||||
"\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
" self.block_size = block_size\n",
|
||||
" self.head_dim = d_out // num_heads\n",
|
||||
"\n",
|
||||
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
|
||||
" self.proj = nn.Linear(d_in, d_out)\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
"\n",
|
||||
" self.register_buffer(\n",
|
||||
" \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" batch_size, num_tokens, embed_dim = x.shape\n",
|
||||
"\n",
|
||||
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
|
||||
" qkv = self.qkv(x)\n",
|
||||
"\n",
|
||||
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
|
||||
" qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
|
||||
"\n",
|
||||
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
|
||||
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
|
||||
"\n",
|
||||
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
|
||||
" queries, keys, values = qkv.unbind(0)\n",
|
||||
"\n",
|
||||
" # (b, num_head, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
|
||||
" attn_scores = queries @ keys.transpose(-2, -1)\n",
|
||||
" attn_scores = attn_scores.masked_fill(\n",
|
||||
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n",
|
||||
" attn_weights = self.dropout(attn_weights)\n",
|
||||
"\n",
|
||||
" # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
"\n",
|
||||
" # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n",
|
||||
" context_vec = context_vec.transpose(1, 2)\n",
|
||||
"\n",
|
||||
" # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n",
|
||||
" context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n",
|
||||
"\n",
|
||||
" context_vec = self.proj(context_vec)\n",
|
||||
"\n",
|
||||
" return context_vec\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"mha_alt = MultiHeadAttentionAlt(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"out = mha_alt(embeddings)\n",
|
||||
"print(out.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Multihead attention with PyTorch's scaled dot product attention"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f78e346f-3b85-44e6-9feb-f01131381148",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MultiHeadAttentionPyTorch(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
|
||||
"\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
" self.block_size = block_size\n",
|
||||
" self.head_dim = d_out // num_heads\n",
|
||||
" self.d_out = d_out\n",
|
||||
"\n",
|
||||
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
|
||||
" self.proj = nn.Linear(d_in, d_out)\n",
|
||||
" self.dropout = dropout\n",
|
||||
"\n",
|
||||
" self.register_buffer(\n",
|
||||
" \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" batch_size, num_tokens, embed_dim = x.shape\n",
|
||||
"\n",
|
||||
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
|
||||
" qkv = self.qkv(x)\n",
|
||||
"\n",
|
||||
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
|
||||
" qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
|
||||
"\n",
|
||||
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
|
||||
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
|
||||
"\n",
|
||||
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
|
||||
" q, k, v = qkv.unbind(0)\n",
|
||||
"\n",
|
||||
" use_dropout = 0. if not self.training else self.dropout\n",
|
||||
" context_vec = torch.nn.functional.scaled_dot_product_attention(q, k, v, \n",
|
||||
" attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
|
||||
"\n",
|
||||
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
||||
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
|
||||
"\n",
|
||||
" return context_vec"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([8, 1024, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mha_pytorch = MultiHeadAttentionPyTorch(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"out = mha_pytorch(embeddings)\n",
|
||||
"print(out.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8877de71-f84f-4f6d-bc87-7552013b6301",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Speed comparison"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"253 ms ± 9.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%timeit mha_ch03(embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"309 ms ± 26.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%timeit mha_alt(embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"90.4 ms ± 719 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%timeit mha_pytorch(embeddings)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@ -1,3 +1,4 @@
|
||||
# Chapter 3: Coding Attention Mechanisms
|
||||
|
||||
- [01_main-chapter-code](01_main-chapter-code) contains the main chapter code.
|
||||
- [01_main-chapter-code](01_main-chapter-code) contains the main chapter code.
|
||||
- [02_bonus_efficient-multihead-attention](02_bonus_efficient-multihead-attention) implements and compares different implementation variants of multihead-attention
|
||||
@ -56,7 +56,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
|
||||
self.d_out = d_out
|
||||
self.num_heads = num_heads
|
||||
|
||||
@ -45,7 +45,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
|
||||
self.d_out = d_out
|
||||
self.num_heads = num_heads
|
||||
|
||||
@ -56,7 +56,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
|
||||
self.d_out = d_out
|
||||
self.num_heads = num_heads
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user