From 3b57b6d8c45feae7dd82ac4fcff546ecf76ebd3d Mon Sep 17 00:00:00 2001 From: rasbt Date: Sat, 18 May 2024 12:08:39 -0500 Subject: [PATCH] make consistent with the latest production version --- ch03/01_main-chapter-code/ch03.ipynb | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/ch03/01_main-chapter-code/ch03.ipynb b/ch03/01_main-chapter-code/ch03.ipynb index 688ff27..e3aa98a 100644 --- a/ch03/01_main-chapter-code/ch03.ipynb +++ b/ch03/01_main-chapter-code/ch03.ipynb @@ -1066,7 +1066,6 @@ "\n", " def __init__(self, d_in, d_out):\n", " super().__init__()\n", - " self.d_out = d_out\n", " self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n", " self.W_key = nn.Parameter(torch.rand(d_in, d_out))\n", " self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n", @@ -1077,7 +1076,9 @@ " values = x @ self.W_value\n", " \n", " attn_scores = queries @ keys.T # omega\n", - " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", + " attn_weights = torch.softmax(\n", + " attn_scores / keys.shape[-1]**0.5, dim=-1\n", + " )\n", "\n", " context_vec = attn_weights @ values\n", " return context_vec\n", @@ -1128,7 +1129,6 @@ "\n", " def __init__(self, d_in, d_out, qkv_bias=False):\n", " super().__init__()\n", - " self.d_out = d_out\n", " self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n", @@ -1598,7 +1598,8 @@ "source": [ "class CausalAttention(nn.Module):\n", "\n", - " def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n", + " def __init__(self, d_in, d_out, context_length,\n", + " dropout, qkv_bias=False):\n", " super().__init__()\n", " self.d_out = d_out\n", " self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n", @@ -1616,7 +1617,9 @@ " attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n", " attn_scores.masked_fill_( # New, _ ops are in-place\n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n", - " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", + " attn_weights = torch.softmax(\n", + " attn_scores / keys.shape[-1]**0.5, dim=-1\n", + " )\n", " attn_weights = self.dropout(attn_weights) # New\n", "\n", " context_vec = attn_weights @ values\n", @@ -1728,7 +1731,9 @@ "\n", "context_length = batch.shape[1] # This is the number of tokens\n", "d_in, d_out = 3, 2\n", - "mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n", + "mha = MultiHeadAttentionWrapper(\n", + " d_in, d_out, context_length, 0.0, num_heads=2\n", + ")\n", "\n", "context_vecs = mha(batch)\n", "\n", @@ -1794,7 +1799,8 @@ "class MultiHeadAttention(nn.Module):\n", " def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n", " super().__init__()\n", - " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", + " assert (d_out % num_heads == 0), \\\n", + " \"d_out must be divisible by num_heads\"\n", "\n", " self.d_out = d_out\n", " self.num_heads = num_heads\n", @@ -1805,7 +1811,11 @@ " self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n", " self.dropout = nn.Dropout(dropout)\n", - " self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", + " self.register_buffer(\n", + " \"mask\",\n", + " torch.triu(torch.ones(context_length, context_length),\n", + " diagonal=1)\n", + " )\n", "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape\n",