mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-22 14:44:06 +00:00
fix plotting
This commit is contained in:
parent
f25760c394
commit
ee8efcbcf6
@ -119,11 +119,11 @@ def train_model_simple(model, optimizer, device, n_epochs,
|
||||
print(f"Ep {epoch+1} (Step {global_step}): "
|
||||
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
|
||||
|
||||
# Generate text passage
|
||||
if index % print_sample_iter == 0:
|
||||
generate_and_print_sample(
|
||||
model, train_loader.dataset.tokenizer, device, start_context
|
||||
)
|
||||
# Generate text passage
|
||||
if global_step % print_sample_iter == 0:
|
||||
generate_and_print_sample(
|
||||
model, train_loader.dataset.tokenizer, device, start_context
|
||||
)
|
||||
|
||||
if global_step % save_ckpt_freq:
|
||||
file_name = output_dir / f"model_pg_{global_step}.pth"
|
||||
@ -137,7 +137,7 @@ def train_model_simple(model, optimizer, device, n_epochs,
|
||||
torch.save(model.state_dict(), file_name)
|
||||
print(f"Saved {file_name}")
|
||||
|
||||
return train_losses, val_losses, tokens_seen
|
||||
return train_losses, val_losses, track_tokens_seen
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -150,7 +150,7 @@ if __name__ == "__main__":
|
||||
help='Directory where the model checkpoints will be saved')
|
||||
parser.add_argument('--n_epochs', type=int, default=1,
|
||||
help='Number of epochs to train the model')
|
||||
parser.add_argument('--print_sample_iter', type=int, default=500,
|
||||
parser.add_argument('--print_sample_iter', type=int, default=1000,
|
||||
help='Iterations between printing sample outputs')
|
||||
parser.add_argument('--eval_freq', type=int, default=100,
|
||||
help='Frequency of evaluations during training')
|
||||
@ -205,7 +205,9 @@ if __name__ == "__main__":
|
||||
start_context="Every effort moves you",
|
||||
)
|
||||
|
||||
epochs_tensor = torch.linspace(1, args.n_epochs, len(train_losses))
|
||||
epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses))
|
||||
|
||||
print("debug", epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
|
||||
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
|
||||
|
||||
torch.save(model.state_dict(), output_dir / "model_pg_final.pth")
|
||||
|
@ -274,8 +274,9 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
|
||||
context_size = model.pos_emb.weight.shape[0]
|
||||
encoded = text_to_token_ids(start_context, tokenizer).to(device)
|
||||
with torch.no_grad():
|
||||
token_ids = generate_text_simple(model=model, idx=encoded,
|
||||
max_new_tokens=50, context_size=context_size)
|
||||
token_ids = generate_text_simple(
|
||||
model=model, idx=encoded,
|
||||
max_new_tokens=50, context_size=context_size)
|
||||
decoded_text = token_ids_to_text(token_ids, tokenizer)
|
||||
print(decoded_text.replace("\n", " ")) # Compact print format
|
||||
model.train()
|
||||
|
Loading…
x
Reference in New Issue
Block a user