From ee8efcbcf6a57ed94a3a097a8498bcebd8c218b8 Mon Sep 17 00:00:00 2001 From: rasbt Date: Thu, 14 Mar 2024 07:41:40 -0500 Subject: [PATCH] fix plotting --- .../pretraining_simple.py | 18 ++++++++++-------- .../previous_chapters.py | 5 +++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py b/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py index 8a738d5..d1763da 100644 --- a/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py +++ b/ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py @@ -119,11 +119,11 @@ def train_model_simple(model, optimizer, device, n_epochs, print(f"Ep {epoch+1} (Step {global_step}): " f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") - # Generate text passage - if index % print_sample_iter == 0: - generate_and_print_sample( - model, train_loader.dataset.tokenizer, device, start_context - ) + # Generate text passage + if global_step % print_sample_iter == 0: + generate_and_print_sample( + model, train_loader.dataset.tokenizer, device, start_context + ) if global_step % save_ckpt_freq: file_name = output_dir / f"model_pg_{global_step}.pth" @@ -137,7 +137,7 @@ def train_model_simple(model, optimizer, device, n_epochs, torch.save(model.state_dict(), file_name) print(f"Saved {file_name}") - return train_losses, val_losses, tokens_seen + return train_losses, val_losses, track_tokens_seen if __name__ == "__main__": @@ -150,7 +150,7 @@ if __name__ == "__main__": help='Directory where the model checkpoints will be saved') parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train the model') - parser.add_argument('--print_sample_iter', type=int, default=500, + parser.add_argument('--print_sample_iter', type=int, default=1000, help='Iterations between printing sample outputs') parser.add_argument('--eval_freq', type=int, default=100, help='Frequency of evaluations during training') @@ -205,7 +205,9 @@ if __name__ == "__main__": start_context="Every effort moves you", ) - epochs_tensor = torch.linspace(1, args.n_epochs, len(train_losses)) + epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses)) + + print("debug", epochs_tensor, tokens_seen, train_losses, val_losses, output_dir) plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir) torch.save(model.state_dict(), output_dir / "model_pg_final.pth") diff --git a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py index 0cd8d02..4641ba4 100644 --- a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py +++ b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py @@ -274,8 +274,9 @@ def generate_and_print_sample(model, tokenizer, device, start_context): context_size = model.pos_emb.weight.shape[0] encoded = text_to_token_ids(start_context, tokenizer).to(device) with torch.no_grad(): - token_ids = generate_text_simple(model=model, idx=encoded, - max_new_tokens=50, context_size=context_size) + token_ids = generate_text_simple( + model=model, idx=encoded, + max_new_tokens=50, context_size=context_size) decoded_text = token_ids_to_text(token_ids, tokenizer) print(decoded_text.replace("\n", " ")) # Compact print format model.train()