mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-27 07:49:25 +00:00
make consistent with the latest production version
This commit is contained in:
parent
ea9da3a89c
commit
3b57b6d8c4
@ -1066,7 +1066,6 @@
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n",
|
||||
" self.W_key = nn.Parameter(torch.rand(d_in, d_out))\n",
|
||||
" self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n",
|
||||
@ -1077,7 +1076,9 @@
|
||||
" values = x @ self.W_value\n",
|
||||
" \n",
|
||||
" attn_scores = queries @ keys.T # omega\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
" attn_weights = torch.softmax(\n",
|
||||
" attn_scores / keys.shape[-1]**0.5, dim=-1\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
" return context_vec\n",
|
||||
@ -1128,7 +1129,6 @@
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
@ -1598,7 +1598,8 @@
|
||||
"source": [
|
||||
"class CausalAttention(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
|
||||
" def __init__(self, d_in, d_out, context_length,\n",
|
||||
" dropout, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
@ -1616,7 +1617,9 @@
|
||||
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
|
||||
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
|
||||
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
" attn_weights = torch.softmax(\n",
|
||||
" attn_scores / keys.shape[-1]**0.5, dim=-1\n",
|
||||
" )\n",
|
||||
" attn_weights = self.dropout(attn_weights) # New\n",
|
||||
"\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
@ -1728,7 +1731,9 @@
|
||||
"\n",
|
||||
"context_length = batch.shape[1] # This is the number of tokens\n",
|
||||
"d_in, d_out = 3, 2\n",
|
||||
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
|
||||
"mha = MultiHeadAttentionWrapper(\n",
|
||||
" d_in, d_out, context_length, 0.0, num_heads=2\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"context_vecs = mha(batch)\n",
|
||||
"\n",
|
||||
@ -1794,7 +1799,8 @@
|
||||
"class MultiHeadAttention(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
||||
" assert (d_out % num_heads == 0), \\\n",
|
||||
" \"d_out must be divisible by num_heads\"\n",
|
||||
"\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
@ -1805,7 +1811,11 @@
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
||||
" self.register_buffer(\n",
|
||||
" \"mask\",\n",
|
||||
" torch.triu(torch.ones(context_length, context_length),\n",
|
||||
" diagonal=1)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" b, num_tokens, d_in = x.shape\n",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user