make consistent with the latest production version

This commit is contained in:
rasbt 2024-05-18 12:08:39 -05:00
parent ea9da3a89c
commit 3b57b6d8c4
No known key found for this signature in database
GPG Key ID: 3C6E5C7C075611DB

View File

@ -1066,7 +1066,6 @@
"\n", "\n",
" def __init__(self, d_in, d_out):\n", " def __init__(self, d_in, d_out):\n",
" super().__init__()\n", " super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n", " self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n",
" self.W_key = nn.Parameter(torch.rand(d_in, d_out))\n", " self.W_key = nn.Parameter(torch.rand(d_in, d_out))\n",
" self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n", " self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n",
@ -1077,7 +1076,9 @@
" values = x @ self.W_value\n", " values = x @ self.W_value\n",
" \n", " \n",
" attn_scores = queries @ keys.T # omega\n", " attn_scores = queries @ keys.T # omega\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " attn_weights = torch.softmax(\n",
" attn_scores / keys.shape[-1]**0.5, dim=-1\n",
" )\n",
"\n", "\n",
" context_vec = attn_weights @ values\n", " context_vec = attn_weights @ values\n",
" return context_vec\n", " return context_vec\n",
@ -1128,7 +1129,6 @@
"\n", "\n",
" def __init__(self, d_in, d_out, qkv_bias=False):\n", " def __init__(self, d_in, d_out, qkv_bias=False):\n",
" super().__init__()\n", " super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
@ -1598,7 +1598,8 @@
"source": [ "source": [
"class CausalAttention(nn.Module):\n", "class CausalAttention(nn.Module):\n",
"\n", "\n",
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n", " def __init__(self, d_in, d_out, context_length,\n",
" dropout, qkv_bias=False):\n",
" super().__init__()\n", " super().__init__()\n",
" self.d_out = d_out\n", " self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
@ -1616,7 +1617,9 @@
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n", " attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
" attn_scores.masked_fill_( # New, _ ops are in-place\n", " attn_scores.masked_fill_( # New, _ ops are in-place\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " attn_weights = torch.softmax(\n",
" attn_scores / keys.shape[-1]**0.5, dim=-1\n",
" )\n",
" attn_weights = self.dropout(attn_weights) # New\n", " attn_weights = self.dropout(attn_weights) # New\n",
"\n", "\n",
" context_vec = attn_weights @ values\n", " context_vec = attn_weights @ values\n",
@ -1728,7 +1731,9 @@
"\n", "\n",
"context_length = batch.shape[1] # This is the number of tokens\n", "context_length = batch.shape[1] # This is the number of tokens\n",
"d_in, d_out = 3, 2\n", "d_in, d_out = 3, 2\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n", "mha = MultiHeadAttentionWrapper(\n",
" d_in, d_out, context_length, 0.0, num_heads=2\n",
")\n",
"\n", "\n",
"context_vecs = mha(batch)\n", "context_vecs = mha(batch)\n",
"\n", "\n",
@ -1794,7 +1799,8 @@
"class MultiHeadAttention(nn.Module):\n", "class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n", " def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n", " super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", " assert (d_out % num_heads == 0), \\\n",
" \"d_out must be divisible by num_heads\"\n",
"\n", "\n",
" self.d_out = d_out\n", " self.d_out = d_out\n",
" self.num_heads = num_heads\n", " self.num_heads = num_heads\n",
@ -1805,7 +1811,11 @@
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n", " self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n", " self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", " self.register_buffer(\n",
" \"mask\",\n",
" torch.triu(torch.ones(context_length, context_length),\n",
" diagonal=1)\n",
" )\n",
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n", " b, num_tokens, d_in = x.shape\n",