diff --git a/ch03/01_main-chapter-code/ch03.ipynb b/ch03/01_main-chapter-code/ch03.ipynb index 734bcdf..8e70dda 100644 --- a/ch03/01_main-chapter-code/ch03.ipynb +++ b/ch03/01_main-chapter-code/ch03.ipynb @@ -1865,7 +1865,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/ch03/02_bonus_efficient-multihead-attention/ch03.py b/ch03/02_bonus_efficient-multihead-attention/ch03.py index f7343d7..ed77eb0 100644 --- a/ch03/02_bonus_efficient-multihead-attention/ch03.py +++ b/ch03/02_bonus_efficient-multihead-attention/ch03.py @@ -2,6 +2,47 @@ import torch import torch.nn as nn +class CausalAttention(nn.Module): + + def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False): + super().__init__() + self.d_out = d_out + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.dropout = nn.Dropout(dropout) # New + self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New + + def forward(self, x): + b, num_tokens, d_in = x.shape # New batch dimension b + keys = self.W_key(x) + queries = self.W_query(x) + values = self.W_value(x) + + attn_scores = queries @ keys.transpose(1, 2) # Changed transpose + attn_scores.masked_fill_( # New, _ ops are in-place + self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) # New + + context_vec = attn_weights @ values + return context_vec + + +class MultiHeadAttentionWrapper(nn.Module): + + def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False): + super().__init__() + self.heads = nn.ModuleList( + [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) + for _ in range(num_heads)] + ) + + def forward(self, x): + return torch.cat([head(x) for head in self.heads], dim=-1) + + + class MultiHeadAttention(nn.Module): def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False): super().__init__() diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index c1abfc9..d2bfba2 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -13,7 +13,7 @@ "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", "metadata": {}, "source": [ - "## Multi-head attention implementation from chapter 3" + "## Multi-head attention implementations from chapter 3" ] }, { @@ -36,6 +36,36 @@ { "cell_type": "code", "execution_count": 2, + "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 9216])\n" + ] + } + ], + "source": [ + "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_1\n", + "\n", + "mha_ch03_1 = Ch03_MHA_1(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ")\n", + "\n", + "out = mha_ch03_1(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", "metadata": {}, "outputs": [ @@ -48,9 +78,9 @@ } ], "source": [ - "from ch03 import MultiHeadAttention as Ch03_MHA\n", + "from ch03 import MultiHeadAttention as Ch03_MHA_2\n", "\n", - "mha_ch03 = Ch03_MHA(\n", + "mha_ch03_2 = Ch03_MHA_2(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", @@ -59,7 +89,7 @@ " qkv_bias=False\n", ")\n", "\n", - "out = mha_ch03(embeddings)\n", + "out = mha_ch03_2(embeddings)\n", "print(out.shape)" ] }, @@ -89,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", "metadata": {}, "outputs": [ @@ -192,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", "metadata": {}, "outputs": [], @@ -243,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", "metadata": {}, "outputs": [ @@ -279,7 +309,25 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, + "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "879 ms ± 4.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%timeit mha_ch03_1(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", "metadata": {}, "outputs": [ @@ -287,17 +335,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "253 ms ± 9.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "259 ms ± 7.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ - "%timeit mha_ch03(embeddings)" + "%timeit mha_ch03_2(embeddings)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", "metadata": {}, "outputs": [ @@ -305,7 +353,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "309 ms ± 26.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "290 ms ± 2.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -315,7 +363,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", "metadata": {}, "outputs": [ @@ -323,7 +371,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "90.4 ms ± 719 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "91.5 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ],