Rename variable to context_length to make it easier on readers (#106)

* rename to context length

* fix spacing
This commit is contained in:
Sebastian Raschka 2024-04-04 07:27:41 -05:00 committed by GitHub
parent 684562733a
commit ccd7cebbb3
25 changed files with 242 additions and 242 deletions

View File

@ -61,13 +61,13 @@
"from previous_chapters import GPTModel\n",
"\n",
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
"}\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
@ -127,8 +127,8 @@
"train_loader = create_dataloader_v1(\n",
" text_data[:split_idx],\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=True,\n",
" shuffle=True\n",
")\n",
@ -136,8 +136,8 @@
"val_loader = create_dataloader_v1(\n",
" text_data[split_idx:],\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=False,\n",
" shuffle=False\n",
")"
@ -755,7 +755,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -61,7 +61,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
@ -74,7 +74,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -164,7 +164,7 @@ class TransformerBlock(nn.Module):
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
block_size=cfg["ctx_len"],
context_length=cfg["ctx_len"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])

View File

@ -1772,8 +1772,8 @@
"metadata": {},
"outputs": [],
"source": [
"block_size = max_length\n",
"pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)"
"context_length = max_length\n",
"pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)"
]
},
{
@ -1874,7 +1874,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -87,11 +87,11 @@
"\n",
"vocab_size = 50257\n",
"output_dim = 256\n",
"block_size = 1024\n",
"context_length = 1024\n",
"\n",
"\n",
"token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)\n",
"\n",
"max_length = 4\n",
"dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=max_length, stride=max_length)"
@ -150,7 +150,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -294,9 +294,9 @@
"vocab_size = 50257\n",
"output_dim = 256\n",
"max_len = 4\n",
"block_size = max_len\n",
"context_length = max_len\n",
"\n",
"token_embedding_layer = torch.nn.Embedding(block_size, output_dim)\n",
"token_embedding_layer = torch.nn.Embedding(context_length, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)"
]
},

View File

@ -1275,8 +1275,8 @@
}
],
"source": [
"block_size = attn_scores.shape[0]\n",
"mask_simple = torch.tril(torch.ones(block_size, block_size))\n",
"context_length = attn_scores.shape[0]\n",
"mask_simple = torch.tril(torch.ones(context_length, context_length))\n",
"print(mask_simple)"
]
},
@ -1395,7 +1395,7 @@
}
],
"source": [
"mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
"mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
"masked = attn_scores.masked_fill(mask.bool(), -torch.inf)\n",
"print(masked)"
]
@ -1598,14 +1598,14 @@
"source": [
"class CausalAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape # New batch dimension b\n",
@ -1624,8 +1624,8 @@
"\n",
"torch.manual_seed(123)\n",
"\n",
"block_size = batch.shape[1]\n",
"ca = CausalAttention(d_in, d_out, block_size, 0.0)\n",
"context_length = batch.shape[1]\n",
"ca = CausalAttention(d_in, d_out, context_length, 0.0)\n",
"\n",
"context_vecs = ca(batch)\n",
"\n",
@ -1713,10 +1713,10 @@
"source": [
"class MultiHeadAttentionWrapper(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" self.heads = nn.ModuleList(\n",
" [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
" [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
" for _ in range(num_heads)]\n",
" )\n",
"\n",
@ -1726,9 +1726,9 @@
"\n",
"torch.manual_seed(123)\n",
"\n",
"block_size = batch.shape[1] # This is the number of tokens\n",
"context_length = batch.shape[1] # This is the number of tokens\n",
"d_in, d_out = 3, 2\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"\n",
"context_vecs = mha(batch)\n",
"\n",
@ -1792,7 +1792,7 @@
],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
"\n",
@ -1805,7 +1805,7 @@
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
@ -1848,9 +1848,9 @@
"\n",
"torch.manual_seed(123)\n",
"\n",
"batch_size, block_size, d_in = batch.shape\n",
"batch_size, context_length, d_in = batch.shape\n",
"d_out = 2\n",
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)\n",
"mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"\n",
"context_vecs = mha(batch)\n",
"\n",

View File

@ -201,7 +201,7 @@
"torch.manual_seed(123)\n",
"\n",
"d_out = 1\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"\n",
"context_vecs = mha(batch)\n",
"\n",
@ -247,11 +247,11 @@
"metadata": {},
"source": [
"```python\n",
"block_size = 1024\n",
"context_length = 1024\n",
"d_in, d_out = 768, 768\n",
"num_heads = 12\n",
"\n",
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)\n",
"mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads)\n",
"```"
]
},

View File

@ -116,11 +116,11 @@
"vocab_size = 50257\n",
"output_dim = 256\n",
"max_len = 1024\n",
"block_size = max_len\n",
"context_length = max_len\n",
"\n",
"\n",
"token_embedding_layer = nn.Embedding(vocab_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)\n",
"\n",
"max_length = 4\n",
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=max_length)"
@ -187,14 +187,14 @@
"source": [
"class CausalSelfAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
"\n",
" def forward(self, x):\n",
" b, n_tokens, d_in = x.shape # New batch dimension b\n",
@ -213,10 +213,10 @@
"\n",
"\n",
"class MultiHeadAttentionWrapper(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" self.heads = nn.ModuleList(\n",
" [CausalSelfAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
" [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
" for _ in range(num_heads)]\n",
" )\n",
" self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
@ -243,13 +243,13 @@
"source": [
"torch.manual_seed(123)\n",
"\n",
"block_size = max_length\n",
"context_length = max_length\n",
"d_in = output_dim\n",
"\n",
"num_heads=2\n",
"d_out = d_in // num_heads\n",
"\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads)\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads)\n",
"\n",
"batch = input_embeddings\n",
"context_vecs = mha(batch)\n",
@ -273,7 +273,7 @@
"outputs": [],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
"\n",
@ -286,7 +286,7 @@
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
@ -345,11 +345,11 @@
"source": [
"torch.manual_seed(123)\n",
"\n",
"block_size = max_length\n",
"context_length = max_length\n",
"d_in = output_dim\n",
"d_out = d_in\n",
"\n",
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)\n",
"mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"\n",
"batch = input_embeddings\n",
"context_vecs = mha(batch)\n",
@ -374,7 +374,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -105,7 +105,7 @@
"mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim//12,\n",
" block_size=context_len,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
@ -154,7 +154,7 @@
"mha_ch03 = Ch03_MHA(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
@ -220,13 +220,13 @@
"\n",
"\n",
"class MultiHeadAttentionCombinedQKV(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.block_size = block_size\n",
" self.context_length = context_length\n",
" self.head_dim = d_out // num_heads\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
@ -234,7 +234,7 @@
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" self.register_buffer(\n",
" \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
" \"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
@ -278,7 +278,7 @@
"mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
@ -321,13 +321,13 @@
"outputs": [],
"source": [
"class MHAPyTorchScaledDotProduct(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.block_size = block_size\n",
" self.context_length = context_length\n",
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
@ -336,7 +336,7 @@
" self.dropout = dropout\n",
"\n",
" self.register_buffer(\n",
" \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
" \"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
@ -388,7 +388,7 @@
"mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
@ -446,10 +446,10 @@
"\n",
"\n",
"class MHAPyTorchClass(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False, need_weights=True):\n",
" super().__init__()\n",
"\n",
" self.block_size = block_size\n",
" self.context_length = context_length\n",
" self.multihead_attn = nn.MultiheadAttention(\n",
" embed_dim=d_out,\n",
" num_heads=num_heads,\n",
@ -461,17 +461,17 @@
"\n",
" self.need_weights = need_weights\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, _ = x.shape\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.block_size >= num_tokens:\n",
" if self.context_length >= num_tokens:\n",
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.mask[:self.block_size, :self.block_size]\n",
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
"\n",
" # attn_mask broadcasting will handle batch_size dimension implicitly\n",
" attn_output, _ = self.multihead_attn(\n",
@ -486,7 +486,7 @@
"mha_pytorch_class_default = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
@ -548,7 +548,7 @@
"mha_pytorch_class_noweights = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False,\n",
@ -1031,7 +1031,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -117,13 +117,13 @@
"outputs": [],
"source": [
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 1024, # Context length\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-Key-Value bias\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 1024, # Context length\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-Key-Value bias\n",
"}"
]
},
@ -134,7 +134,7 @@
"source": [
"- We use short variable names to avoid long lines of code later\n",
"- `\"vocab_size\"` indicates a vocabulary size of 50,257 words, supported by the BPE tokenizer discussed in Chapter 2\n",
"- `\"ctx_len\"` represents the model's maximum input token count, as enabled by positional embeddings covered in Chapter 2\n",
"- `\"context_length\"` represents the model's maximum input token count, as enabled by positional embeddings covered in Chapter 2\n",
"- `\"emb_dim\"` is the embedding size for token inputs, converting each input token into a 768-dimensional vector\n",
"- `\"n_heads\"` is the number of attention heads in the multi-head attention mechanism implemented in Chapter 3\n",
"- `\"n_layers\"` is the number of transformer blocks within the model, which we'll implement in upcoming sections\n",
@ -943,7 +943,7 @@
" self.att = MultiHeadAttention(\n",
" d_in=cfg[\"emb_dim\"],\n",
" d_out=cfg[\"emb_dim\"],\n",
" block_size=cfg[\"ctx_len\"],\n",
" context_length=cfg[\"ctx_len\"],\n",
" num_heads=cfg[\"n_heads\"], \n",
" dropout=cfg[\"drop_rate\"],\n",
" qkv_bias=cfg[\"qkv_bias\"])\n",
@ -1489,7 +1489,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -34,11 +34,11 @@
"metadata": {},
"outputs": [],
"source": [
"from gpt import TransformerBlock\n",
"from gpt import Transfocontext_lengthmerBlock\n",
"\n",
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257,\n",
" \"ctx_len\": 1024,\n",
" \"context_length\": 1024,\n",
" \"emb_dim\": 768,\n",
" \"n_heads\": 12,\n",
" \"n_layers\": 12,\n",
@ -139,7 +139,7 @@
"source": [
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257,\n",
" \"ctx_len\": 1024,\n",
" \"context_length\": 1024,\n",
" \"emb_dim\": 768,\n",
" \"n_heads\": 12,\n",
" \"n_layers\": 12,\n",
@ -260,7 +260,7 @@
"source": [
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257,\n",
" \"ctx_len\": 1024,\n",
" \"context_length\": 1024,\n",
" \"emb_dim\": 768,\n",
" \"n_heads\": 12,\n",
" \"n_layers\": 12,\n",
@ -288,7 +288,7 @@
" self.att = MultiHeadAttention(\n",
" d_in=cfg[\"emb_dim\"],\n",
" d_out=cfg[\"emb_dim\"],\n",
" block_size=cfg[\"ctx_len\"],\n",
" context_length=cfg[\"context_length\"],\n",
" num_heads=cfg[\"n_heads\"], \n",
" dropout=cfg[\"drop_rate_attn\"], # NEW: dropout for multi-head attention\n",
" qkv_bias=cfg[\"qkv_bias\"])\n",
@ -319,7 +319,7 @@
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n",
" self.pos_emb = nn.Embedding(cfg[\"ctx_len\"], cfg[\"emb_dim\"])\n",
" self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n",
" self.drop_emb = nn.Dropout(cfg[\"drop_rate_emb\"]) # NEW: dropout for embedding layers\n",
"\n",
" self.trf_blocks = nn.Sequential(\n",
@ -370,7 +370,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -54,7 +54,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
@ -67,7 +67,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -156,7 +156,7 @@ class TransformerBlock(nn.Module):
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
block_size=cfg["ctx_len"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
@ -187,7 +187,7 @@ class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
@ -236,13 +236,13 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
def main():
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"ctx_len": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
torch.manual_seed(123)
@ -264,7 +264,7 @@ def main():
model=model,
idx=encoded_tensor,
max_new_tokens=10,
context_size=GPT_CONFIG_124M["ctx_len"]
context_size=GPT_CONFIG_124M["context_length"]
)
decoded_text = tokenizer.decode(out.squeeze(0).tolist())

View File

@ -48,7 +48,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
@ -61,7 +61,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape

View File

@ -140,13 +140,13 @@
"from previous_chapters import GPTModel\n",
"\n",
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
"}\n",
"\n",
"torch.manual_seed(123)\n",
@ -161,10 +161,10 @@
"source": [
"- We use dropout of 0.1 above, but it's relatively common to train LLMs without dropout nowadays\n",
"- Modern LLMs also don't use bias vectors in the `nn.Linear` layers for the query, key, and value matrices (unlike earlier GPT models), which is achieved by setting `\"qkv_bias\": False`\n",
"- We reduce the context length (`ctx_len`) of only 256 tokens to reduce the computational resource requirements for training the model, whereas the original 124 million parameter GPT-2 model used 1024 characters\n",
"- We reduce the context length (`context_length`) of only 256 tokens to reduce the computational resource requirements for training the model, whereas the original 124 million parameter GPT-2 model used 1024 characters\n",
" - This is so that more readers will be able to follow and execute the code examples on their laptop computer\n",
" - However, please feel free to increase the `ctx_len` to 1024 tokens (this would not require any code changes)\n",
" - We will also load a model with a 1024 `ctx_len` later from pretrained weights"
" - However, please feel free to increase the `context_length` to 1024 tokens (this would not require any code changes)\n",
" - We will also load a model with a 1024 `context_length` later from pretrained weights"
]
},
{
@ -219,7 +219,7 @@
" model=model,\n",
" idx=text_to_token_ids(start_context, tokenizer),\n",
" max_new_tokens=10,\n",
" context_size=GPT_CONFIG_124M[\"ctx_len\"]\n",
" context_size=GPT_CONFIG_124M[\"context_length\"]\n",
")\n",
"\n",
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
@ -928,8 +928,8 @@
"train_loader = create_dataloader_v1(\n",
" train_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=True,\n",
" shuffle=True\n",
")\n",
@ -937,8 +937,8 @@
"val_loader = create_dataloader_v1(\n",
" val_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=False,\n",
" shuffle=False\n",
")"
@ -953,14 +953,14 @@
"source": [
"# Sanity check\n",
"\n",
"if total_tokens * (train_ratio) < GPT_CONFIG_124M[\"ctx_len\"]:\n",
"if total_tokens * (train_ratio) < GPT_CONFIG_124M[\"context_length\"]:\n",
" print(\"Not enough tokens for the training loader. \"\n",
" \"Try to lower the `GPT_CONFIG_124M['ctx_len']` or \"\n",
" \"Try to lower the `GPT_CONFIG_124M['context_length']` or \"\n",
" \"increase the `training_ratio`\")\n",
"\n",
"if total_tokens * (1-train_ratio) < GPT_CONFIG_124M[\"ctx_len\"]:\n",
"if total_tokens * (1-train_ratio) < GPT_CONFIG_124M[\"context_length\"]:\n",
" print(\"Not enough tokens for the validation loader. \"\n",
" \"Try to lower the `GPT_CONFIG_124M['ctx_len']` or \"\n",
" \"Try to lower the `GPT_CONFIG_124M['context_length']` or \"\n",
" \"decrease the `training_ratio`\")"
]
},
@ -1441,7 +1441,7 @@
" model=model,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
" max_new_tokens=25,\n",
" context_size=GPT_CONFIG_124M[\"ctx_len\"]\n",
" context_size=GPT_CONFIG_124M[\"context_length\"]\n",
")\n",
"\n",
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
@ -1906,7 +1906,7 @@
" model=model,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
" max_new_tokens=15,\n",
" context_size=GPT_CONFIG_124M[\"ctx_len\"],\n",
" context_size=GPT_CONFIG_124M[\"context_length\"],\n",
" top_k=25,\n",
" temperature=1.4\n",
")\n",
@ -2203,7 +2203,7 @@
"model_name = \"gpt2-small (124M)\" # Example model name\n",
"NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
"NEW_CONFIG.update(model_configs[model_name])\n",
"NEW_CONFIG.update({\"ctx_len\": 1024, \"qkv_bias\": True})\n",
"NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
"\n",
"gpt = GPTModel(NEW_CONFIG)\n",
"gpt.eval();"
@ -2338,7 +2338,7 @@
" model=gpt,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
" max_new_tokens=25,\n",
" context_size=NEW_CONFIG[\"ctx_len\"],\n",
" context_size=NEW_CONFIG[\"context_length\"],\n",
" top_k=50,\n",
" temperature=1.5\n",
")\n",
@ -2403,7 +2403,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -234,7 +234,7 @@
"\n",
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 256, # Shortened context length (orig: 1024)\n",
" \"context_length\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
@ -286,7 +286,7 @@
" model=model,\n",
" idx=text_to_token_ids(start_context, tokenizer),\n",
" max_new_tokens=25,\n",
" context_size=GPT_CONFIG_124M[\"ctx_len\"]\n",
" context_size=GPT_CONFIG_124M[\"context_length\"]\n",
")\n",
"\n",
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
@ -314,7 +314,7 @@
" model=model,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
" max_new_tokens=25,\n",
" context_size=GPT_CONFIG_124M[\"ctx_len\"],\n",
" context_size=GPT_CONFIG_124M[\"context_length\"],\n",
" top_k=None,\n",
" temperature=0.0\n",
")\n",
@ -344,7 +344,7 @@
" model=model,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
" max_new_tokens=25,\n",
" context_size=GPT_CONFIG_124M[\"ctx_len\"],\n",
" context_size=GPT_CONFIG_124M[\"context_length\"],\n",
" top_k=None,\n",
" temperature=0.0\n",
")\n",
@ -383,13 +383,13 @@
"\n",
"\n",
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
"}\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
@ -451,8 +451,8 @@
"train_loader = create_dataloader_v1(\n",
" train_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=True,\n",
" shuffle=True\n",
")\n",
@ -460,8 +460,8 @@
"val_loader = create_dataloader_v1(\n",
" val_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=False,\n",
" shuffle=False\n",
")"
@ -557,13 +557,13 @@
"\n",
"\n",
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
"}\n",
"\n",
"\n",
@ -617,7 +617,7 @@
"model_name = \"gpt2-small (124M)\" # Example model name\n",
"NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
"NEW_CONFIG.update(model_configs[model_name])\n",
"NEW_CONFIG.update({\"ctx_len\": 1024, \"qkv_bias\": True})\n",
"NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
"\n",
"gpt = GPTModel(NEW_CONFIG)\n",
"gpt.eval();"
@ -675,8 +675,8 @@
"train_loader = create_dataloader_v1(\n",
" train_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=True,\n",
" shuffle=True\n",
")\n",
@ -684,8 +684,8 @@
"val_loader = create_dataloader_v1(\n",
" val_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=False,\n",
" shuffle=False\n",
")"
@ -753,7 +753,7 @@
"model_name = \"gpt2-xl (1558M)\"\n",
"NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
"NEW_CONFIG.update(model_configs[model_name])\n",
"NEW_CONFIG.update({\"ctx_len\": 1024, \"qkv_bias\": True})\n",
"NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
"\n",
"gpt = GPTModel(NEW_CONFIG)\n",
"gpt.eval();\n",
@ -811,13 +811,13 @@
"\n",
"\n",
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
"}\n",
"\n",
"\n",
@ -859,7 +859,7 @@
"model_name = \"gpt2-xl (1558M)\"\n",
"NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
"NEW_CONFIG.update(model_configs[model_name])\n",
"NEW_CONFIG.update({\"ctx_len\": 1024, \"qkv_bias\": True})\n",
"NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
"\n",
"gpt = GPTModel(NEW_CONFIG)\n",
"gpt.eval()\n",
@ -901,7 +901,7 @@
" model=gpt,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
" max_new_tokens=25,\n",
" context_size=NEW_CONFIG[\"ctx_len\"],\n",
" context_size=NEW_CONFIG[\"context_length\"],\n",
" top_k=50,\n",
" temperature=1.5\n",
")\n",
@ -926,7 +926,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -234,7 +234,7 @@ def main(gpt_config, input_prompt, model_size):
model=gpt,
idx=text_to_token_ids(input_prompt, tokenizer),
max_new_tokens=30,
context_size=gpt_config["ctx_len"],
context_size=gpt_config["context_length"],
top_k=1,
temperature=1.0
)
@ -250,10 +250,10 @@ if __name__ == "__main__":
INPUT_PROMPT = "Every effort moves"
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"ctx_len": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {

View File

@ -166,8 +166,8 @@ def main(gpt_config, hparams):
train_loader = create_dataloader_v1(
text_data[:split_idx],
batch_size=hparams["batch_size"],
max_length=gpt_config["ctx_len"],
stride=gpt_config["ctx_len"],
max_length=gpt_config["context_length"],
stride=gpt_config["context_length"],
drop_last=True,
shuffle=True
)
@ -175,8 +175,8 @@ def main(gpt_config, hparams):
val_loader = create_dataloader_v1(
text_data[split_idx:],
batch_size=hparams["batch_size"],
max_length=gpt_config["ctx_len"],
stride=gpt_config["ctx_len"],
max_length=gpt_config["context_length"],
stride=gpt_config["context_length"],
drop_last=False,
shuffle=False
)
@ -197,13 +197,13 @@ def main(gpt_config, hparams):
if __name__ == "__main__":
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"ctx_len": 256, # Shortened context length (orig: 1024)
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-key-value bias
"vocab_size": 50257, # Vocabulary size
"context_length": 256, # Shortened context length (orig: 1024)
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-key-value bias
}
OTHER_HPARAMS = {

View File

@ -54,7 +54,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
@ -67,7 +67,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -156,7 +156,7 @@ class TransformerBlock(nn.Module):
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
block_size=cfg["ctx_len"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
@ -187,7 +187,7 @@ class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
@ -237,13 +237,13 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
if __name__ == "__main__":
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"ctx_len": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
torch.manual_seed(123)
@ -265,7 +265,7 @@ if __name__ == "__main__":
model=model,
idx=encoded_tensor,
max_new_tokens=10,
context_size=GPT_CONFIG_124M["ctx_len"]
context_size=GPT_CONFIG_124M["context_length"]
)
decoded_text = tokenizer.decode(out.squeeze(0).tolist())

View File

@ -13,10 +13,10 @@ from gpt_train import main
def gpt_config():
return {
"vocab_size": 50257,
"ctx_len": 12, # small for testing efficiency
"emb_dim": 32, # small for testing efficiency
"n_heads": 4, # small for testing efficiency
"n_layers": 2, # small for testing efficiency
"context_length": 12, # small for testing efficiency
"emb_dim": 32, # small for testing efficiency
"n_heads": 4, # small for testing efficiency
"n_layers": 2, # small for testing efficiency
"drop_rate": 0.1,
"qkv_bias": False
}

View File

@ -54,7 +54,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
@ -67,7 +67,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -156,7 +156,7 @@ class TransformerBlock(nn.Module):
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
block_size=cfg["ctx_len"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
@ -187,7 +187,7 @@ class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(

View File

@ -146,10 +146,10 @@
"outputs": [],
"source": [
"BASE_CONFIG = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 1024, # Context length\n",
" \"drop_rate\": 0.0, # Dropout rate\n",
" \"qkv_bias\": True # Query-key-value bias\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 1024, # Context length\n",
" \"drop_rate\": 0.0, # Dropout rate\n",
" \"qkv_bias\": True # Query-key-value bias\n",
"}\n",
"\n",
"model_configs = {\n",
@ -279,7 +279,7 @@
" model=gpt,\n",
" idx=text_to_token_ids(\"Every effort moves\", tokenizer),\n",
" max_new_tokens=30,\n",
" context_size=BASE_CONFIG[\"ctx_len\"],\n",
" context_size=BASE_CONFIG[\"context_length\"],\n",
" top_k=1,\n",
" temperature=1.0\n",
")\n",
@ -304,7 +304,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -100,8 +100,8 @@ def train_model_simple(model, optimizer, device, n_epochs,
text_data,
train_ratio=train_ratio,
batch_size=batch_size,
max_length=GPT_CONFIG_124M["ctx_len"],
stride=GPT_CONFIG_124M["ctx_len"]
max_length=GPT_CONFIG_124M["context_length"],
stride=GPT_CONFIG_124M["context_length"]
)
print("Training ...")
model.train()
@ -168,13 +168,13 @@ if __name__ == "__main__":
args = parser.parse_args()
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"ctx_len": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-key-value bias
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-key-value bias
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

View File

@ -55,7 +55,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
@ -68,7 +68,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -158,7 +158,7 @@ class TransformerBlock(nn.Module):
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
block_size=cfg["ctx_len"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
@ -189,7 +189,7 @@ class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(

View File

@ -139,21 +139,21 @@ if __name__ == "__main__":
HPARAM_CONFIG = dict(zip(HPARAM_GRID.keys(), combination))
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"ctx_len": 256, # Context length -- shortened from original 1024 tokens
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"vocab_size": 50257, # Vocabulary size
"context_length": 256, # Context length -- shortened from original 1024 tokens
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": HPARAM_CONFIG["drop_rate"],
"qkv_bias": False, # Query-Key-Value bias
"qkv_bias": False, # Query-Key-Value bias
}
torch.manual_seed(123)
train_loader = create_dataloader_v1(
text_data[:split_idx],
batch_size=HPARAM_CONFIG["batch_size"],
max_length=GPT_CONFIG_124M["ctx_len"],
stride=GPT_CONFIG_124M["ctx_len"],
max_length=GPT_CONFIG_124M["context_length"],
stride=GPT_CONFIG_124M["context_length"],
drop_last=True,
shuffle=True
)
@ -161,8 +161,8 @@ if __name__ == "__main__":
val_loader = create_dataloader_v1(
text_data[split_idx:],
batch_size=HPARAM_CONFIG["batch_size"],
max_length=GPT_CONFIG_124M["ctx_len"],
stride=GPT_CONFIG_124M["ctx_len"],
max_length=GPT_CONFIG_124M["context_length"],
stride=GPT_CONFIG_124M["context_length"],
drop_last=False,
shuffle=False
)

View File

@ -59,7 +59,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
@ -72,7 +72,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -161,7 +161,7 @@ class TransformerBlock(nn.Module):
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
block_size=cfg["ctx_len"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
@ -192,7 +192,7 @@ class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
@ -242,13 +242,13 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
if __name__ == "__main__":
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"ctx_len": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
torch.manual_seed(123)
@ -270,7 +270,7 @@ if __name__ == "__main__":
model=model,
idx=encoded_tensor,
max_new_tokens=10,
context_size=GPT_CONFIG_124M["ctx_len"]
context_size=GPT_CONFIG_124M["context_length"]
)
decoded_text = tokenizer.decode(out.squeeze(0).tolist())