fix plotting

This commit is contained in:
rasbt 2024-03-14 07:41:40 -05:00
parent f25760c394
commit ee8efcbcf6
2 changed files with 13 additions and 10 deletions

View File

@ -119,11 +119,11 @@ def train_model_simple(model, optimizer, device, n_epochs,
print(f"Ep {epoch+1} (Step {global_step}): " print(f"Ep {epoch+1} (Step {global_step}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
# Generate text passage # Generate text passage
if index % print_sample_iter == 0: if global_step % print_sample_iter == 0:
generate_and_print_sample( generate_and_print_sample(
model, train_loader.dataset.tokenizer, device, start_context model, train_loader.dataset.tokenizer, device, start_context
) )
if global_step % save_ckpt_freq: if global_step % save_ckpt_freq:
file_name = output_dir / f"model_pg_{global_step}.pth" file_name = output_dir / f"model_pg_{global_step}.pth"
@ -137,7 +137,7 @@ def train_model_simple(model, optimizer, device, n_epochs,
torch.save(model.state_dict(), file_name) torch.save(model.state_dict(), file_name)
print(f"Saved {file_name}") print(f"Saved {file_name}")
return train_losses, val_losses, tokens_seen return train_losses, val_losses, track_tokens_seen
if __name__ == "__main__": if __name__ == "__main__":
@ -150,7 +150,7 @@ if __name__ == "__main__":
help='Directory where the model checkpoints will be saved') help='Directory where the model checkpoints will be saved')
parser.add_argument('--n_epochs', type=int, default=1, parser.add_argument('--n_epochs', type=int, default=1,
help='Number of epochs to train the model') help='Number of epochs to train the model')
parser.add_argument('--print_sample_iter', type=int, default=500, parser.add_argument('--print_sample_iter', type=int, default=1000,
help='Iterations between printing sample outputs') help='Iterations between printing sample outputs')
parser.add_argument('--eval_freq', type=int, default=100, parser.add_argument('--eval_freq', type=int, default=100,
help='Frequency of evaluations during training') help='Frequency of evaluations during training')
@ -205,7 +205,9 @@ if __name__ == "__main__":
start_context="Every effort moves you", start_context="Every effort moves you",
) )
epochs_tensor = torch.linspace(1, args.n_epochs, len(train_losses)) epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses))
print("debug", epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir) plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
torch.save(model.state_dict(), output_dir / "model_pg_final.pth") torch.save(model.state_dict(), output_dir / "model_pg_final.pth")

View File

@ -274,8 +274,9 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
context_size = model.pos_emb.weight.shape[0] context_size = model.pos_emb.weight.shape[0]
encoded = text_to_token_ids(start_context, tokenizer).to(device) encoded = text_to_token_ids(start_context, tokenizer).to(device)
with torch.no_grad(): with torch.no_grad():
token_ids = generate_text_simple(model=model, idx=encoded, token_ids = generate_text_simple(
max_new_tokens=50, context_size=context_size) model=model, idx=encoded,
max_new_tokens=50, context_size=context_size)
decoded_text = token_ids_to_text(token_ids, tokenizer) decoded_text = token_ids_to_text(token_ids, tokenizer)
print(decoded_text.replace("\n", " ")) # Compact print format print(decoded_text.replace("\n", " ")) # Compact print format
model.train() model.train()