From adc2964fc51523aecd7d11f487ca3a5422e6156a Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 4 Apr 2024 20:54:09 -0500 Subject: [PATCH] Fix Loss in Gutenberg bonus section (#109) --- ch05/01_main-chapter-code/ch05.ipynb | 6 ++-- .../pretraining_simple.py | 34 +++++++++++++------ .../previous_chapters.py | 2 +- .../tests.py | 32 +++++++++++++++++ 4 files changed, 58 insertions(+), 16 deletions(-) create mode 100644 ch05/03_bonus_pretraining_on_gutenberg/tests.py diff --git a/ch05/01_main-chapter-code/ch05.ipynb b/ch05/01_main-chapter-code/ch05.ipynb index 8c05cf7..e93fb39 100644 --- a/ch05/01_main-chapter-code/ch05.ipynb +++ b/ch05/01_main-chapter-code/ch05.ipynb @@ -1081,10 +1081,8 @@ "source": [ "def calc_loss_batch(input_batch, target_batch, model, device):\n", " input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", - "\n", " logits = model(input_batch)\n", - " logits = logits.flatten(0, 1)\n", - " loss = torch.nn.functional.cross_entropy(logits, target_batch.flatten())\n", + " loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())\n", " return loss\n", "\n", "\n", @@ -2403,7 +2401,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py b/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py index 004d79f..8c83b1f 100644 --- a/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py +++ b/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py @@ -164,18 +164,32 @@ if __name__ == "__main__": help='Learning rate for the optimizer') parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training') + parser.add_argument('--debug', type=bool, default=False, + help='Uses a very small model for debugging purposes') args = parser.parse_args() - GPT_CONFIG_124M = { - "vocab_size": 50257, # Vocabulary size - "context_length": 1024, # Context length - "emb_dim": 768, # Embedding dimension - "n_heads": 12, # Number of attention heads - "n_layers": 12, # Number of layers - "drop_rate": 0.1, # Dropout rate - "qkv_bias": False # Query-key-value bias - } + if args.debug: + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": 10, # Context length + "emb_dim": 12, # Embedding dimension + "n_heads": 2, # Number of attention heads + "n_layers": 2, # Number of layers + "drop_rate": 0.0, # Dropout rate + "qkv_bias": False # Query-key-value bias + } + + else: + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": 1024, # Context length + "emb_dim": 768, # Embedding dimension + "n_heads": 12, # Number of attention heads + "n_layers": 12, # Number of layers + "drop_rate": 0.1, # Dropout rate + "qkv_bias": False # Query-key-value bias + } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(123) @@ -210,8 +224,6 @@ if __name__ == "__main__": ) epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses)) - - print("debug", epochs_tensor, tokens_seen, train_losses, val_losses, output_dir) plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir) torch.save(model.state_dict(), output_dir / "model_pg_final.pth") diff --git a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py index 8a9f6bc..38edd4a 100644 --- a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py +++ b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py @@ -244,7 +244,7 @@ def generate_text_simple(model, idx, max_new_tokens, context_size): def calc_loss_batch(input_batch, target_batch, model, device): input_batch, target_batch = input_batch.to(device), target_batch.to(device) logits = model(input_batch) - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -1), target_batch.flatten()) + loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten()) return loss diff --git a/ch05/03_bonus_pretraining_on_gutenberg/tests.py b/ch05/03_bonus_pretraining_on_gutenberg/tests.py new file mode 100644 index 0000000..2a9f80a --- /dev/null +++ b/ch05/03_bonus_pretraining_on_gutenberg/tests.py @@ -0,0 +1,32 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# File for internal use (unit tests) + +from pathlib import Path +import os +import subprocess + + +def test_pretraining(): + + sequence = "a b c d" + repetitions = 1000 + content = sequence * repetitions + + folder_path = Path("gutenberg") / "data" + file_name = "repeated_sequence.txt" + + os.makedirs(folder_path, exist_ok=True) + + with open(folder_path/file_name, "w") as file: + file.write(content) + + result = subprocess.run( + ["python", "pretraining_simple.py", "--debug", "true"], + capture_output=True, text=True + ) + print(result.stdout) + assert "Maximum GPU memory allocated" in result.stdout