mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-12 23:42:17 +00:00
also add simple wrapper
This commit is contained in:
parent
571377a2d6
commit
b6fe1a37b3
@ -1865,7 +1865,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.10.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@ -2,6 +2,47 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class CausalAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):
|
||||||
|
super().__init__()
|
||||||
|
self.d_out = d_out
|
||||||
|
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||||
|
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||||
|
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||||
|
self.dropout = nn.Dropout(dropout) # New
|
||||||
|
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, num_tokens, d_in = x.shape # New batch dimension b
|
||||||
|
keys = self.W_key(x)
|
||||||
|
queries = self.W_query(x)
|
||||||
|
values = self.W_value(x)
|
||||||
|
|
||||||
|
attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
|
||||||
|
attn_scores.masked_fill_( # New, _ ops are in-place
|
||||||
|
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
|
||||||
|
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||||
|
attn_weights = self.dropout(attn_weights) # New
|
||||||
|
|
||||||
|
context_vec = attn_weights @ values
|
||||||
|
return context_vec
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttentionWrapper(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = nn.ModuleList(
|
||||||
|
[CausalAttention(d_in, d_out, block_size, dropout, qkv_bias)
|
||||||
|
for _ in range(num_heads)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.cat([head(x) for head in self.heads], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
|
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -13,7 +13,7 @@
|
|||||||
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
|
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Multi-head attention implementation from chapter 3"
|
"## Multi-head attention implementations from chapter 3"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -36,6 +36,36 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
|
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"torch.Size([8, 1024, 9216])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_1\n",
|
||||||
|
"\n",
|
||||||
|
"mha_ch03_1 = Ch03_MHA_1(\n",
|
||||||
|
" d_in=embed_dim,\n",
|
||||||
|
" d_out=embed_dim,\n",
|
||||||
|
" block_size=context_len,\n",
|
||||||
|
" dropout=0.0,\n",
|
||||||
|
" num_heads=12,\n",
|
||||||
|
" qkv_bias=False\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"out = mha_ch03_1(embeddings)\n",
|
||||||
|
"print(out.shape)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
|
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -48,9 +78,9 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from ch03 import MultiHeadAttention as Ch03_MHA\n",
|
"from ch03 import MultiHeadAttention as Ch03_MHA_2\n",
|
||||||
"\n",
|
"\n",
|
||||||
"mha_ch03 = Ch03_MHA(\n",
|
"mha_ch03_2 = Ch03_MHA_2(\n",
|
||||||
" d_in=embed_dim,\n",
|
" d_in=embed_dim,\n",
|
||||||
" d_out=embed_dim,\n",
|
" d_out=embed_dim,\n",
|
||||||
" block_size=context_len,\n",
|
" block_size=context_len,\n",
|
||||||
@ -59,7 +89,7 @@
|
|||||||
" qkv_bias=False\n",
|
" qkv_bias=False\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"out = mha_ch03(embeddings)\n",
|
"out = mha_ch03_2(embeddings)\n",
|
||||||
"print(out.shape)"
|
"print(out.shape)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -89,7 +119,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
|
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -192,7 +222,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
|
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -243,7 +273,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 6,
|
||||||
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
|
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -279,7 +309,25 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 7,
|
||||||
|
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"879 ms ± 4.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%timeit mha_ch03_1(embeddings)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
|
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -287,17 +335,17 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"253 ms ± 9.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
"259 ms ± 7.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"%timeit mha_ch03(embeddings)"
|
"%timeit mha_ch03_2(embeddings)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 9,
|
||||||
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
|
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -305,7 +353,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"309 ms ± 26.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
"290 ms ± 2.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -315,7 +363,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 10,
|
||||||
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
|
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -323,7 +371,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"90.4 ms ± 719 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
"91.5 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user