mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-25 16:17:10 +00:00
Make datesets and loaders compatible with multiprocessing (#118)
This commit is contained in:
parent
8fe63a9a0e
commit
bae4b0fb08
3
.gitignore
vendored
3
.gitignore
vendored
@ -14,7 +14,8 @@ ch05/01_main-chapter-code/model.pth
|
|||||||
ch05/01_main-chapter-code/model_and_optimizer.pth
|
ch05/01_main-chapter-code/model_and_optimizer.pth
|
||||||
ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints
|
ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints
|
||||||
|
|
||||||
# Preprocessing output folders
|
# Datasets
|
||||||
|
ch05/03_bonus_pretraining_on_gutenberg/gutenberg
|
||||||
ch05/03_bonus_pretraining_on_gutenberg/gutenberg_preprocessed
|
ch05/03_bonus_pretraining_on_gutenberg/gutenberg_preprocessed
|
||||||
|
|
||||||
# Temporary OS-related files
|
# Temporary OS-related files
|
||||||
|
@ -47,7 +47,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"torch version: 2.2.1\n"
|
"torch version: 2.2.2\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -130,7 +130,8 @@
|
|||||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" drop_last=True,\n",
|
" drop_last=True,\n",
|
||||||
" shuffle=True\n",
|
" shuffle=True,\n",
|
||||||
|
" num_workers=0\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"val_loader = create_dataloader_v1(\n",
|
"val_loader = create_dataloader_v1(\n",
|
||||||
@ -139,7 +140,8 @@
|
|||||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" drop_last=False,\n",
|
" drop_last=False,\n",
|
||||||
" shuffle=False\n",
|
" shuffle=False,\n",
|
||||||
|
" num_workers=0\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -500,7 +502,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def train_model(model, train_loader, val_loader, optimizer, device, n_epochs,\n",
|
"def train_model(model, train_loader, val_loader, optimizer, device, n_epochs,\n",
|
||||||
" eval_freq, eval_iter, start_context, warmup_steps=10,\n",
|
" eval_freq, eval_iter, start_context, tokenizer, warmup_steps=10,\n",
|
||||||
" initial_lr=3e-05, min_lr=1e-6):\n",
|
" initial_lr=3e-05, min_lr=1e-6):\n",
|
||||||
"\n",
|
"\n",
|
||||||
" train_losses, val_losses, track_tokens_seen, track_lrs = [], [], [], []\n",
|
" train_losses, val_losses, track_tokens_seen, track_lrs = [], [], [], []\n",
|
||||||
@ -562,8 +564,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Generate and print a sample from the model to monitor progress\n",
|
" # Generate and print a sample from the model to monitor progress\n",
|
||||||
" generate_and_print_sample(\n",
|
" generate_and_print_sample(\n",
|
||||||
" model, train_loader.dataset.tokenizer,\n",
|
" model, tokenizer, device, start_context\n",
|
||||||
" device, start_context\n",
|
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return train_losses, val_losses, track_tokens_seen, track_lrs"
|
" return train_losses, val_losses, track_tokens_seen, track_lrs"
|
||||||
@ -625,18 +626,21 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"import tiktoken\n",
|
||||||
|
"\n",
|
||||||
"torch.manual_seed(123)\n",
|
"torch.manual_seed(123)\n",
|
||||||
"model = GPTModel(GPT_CONFIG_124M)\n",
|
"model = GPTModel(GPT_CONFIG_124M)\n",
|
||||||
"model.to(device)\n",
|
"model.to(device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"peak_lr = 5e-4\n",
|
"peak_lr = 5e-4\n",
|
||||||
"optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.1)\n",
|
"optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.1)\n",
|
||||||
|
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"n_epochs = 15\n",
|
"n_epochs = 15\n",
|
||||||
"train_losses, val_losses, tokens_seen, lrs = train_model(\n",
|
"train_losses, val_losses, tokens_seen, lrs = train_model(\n",
|
||||||
" model, train_loader, val_loader, optimizer, device, n_epochs=n_epochs,\n",
|
" model, train_loader, val_loader, optimizer, device, n_epochs=n_epochs,\n",
|
||||||
" eval_freq=5, eval_iter=1, start_context=\"Every effort moves you\",\n",
|
" eval_freq=5, eval_iter=1, start_context=\"Every effort moves you\",\n",
|
||||||
" warmup_steps=10, initial_lr=1e-5, min_lr=1e-5\n",
|
" tokenizer=tokenizer, warmup_steps=10, initial_lr=1e-5, min_lr=1e-5\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -705,7 +709,7 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"/var/folders/jg/tpqyh1fd5js5wsr1d138k3n40000gn/T/ipykernel_34986/3589549395.py:5: UserWarning: The figure layout has changed to tight\n",
|
"/var/folders/jg/tpqyh1fd5js5wsr1d138k3n40000gn/T/ipykernel_9436/3589549395.py:5: UserWarning: The figure layout has changed to tight\n",
|
||||||
" plt.tight_layout(); plt.savefig(\"3.pdf\")\n"
|
" plt.tight_layout(); plt.savefig(\"3.pdf\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -755,7 +759,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.6"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -20,12 +20,11 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
class GPTDatasetV1(Dataset):
|
class GPTDatasetV1(Dataset):
|
||||||
def __init__(self, txt, tokenizer, max_length, stride):
|
def __init__(self, txt, tokenizer, max_length, stride):
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
self.target_ids = []
|
self.target_ids = []
|
||||||
|
|
||||||
# Tokenize the entire text
|
# Tokenize the entire text
|
||||||
token_ids = self.tokenizer.encode(txt)
|
token_ids = tokenizer.encode(txt)
|
||||||
|
|
||||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||||
for i in range(0, len(token_ids) - max_length, stride):
|
for i in range(0, len(token_ids) - max_length, stride):
|
||||||
@ -42,7 +41,7 @@ class GPTDatasetV1(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||||
stride=128, shuffle=True, drop_last=True):
|
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
@ -51,7 +50,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
|||||||
|
|
||||||
# Create dataloader
|
# Create dataloader
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"torch version: 2.2.1\n",
|
"torch version: 2.2.2\n",
|
||||||
"tiktoken version: 0.5.1\n"
|
"tiktoken version: 0.5.1\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -724,7 +724,7 @@
|
|||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
||||||
"Cell \u001b[0;32mIn[16], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m SimpleTokenizerV1(vocab)\n\u001b[1;32m 3\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello, do you like tea. Is this-- a test?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
|
"Cell \u001b[0;32mIn[16], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m SimpleTokenizerV1(vocab)\n\u001b[1;32m 3\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello, do you like tea. Is this-- a test?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
"Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36mSimpleTokenizerV1.encode\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m 8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstr_to_int[s] \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
|
"Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36mSimpleTokenizerV1.encode\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m 8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstr_to_int\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpreprocessed\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
|
||||||
"Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m 8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstr_to_int\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
|
"Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m 8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstr_to_int\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
|
||||||
"\u001b[0;31mKeyError\u001b[0m: 'Hello'"
|
"\u001b[0;31mKeyError\u001b[0m: 'Hello'"
|
||||||
]
|
]
|
||||||
@ -957,7 +957,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 28,
|
"execution_count": 24,
|
||||||
"id": "ede1d41f-934b-4bf4-8184-54394a257a94",
|
"id": "ede1d41f-934b-4bf4-8184-54394a257a94",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -967,7 +967,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 29,
|
"execution_count": 25,
|
||||||
"id": "48967a77-7d17-42bf-9e92-fc619d63a59e",
|
"id": "48967a77-7d17-42bf-9e92-fc619d63a59e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -988,7 +988,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 30,
|
"execution_count": 26,
|
||||||
"id": "6ad3312f-a5f7-4efc-9d7d-8ea09d7b5128",
|
"id": "6ad3312f-a5f7-4efc-9d7d-8ea09d7b5128",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -998,7 +998,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 31,
|
"execution_count": 27,
|
||||||
"id": "5ff2cd85-7cfb-4325-b390-219938589428",
|
"id": "5ff2cd85-7cfb-4325-b390-219938589428",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1020,7 +1020,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 32,
|
"execution_count": 28,
|
||||||
"id": "d26a48bb-f82e-41a8-a955-a1c9cf9d50ab",
|
"id": "d26a48bb-f82e-41a8-a955-a1c9cf9d50ab",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1080,7 +1080,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 33,
|
"execution_count": 29,
|
||||||
"id": "848d5ade-fd1f-46c3-9e31-1426e315c71b",
|
"id": "848d5ade-fd1f-46c3-9e31-1426e315c71b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1111,7 +1111,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 34,
|
"execution_count": 30,
|
||||||
"id": "e84424a7-646d-45b6-99e3-80d15fb761f2",
|
"id": "e84424a7-646d-45b6-99e3-80d15fb761f2",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -1121,7 +1121,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 35,
|
"execution_count": 31,
|
||||||
"id": "dfbff852-a92f-48c8-a46d-143a0f109f40",
|
"id": "dfbff852-a92f-48c8-a46d-143a0f109f40",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1154,7 +1154,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 36,
|
"execution_count": 32,
|
||||||
"id": "d97b031e-ed55-409d-95f2-aeb38c6fe366",
|
"id": "d97b031e-ed55-409d-95f2-aeb38c6fe366",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1179,7 +1179,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 37,
|
"execution_count": 33,
|
||||||
"id": "f57bd746-dcbf-4433-8e24-ee213a8c34a1",
|
"id": "f57bd746-dcbf-4433-8e24-ee213a8c34a1",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1221,7 +1221,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 38,
|
"execution_count": 34,
|
||||||
"id": "e1770134-e7f3-4725-a679-e04c3be48cac",
|
"id": "e1770134-e7f3-4725-a679-e04c3be48cac",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1229,7 +1229,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"PyTorch version: 2.1.0\n"
|
"PyTorch version: 2.2.2\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -1258,7 +1258,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 39,
|
"execution_count": 35,
|
||||||
"id": "74b41073-4c9f-46e2-a1bd-d38e4122b375",
|
"id": "74b41073-4c9f-46e2-a1bd-d38e4122b375",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -1268,12 +1268,11 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"class GPTDatasetV1(Dataset):\n",
|
"class GPTDatasetV1(Dataset):\n",
|
||||||
" def __init__(self, txt, tokenizer, max_length, stride):\n",
|
" def __init__(self, txt, tokenizer, max_length, stride):\n",
|
||||||
" self.tokenizer = tokenizer\n",
|
|
||||||
" self.input_ids = []\n",
|
" self.input_ids = []\n",
|
||||||
" self.target_ids = []\n",
|
" self.target_ids = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Tokenize the entire text\n",
|
" # Tokenize the entire text\n",
|
||||||
" token_ids = self.tokenizer.encode(txt, allowed_special={'<|endoftext|>'})\n",
|
" token_ids = tokenizer.encode(txt, allowed_special={\"<|endoftext|>\"})\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Use a sliding window to chunk the book into overlapping sequences of max_length\n",
|
" # Use a sliding window to chunk the book into overlapping sequences of max_length\n",
|
||||||
" for i in range(0, len(token_ids) - max_length, stride):\n",
|
" for i in range(0, len(token_ids) - max_length, stride):\n",
|
||||||
@ -1291,12 +1290,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 40,
|
"execution_count": 36,
|
||||||
"id": "5eb30ebe-97b3-43c5-9ff1-a97d621b3c4e",
|
"id": "5eb30ebe-97b3-43c5-9ff1-a97d621b3c4e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True):\n",
|
"def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True, num_workers=0):\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Initialize the tokenizer\n",
|
" # Initialize the tokenizer\n",
|
||||||
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||||||
@ -1306,7 +1305,12 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Create dataloader\n",
|
" # Create dataloader\n",
|
||||||
" dataloader = DataLoader(\n",
|
" dataloader = DataLoader(\n",
|
||||||
" dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)\n",
|
" dataset,\n",
|
||||||
|
" batch_size=batch_size,\n",
|
||||||
|
" shuffle=shuffle,\n",
|
||||||
|
" drop_last=drop_last,\n",
|
||||||
|
" num_workers=0\n",
|
||||||
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return dataloader"
|
" return dataloader"
|
||||||
]
|
]
|
||||||
@ -1321,7 +1325,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 41,
|
"execution_count": 37,
|
||||||
"id": "df31d96c-6bfd-4564-a956-6192242d7579",
|
"id": "df31d96c-6bfd-4564-a956-6192242d7579",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -1332,7 +1336,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 42,
|
"execution_count": 38,
|
||||||
"id": "9226d00c-ad9a-4949-a6e4-9afccfc7214f",
|
"id": "9226d00c-ad9a-4949-a6e4-9afccfc7214f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1354,7 +1358,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 43,
|
"execution_count": 39,
|
||||||
"id": "10deb4bc-4de1-4d20-921e-4b1c7a0e1a6d",
|
"id": "10deb4bc-4de1-4d20-921e-4b1c7a0e1a6d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1398,7 +1402,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 44,
|
"execution_count": 40,
|
||||||
"id": "1916e7a6-f03d-4f09-91a6-d0bdbac5a58c",
|
"id": "1916e7a6-f03d-4f09-91a6-d0bdbac5a58c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1473,7 +1477,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 46,
|
"execution_count": 41,
|
||||||
"id": "15a6304c-9474-4470-b85d-3991a49fa653",
|
"id": "15a6304c-9474-4470-b85d-3991a49fa653",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -1491,7 +1495,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 47,
|
"execution_count": 42,
|
||||||
"id": "93cb2cee-9aa6-4bb8-8977-c65661d16eda",
|
"id": "93cb2cee-9aa6-4bb8-8977-c65661d16eda",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -1513,7 +1517,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 49,
|
"execution_count": 43,
|
||||||
"id": "a686eb61-e737-4351-8f1c-222913d47468",
|
"id": "a686eb61-e737-4351-8f1c-222913d47468",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1554,7 +1558,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 50,
|
"execution_count": 44,
|
||||||
"id": "e43600ba-f287-4746-8ddf-d0f71a9023ca",
|
"id": "e43600ba-f287-4746-8ddf-d0f71a9023ca",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1581,7 +1585,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 51,
|
"execution_count": 45,
|
||||||
"id": "50280ead-0363-44c8-8c35-bb885d92c8b7",
|
"id": "50280ead-0363-44c8-8c35-bb885d92c8b7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1874,7 +1878,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.10"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -31,7 +31,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 1,
|
||||||
"id": "0ed4b7db-3b47-4fd3-a4a6-5f4ed5dd166e",
|
"id": "0ed4b7db-3b47-4fd3-a4a6-5f4ed5dd166e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -43,12 +43,11 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"class GPTDatasetV1(Dataset):\n",
|
"class GPTDatasetV1(Dataset):\n",
|
||||||
" def __init__(self, txt, tokenizer, max_length, stride):\n",
|
" def __init__(self, txt, tokenizer, max_length, stride):\n",
|
||||||
" self.tokenizer = tokenizer\n",
|
|
||||||
" self.input_ids = []\n",
|
" self.input_ids = []\n",
|
||||||
" self.target_ids = []\n",
|
" self.target_ids = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Tokenize the entire text\n",
|
" # Tokenize the entire text\n",
|
||||||
" token_ids = self.tokenizer.encode(txt, allowed_special={'<|endoftext|>'})\n",
|
" token_ids = tokenizer.encode(txt, allowed_special={\"<|endoftext|>\"})\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Use a sliding window to chunk the book into overlapping sequences of max_length\n",
|
" # Use a sliding window to chunk the book into overlapping sequences of max_length\n",
|
||||||
" for i in range(0, len(token_ids) - max_length, stride):\n",
|
" for i in range(0, len(token_ids) - max_length, stride):\n",
|
||||||
@ -65,7 +64,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def create_dataloader_v1(txt, batch_size=4, max_length=256, \n",
|
"def create_dataloader_v1(txt, batch_size=4, max_length=256, \n",
|
||||||
" stride=128, shuffle=True, drop_last=True):\n",
|
" stride=128, shuffle=True, drop_last=True, num_workers=0):\n",
|
||||||
" # Initialize the tokenizer\n",
|
" # Initialize the tokenizer\n",
|
||||||
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -74,7 +73,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Create dataloader\n",
|
" # Create dataloader\n",
|
||||||
" dataloader = DataLoader(\n",
|
" dataloader = DataLoader(\n",
|
||||||
" dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)\n",
|
" dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return dataloader\n",
|
" return dataloader\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -99,7 +98,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
|
"id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -117,7 +116,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 3,
|
||||||
"id": "d3664332-e6bb-447e-8b96-203aafde8b24",
|
"id": "d3664332-e6bb-447e-8b96-203aafde8b24",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -150,7 +149,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.10"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -13,13 +13,12 @@ from torch.utils.data import Dataset, DataLoader
|
|||||||
|
|
||||||
|
|
||||||
class GPTDatasetV1(Dataset):
|
class GPTDatasetV1(Dataset):
|
||||||
def __init__(self, txt, tokenizer, max_length, stride):
|
def __init__(self, txt, tokenizer, max_length, stride, num_workers=0):
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
self.target_ids = []
|
self.target_ids = []
|
||||||
|
|
||||||
# Tokenize the entire text
|
# Tokenize the entire text
|
||||||
token_ids = self.tokenizer.encode(txt)
|
token_ids = tokenizer.encode(txt)
|
||||||
|
|
||||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||||
for i in range(0, len(token_ids) - max_length, stride):
|
for i in range(0, len(token_ids) - max_length, stride):
|
||||||
@ -36,7 +35,7 @@ class GPTDatasetV1(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||||
stride=128, shuffle=True, drop_last=True):
|
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@ from torch.utils.data import Dataset, DataLoader
|
|||||||
|
|
||||||
class GPTDatasetV1(Dataset):
|
class GPTDatasetV1(Dataset):
|
||||||
def __init__(self, txt, tokenizer, max_length, stride):
|
def __init__(self, txt, tokenizer, max_length, stride):
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
self.target_ids = []
|
self.target_ids = []
|
||||||
|
|
||||||
@ -33,7 +32,7 @@ class GPTDatasetV1(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||||
stride=128, shuffle=True, drop_last=True):
|
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
@ -42,7 +41,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
|||||||
|
|
||||||
# Create dataloader
|
# Create dataloader
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
0
ch05/01_main-chapter-code/.gitignore
vendored
Normal file
0
ch05/01_main-chapter-code/.gitignore
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -473,7 +473,8 @@
|
|||||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" drop_last=True,\n",
|
" drop_last=True,\n",
|
||||||
" shuffle=True\n",
|
" shuffle=True,\n",
|
||||||
|
" num_workers=0\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"val_loader = create_dataloader_v1(\n",
|
"val_loader = create_dataloader_v1(\n",
|
||||||
@ -482,7 +483,8 @@
|
|||||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" drop_last=False,\n",
|
" drop_last=False,\n",
|
||||||
" shuffle=False\n",
|
" shuffle=False,\n",
|
||||||
|
" num_workers=0\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -697,7 +699,8 @@
|
|||||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" drop_last=True,\n",
|
" drop_last=True,\n",
|
||||||
" shuffle=True\n",
|
" shuffle=True,\n",
|
||||||
|
" num_workers=0\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"val_loader = create_dataloader_v1(\n",
|
"val_loader = create_dataloader_v1(\n",
|
||||||
@ -706,7 +709,8 @@
|
|||||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||||
" drop_last=False,\n",
|
" drop_last=False,\n",
|
||||||
" shuffle=False\n",
|
" shuffle=False,\n",
|
||||||
|
" num_workers=0\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -945,7 +949,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
# Import from local files
|
# Import from local files
|
||||||
from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
|
from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
|
||||||
@ -69,7 +71,7 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
|
|||||||
|
|
||||||
|
|
||||||
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
||||||
eval_freq, eval_iter, start_context):
|
eval_freq, eval_iter, start_context, tokenizer):
|
||||||
# Initialize lists to track losses and tokens seen
|
# Initialize lists to track losses and tokens seen
|
||||||
train_losses, val_losses, track_tokens_seen = [], [], []
|
train_losses, val_losses, track_tokens_seen = [], [], []
|
||||||
tokens_seen = 0
|
tokens_seen = 0
|
||||||
@ -99,7 +101,7 @@ def train_model_simple(model, train_loader, val_loader, optimizer, device, num_e
|
|||||||
|
|
||||||
# Print a sample text after each epoch
|
# Print a sample text after each epoch
|
||||||
generate_and_print_sample(
|
generate_and_print_sample(
|
||||||
model, train_loader.dataset.tokenizer, device, start_context
|
model, tokenizer, device, start_context
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_losses, val_losses, track_tokens_seen
|
return train_losses, val_losses, track_tokens_seen
|
||||||
@ -169,7 +171,8 @@ def main(gpt_config, settings):
|
|||||||
max_length=gpt_config["context_length"],
|
max_length=gpt_config["context_length"],
|
||||||
stride=gpt_config["context_length"],
|
stride=gpt_config["context_length"],
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True
|
shuffle=True,
|
||||||
|
num_workers=0
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = create_dataloader_v1(
|
val_loader = create_dataloader_v1(
|
||||||
@ -178,17 +181,20 @@ def main(gpt_config, settings):
|
|||||||
max_length=gpt_config["context_length"],
|
max_length=gpt_config["context_length"],
|
||||||
stride=gpt_config["context_length"],
|
stride=gpt_config["context_length"],
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
shuffle=False
|
shuffle=False,
|
||||||
|
num_workers=0
|
||||||
)
|
)
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# Train model
|
# Train model
|
||||||
##############################
|
##############################
|
||||||
|
|
||||||
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
train_losses, val_losses, tokens_seen = train_model_simple(
|
train_losses, val_losses, tokens_seen = train_model_simple(
|
||||||
model, train_loader, val_loader, optimizer, device,
|
model, train_loader, val_loader, optimizer, device,
|
||||||
num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
|
num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
|
||||||
start_context="Every effort moves you",
|
start_context="Every effort moves you", tokenizer=tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_losses, val_losses, tokens_seen, model
|
return train_losses, val_losses, tokens_seen, model
|
||||||
|
@ -14,12 +14,11 @@ from torch.utils.data import Dataset, DataLoader
|
|||||||
|
|
||||||
class GPTDatasetV1(Dataset):
|
class GPTDatasetV1(Dataset):
|
||||||
def __init__(self, txt, tokenizer, max_length, stride):
|
def __init__(self, txt, tokenizer, max_length, stride):
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
self.target_ids = []
|
self.target_ids = []
|
||||||
|
|
||||||
# Tokenize the entire text
|
# Tokenize the entire text
|
||||||
token_ids = self.tokenizer.encode(txt)
|
token_ids = tokenizer.encode(txt)
|
||||||
|
|
||||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||||
for i in range(0, len(token_ids) - max_length, stride):
|
for i in range(0, len(token_ids) - max_length, stride):
|
||||||
@ -36,7 +35,7 @@ class GPTDatasetV1(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||||
stride=128, shuffle=True, drop_last=True):
|
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
@ -45,7 +44,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
|||||||
|
|
||||||
# Create dataloader
|
# Create dataloader
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
@ -14,12 +14,11 @@ from torch.utils.data import Dataset, DataLoader
|
|||||||
|
|
||||||
class GPTDatasetV1(Dataset):
|
class GPTDatasetV1(Dataset):
|
||||||
def __init__(self, txt, tokenizer, max_length, stride):
|
def __init__(self, txt, tokenizer, max_length, stride):
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
self.target_ids = []
|
self.target_ids = []
|
||||||
|
|
||||||
# Tokenize the entire text
|
# Tokenize the entire text
|
||||||
token_ids = self.tokenizer.encode(txt)
|
token_ids = tokenizer.encode(txt)
|
||||||
|
|
||||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||||
for i in range(0, len(token_ids) - max_length, stride):
|
for i in range(0, len(token_ids) - max_length, stride):
|
||||||
@ -36,7 +35,7 @@ class GPTDatasetV1(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||||
stride=128, shuffle=True, drop_last=True):
|
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
@ -45,7 +44,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
|||||||
|
|
||||||
# Create dataloader
|
# Create dataloader
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
from previous_chapters import (
|
from previous_chapters import (
|
||||||
create_dataloader_v1,
|
create_dataloader_v1,
|
||||||
@ -32,7 +33,7 @@ def read_text_file(file_path):
|
|||||||
return text_data
|
return text_data
|
||||||
|
|
||||||
|
|
||||||
def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride):
|
def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride, num_workers=0):
|
||||||
split_idx = int(train_ratio * len(text_data))
|
split_idx = int(train_ratio * len(text_data))
|
||||||
train_loader = create_dataloader_v1(
|
train_loader = create_dataloader_v1(
|
||||||
text_data[:split_idx],
|
text_data[:split_idx],
|
||||||
@ -40,7 +41,8 @@ def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride):
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True
|
shuffle=True,
|
||||||
|
num_workers=num_workers
|
||||||
)
|
)
|
||||||
val_loader = create_dataloader_v1(
|
val_loader = create_dataloader_v1(
|
||||||
text_data[split_idx:],
|
text_data[split_idx:],
|
||||||
@ -48,7 +50,8 @@ def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride):
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
shuffle=False
|
shuffle=False,
|
||||||
|
num_workers=num_workers
|
||||||
)
|
)
|
||||||
return train_loader, val_loader
|
return train_loader, val_loader
|
||||||
|
|
||||||
@ -78,7 +81,7 @@ def print_eta(start_time, book_start_time, index, total_files):
|
|||||||
|
|
||||||
def train_model_simple(model, optimizer, device, n_epochs,
|
def train_model_simple(model, optimizer, device, n_epochs,
|
||||||
eval_freq, eval_iter, print_sample_iter, start_context,
|
eval_freq, eval_iter, print_sample_iter, start_context,
|
||||||
output_dir, save_ckpt_freq,
|
output_dir, save_ckpt_freq, tokenizer,
|
||||||
batch_size=1024, train_ratio=0.90):
|
batch_size=1024, train_ratio=0.90):
|
||||||
|
|
||||||
train_losses, val_losses, track_tokens_seen = [], [], []
|
train_losses, val_losses, track_tokens_seen = [], [], []
|
||||||
@ -101,7 +104,8 @@ def train_model_simple(model, optimizer, device, n_epochs,
|
|||||||
train_ratio=train_ratio,
|
train_ratio=train_ratio,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=GPT_CONFIG_124M["context_length"],
|
max_length=GPT_CONFIG_124M["context_length"],
|
||||||
stride=GPT_CONFIG_124M["context_length"]
|
stride=GPT_CONFIG_124M["context_length"],
|
||||||
|
num_workers=0
|
||||||
)
|
)
|
||||||
print("Training ...")
|
print("Training ...")
|
||||||
model.train()
|
model.train()
|
||||||
@ -126,7 +130,7 @@ def train_model_simple(model, optimizer, device, n_epochs,
|
|||||||
# Generate text passage
|
# Generate text passage
|
||||||
if global_step % print_sample_iter == 0:
|
if global_step % print_sample_iter == 0:
|
||||||
generate_and_print_sample(
|
generate_and_print_sample(
|
||||||
model, train_loader.dataset.tokenizer, device, start_context
|
model, tokenizer, device, start_context
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_step % save_ckpt_freq:
|
if global_step % save_ckpt_freq:
|
||||||
@ -196,6 +200,7 @@ if __name__ == "__main__":
|
|||||||
model = GPTModel(GPT_CONFIG_124M)
|
model = GPTModel(GPT_CONFIG_124M)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1)
|
||||||
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
data_dir = args.data_dir
|
data_dir = args.data_dir
|
||||||
all_files = [os.path.join(path, name) for path, subdirs, files
|
all_files = [os.path.join(path, name) for path, subdirs, files
|
||||||
@ -221,6 +226,7 @@ if __name__ == "__main__":
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
save_ckpt_freq=args.save_ckpt_freq,
|
save_ckpt_freq=args.save_ckpt_freq,
|
||||||
start_context="Every effort moves you",
|
start_context="Every effort moves you",
|
||||||
|
tokenizer=tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses))
|
epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses))
|
||||||
|
@ -21,11 +21,10 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
class GPTDatasetV1(Dataset):
|
class GPTDatasetV1(Dataset):
|
||||||
def __init__(self, txt, tokenizer, max_length, stride):
|
def __init__(self, txt, tokenizer, max_length, stride):
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
self.target_ids = []
|
self.target_ids = []
|
||||||
|
|
||||||
token_ids = self.tokenizer.encode(txt, allowed_special={'<|endoftext|>'})
|
token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})
|
||||||
|
|
||||||
for i in range(0, len(token_ids) - max_length, stride):
|
for i in range(0, len(token_ids) - max_length, stride):
|
||||||
input_chunk = token_ids[i:i + max_length]
|
input_chunk = token_ids[i:i + max_length]
|
||||||
@ -41,11 +40,11 @@ class GPTDatasetV1(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||||
stride=128, shuffle=True, drop_last=True):
|
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
|
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
from previous_chapters import GPTModel, create_dataloader_v1
|
from previous_chapters import GPTModel, create_dataloader_v1
|
||||||
|
|
||||||
@ -58,7 +59,7 @@ def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
|||||||
|
|
||||||
def train_model(model, train_loader, val_loader, optimizer, device,
|
def train_model(model, train_loader, val_loader, optimizer, device,
|
||||||
n_epochs, eval_freq, eval_iter,
|
n_epochs, eval_freq, eval_iter,
|
||||||
encoded_start_context, warmup_iters=10,
|
encoded_start_context, tokenizer, warmup_iters=10,
|
||||||
initial_lr=3e-05, min_lr=1e-6):
|
initial_lr=3e-05, min_lr=1e-6):
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
@ -120,6 +121,7 @@ if __name__ == "__main__":
|
|||||||
with open(os.path.join(script_dir, "the-verdict.txt"), "r", encoding="utf-8") as file:
|
with open(os.path.join(script_dir, "the-verdict.txt"), "r", encoding="utf-8") as file:
|
||||||
text_data = file.read()
|
text_data = file.read()
|
||||||
|
|
||||||
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
train_ratio = 0.95
|
train_ratio = 0.95
|
||||||
@ -155,7 +157,8 @@ if __name__ == "__main__":
|
|||||||
max_length=GPT_CONFIG_124M["context_length"],
|
max_length=GPT_CONFIG_124M["context_length"],
|
||||||
stride=GPT_CONFIG_124M["context_length"],
|
stride=GPT_CONFIG_124M["context_length"],
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True
|
shuffle=True,
|
||||||
|
num_workers=0
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = create_dataloader_v1(
|
val_loader = create_dataloader_v1(
|
||||||
@ -164,7 +167,8 @@ if __name__ == "__main__":
|
|||||||
max_length=GPT_CONFIG_124M["context_length"],
|
max_length=GPT_CONFIG_124M["context_length"],
|
||||||
stride=GPT_CONFIG_124M["context_length"],
|
stride=GPT_CONFIG_124M["context_length"],
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
shuffle=False
|
shuffle=False,
|
||||||
|
num_workers=0
|
||||||
)
|
)
|
||||||
|
|
||||||
model = GPTModel(GPT_CONFIG_124M)
|
model = GPTModel(GPT_CONFIG_124M)
|
||||||
@ -176,7 +180,7 @@ if __name__ == "__main__":
|
|||||||
weight_decay=HPARAM_CONFIG["weight_decay"]
|
weight_decay=HPARAM_CONFIG["weight_decay"]
|
||||||
)
|
)
|
||||||
|
|
||||||
encoded_start_context = train_loader.dataset.tokenizer.encode("Nevertheless")
|
encoded_start_context = tokenizer.encode("Nevertheless")
|
||||||
encoded_tensor = torch.tensor(encoded_start_context).unsqueeze(0)
|
encoded_tensor = torch.tensor(encoded_start_context).unsqueeze(0)
|
||||||
|
|
||||||
train_loss, val_loss = train_model(
|
train_loss, val_loss = train_model(
|
||||||
@ -184,6 +188,7 @@ if __name__ == "__main__":
|
|||||||
n_epochs=HPARAM_CONFIG["n_epochs"],
|
n_epochs=HPARAM_CONFIG["n_epochs"],
|
||||||
eval_freq=5, eval_iter=1,
|
eval_freq=5, eval_iter=1,
|
||||||
encoded_start_context=encoded_tensor,
|
encoded_start_context=encoded_tensor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
warmup_iters=HPARAM_CONFIG["warmup_iters"],
|
warmup_iters=HPARAM_CONFIG["warmup_iters"],
|
||||||
initial_lr=HPARAM_CONFIG["initial_lr"],
|
initial_lr=HPARAM_CONFIG["initial_lr"],
|
||||||
min_lr=HPARAM_CONFIG["min_lr"]
|
min_lr=HPARAM_CONFIG["min_lr"]
|
||||||
|
@ -19,12 +19,11 @@ from torch.utils.data import Dataset, DataLoader
|
|||||||
|
|
||||||
class GPTDatasetV1(Dataset):
|
class GPTDatasetV1(Dataset):
|
||||||
def __init__(self, txt, tokenizer, max_length, stride):
|
def __init__(self, txt, tokenizer, max_length, stride):
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
self.target_ids = []
|
self.target_ids = []
|
||||||
|
|
||||||
# Tokenize the entire text
|
# Tokenize the entire text
|
||||||
token_ids = self.tokenizer.encode(txt)
|
token_ids = tokenizer.encode(txt)
|
||||||
|
|
||||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||||
for i in range(0, len(token_ids) - max_length, stride):
|
for i in range(0, len(token_ids) - max_length, stride):
|
||||||
@ -46,11 +45,11 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
|||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
# Create dataset
|
# Create dataset
|
||||||
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
|
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride, num_workers=0)
|
||||||
|
|
||||||
# Create dataloader
|
# Create dataloader
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user