| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | 
					
						
							|  |  |  | # Source for "Build a Large Language Model From Scratch" | 
					
						
							|  |  |  | #   - https://www.manning.com/books/build-a-large-language-model-from-scratch | 
					
						
							|  |  |  | # Code: https://github.com/rasbt/LLMs-from-scratch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import matplotlib.pyplot as plt | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import urllib.request | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  | import tiktoken | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | # Import from local files | 
					
						
							|  |  |  | from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def text_to_token_ids(text, tokenizer): | 
					
						
							|  |  |  |     encoded = tokenizer.encode(text) | 
					
						
							|  |  |  |     encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension | 
					
						
							|  |  |  |     return encoded_tensor | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def token_ids_to_text(token_ids, tokenizer): | 
					
						
							|  |  |  |     flat = token_ids.squeeze(0)  # remove batch dimension | 
					
						
							|  |  |  |     return tokenizer.decode(flat.tolist()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def calc_loss_batch(input_batch, target_batch, model, device): | 
					
						
							|  |  |  |     input_batch, target_batch = input_batch.to(device), target_batch.to(device) | 
					
						
							|  |  |  |     logits = model(input_batch) | 
					
						
							| 
									
										
										
										
											2024-03-25 08:09:31 -05:00
										 |  |  |     loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten()) | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     return loss | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def calc_loss_loader(data_loader, model, device, num_batches=None): | 
					
						
							| 
									
										
										
										
											2024-03-26 20:34:50 -05:00
										 |  |  |     total_loss = 0. | 
					
						
							| 
									
										
										
										
											2024-04-20 08:02:30 -05:00
										 |  |  |     if len(data_loader) == 0: | 
					
						
							|  |  |  |         return float("nan") | 
					
						
							|  |  |  |     elif num_batches is None: | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |         num_batches = len(data_loader) | 
					
						
							| 
									
										
										
										
											2024-03-27 07:11:56 -05:00
										 |  |  |     else: | 
					
						
							|  |  |  |         num_batches = min(num_batches, len(data_loader)) | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     for i, (input_batch, target_batch) in enumerate(data_loader): | 
					
						
							|  |  |  |         if i < num_batches: | 
					
						
							|  |  |  |             loss = calc_loss_batch(input_batch, target_batch, model, device) | 
					
						
							|  |  |  |             total_loss += loss.item() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             break | 
					
						
							| 
									
										
										
										
											2024-03-26 20:34:50 -05:00
										 |  |  |     return total_loss / num_batches | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def evaluate_model(model, train_loader, val_loader, device, eval_iter): | 
					
						
							|  |  |  |     model.eval() | 
					
						
							|  |  |  |     with torch.no_grad(): | 
					
						
							|  |  |  |         train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) | 
					
						
							|  |  |  |         val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) | 
					
						
							|  |  |  |     model.train() | 
					
						
							|  |  |  |     return train_loss, val_loss | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def generate_and_print_sample(model, tokenizer, device, start_context): | 
					
						
							|  |  |  |     model.eval() | 
					
						
							|  |  |  |     context_size = model.pos_emb.weight.shape[0] | 
					
						
							|  |  |  |     encoded = text_to_token_ids(start_context, tokenizer).to(device) | 
					
						
							|  |  |  |     with torch.no_grad(): | 
					
						
							|  |  |  |         token_ids = generate_text_simple( | 
					
						
							|  |  |  |             model=model, idx=encoded, | 
					
						
							|  |  |  |             max_new_tokens=50, context_size=context_size | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         decoded_text = token_ids_to_text(token_ids, tokenizer) | 
					
						
							|  |  |  |         print(decoded_text.replace("\n", " "))  # Compact print format | 
					
						
							|  |  |  |     model.train() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs, | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |                        eval_freq, eval_iter, start_context, tokenizer): | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     # Initialize lists to track losses and tokens seen | 
					
						
							|  |  |  |     train_losses, val_losses, track_tokens_seen = [], [], [] | 
					
						
							|  |  |  |     tokens_seen = 0 | 
					
						
							|  |  |  |     global_step = -1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Main training loop | 
					
						
							|  |  |  |     for epoch in range(num_epochs): | 
					
						
							|  |  |  |         model.train()  # Set model to training mode | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for input_batch, target_batch in train_loader: | 
					
						
							| 
									
										
										
										
											2024-06-09 06:14:02 -05:00
										 |  |  |             optimizer.zero_grad()  # Reset loss gradients from previous batch iteration | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |             loss = calc_loss_batch(input_batch, target_batch, model, device) | 
					
						
							|  |  |  |             loss.backward()  # Calculate loss gradients | 
					
						
							|  |  |  |             optimizer.step()  # Update model weights using loss gradients | 
					
						
							|  |  |  |             tokens_seen += input_batch.numel() | 
					
						
							|  |  |  |             global_step += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Optional evaluation step | 
					
						
							|  |  |  |             if global_step % eval_freq == 0: | 
					
						
							|  |  |  |                 train_loss, val_loss = evaluate_model( | 
					
						
							|  |  |  |                     model, train_loader, val_loader, device, eval_iter) | 
					
						
							|  |  |  |                 train_losses.append(train_loss) | 
					
						
							|  |  |  |                 val_losses.append(val_loss) | 
					
						
							|  |  |  |                 track_tokens_seen.append(tokens_seen) | 
					
						
							|  |  |  |                 print(f"Ep {epoch+1} (Step {global_step:06d}): " | 
					
						
							|  |  |  |                       f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Print a sample text after each epoch | 
					
						
							|  |  |  |         generate_and_print_sample( | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |             model, tokenizer, device, start_context | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return train_losses, val_losses, track_tokens_seen | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses): | 
					
						
							|  |  |  |     fig, ax1 = plt.subplots() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Plot training and validation loss against epochs | 
					
						
							|  |  |  |     ax1.plot(epochs_seen, train_losses, label="Training loss") | 
					
						
							|  |  |  |     ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss") | 
					
						
							|  |  |  |     ax1.set_xlabel("Epochs") | 
					
						
							|  |  |  |     ax1.set_ylabel("Loss") | 
					
						
							|  |  |  |     ax1.legend(loc="upper right") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Create a second x-axis for tokens seen | 
					
						
							|  |  |  |     ax2 = ax1.twiny()  # Create a second x-axis that shares the same y-axis | 
					
						
							|  |  |  |     ax2.plot(tokens_seen, train_losses, alpha=0)  # Invisible plot for aligning ticks | 
					
						
							|  |  |  |     ax2.set_xlabel("Tokens seen") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     fig.tight_layout()  # Adjust layout to make room | 
					
						
							|  |  |  |     # plt.show() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-05 07:24:46 -05:00
										 |  |  | def main(gpt_config, settings): | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     torch.manual_seed(123) | 
					
						
							|  |  |  |     device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ############################## | 
					
						
							|  |  |  |     # Download data if necessary | 
					
						
							|  |  |  |     ############################## | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     file_path = "the-verdict.txt" | 
					
						
							|  |  |  |     url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not os.path.exists(file_path): | 
					
						
							|  |  |  |         with urllib.request.urlopen(url) as response: | 
					
						
							|  |  |  |             text_data = response.read().decode('utf-8') | 
					
						
							|  |  |  |         with open(file_path, "w", encoding="utf-8") as file: | 
					
						
							|  |  |  |             file.write(text_data) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         with open(file_path, "r", encoding="utf-8") as file: | 
					
						
							|  |  |  |             text_data = file.read() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ############################## | 
					
						
							|  |  |  |     # Initialize model | 
					
						
							|  |  |  |     ############################## | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model = GPTModel(gpt_config) | 
					
						
							|  |  |  |     model.to(device)  # no assignment model = model.to(device) necessary for nn.Module classes | 
					
						
							|  |  |  |     optimizer = torch.optim.AdamW( | 
					
						
							| 
									
										
										
										
											2024-04-05 07:24:46 -05:00
										 |  |  |         model.parameters(), lr=settings["learning_rate"], weight_decay=settings["weight_decay"] | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ############################## | 
					
						
							|  |  |  |     # Set up dataloaders | 
					
						
							|  |  |  |     ############################## | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Train/validation ratio | 
					
						
							|  |  |  |     train_ratio = 0.90 | 
					
						
							|  |  |  |     split_idx = int(train_ratio * len(text_data)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     train_loader = create_dataloader_v1( | 
					
						
							|  |  |  |         text_data[:split_idx], | 
					
						
							| 
									
										
										
										
											2024-04-05 07:24:46 -05:00
										 |  |  |         batch_size=settings["batch_size"], | 
					
						
							| 
									
										
										
										
											2024-04-04 07:27:41 -05:00
										 |  |  |         max_length=gpt_config["context_length"], | 
					
						
							|  |  |  |         stride=gpt_config["context_length"], | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |         drop_last=True, | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |         shuffle=True, | 
					
						
							|  |  |  |         num_workers=0 | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     val_loader = create_dataloader_v1( | 
					
						
							|  |  |  |         text_data[split_idx:], | 
					
						
							| 
									
										
										
										
											2024-04-05 07:24:46 -05:00
										 |  |  |         batch_size=settings["batch_size"], | 
					
						
							| 
									
										
										
										
											2024-04-04 07:27:41 -05:00
										 |  |  |         max_length=gpt_config["context_length"], | 
					
						
							|  |  |  |         stride=gpt_config["context_length"], | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |         drop_last=False, | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |         shuffle=False, | 
					
						
							|  |  |  |         num_workers=0 | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ############################## | 
					
						
							|  |  |  |     # Train model | 
					
						
							|  |  |  |     ############################## | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |     tokenizer = tiktoken.get_encoding("gpt2") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     train_losses, val_losses, tokens_seen = train_model_simple( | 
					
						
							|  |  |  |         model, train_loader, val_loader, optimizer, device, | 
					
						
							| 
									
										
										
										
											2024-04-05 07:24:46 -05:00
										 |  |  |         num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1, | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |         start_context="Every effort moves you", tokenizer=tokenizer | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return train_losses, val_losses, tokens_seen, model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     GPT_CONFIG_124M = { | 
					
						
							| 
									
										
										
										
											2024-04-04 07:27:41 -05:00
										 |  |  |         "vocab_size": 50257,    # Vocabulary size | 
					
						
							|  |  |  |         "context_length": 256,  # Shortened context length (orig: 1024) | 
					
						
							|  |  |  |         "emb_dim": 768,         # Embedding dimension | 
					
						
							|  |  |  |         "n_heads": 12,          # Number of attention heads | 
					
						
							|  |  |  |         "n_layers": 12,         # Number of layers | 
					
						
							|  |  |  |         "drop_rate": 0.1,       # Dropout rate | 
					
						
							|  |  |  |         "qkv_bias": False       # Query-key-value bias | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-05 07:24:46 -05:00
										 |  |  |     OTHER_SETTINGS = { | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |         "learning_rate": 5e-4, | 
					
						
							|  |  |  |         "num_epochs": 10, | 
					
						
							|  |  |  |         "batch_size": 2, | 
					
						
							|  |  |  |         "weight_decay": 0.1 | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ########################### | 
					
						
							|  |  |  |     # Initiate training | 
					
						
							|  |  |  |     ########################### | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-05 07:24:46 -05:00
										 |  |  |     train_losses, val_losses, tokens_seen, model = main(GPT_CONFIG_124M, OTHER_SETTINGS) | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     ########################### | 
					
						
							|  |  |  |     # After training | 
					
						
							|  |  |  |     ########################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Plot results | 
					
						
							| 
									
										
										
										
											2024-04-05 07:24:46 -05:00
										 |  |  |     epochs_tensor = torch.linspace(0, OTHER_SETTINGS["num_epochs"], len(train_losses)) | 
					
						
							| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  |     plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses) | 
					
						
							|  |  |  |     plt.savefig("loss.pdf") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Save and load model | 
					
						
							|  |  |  |     torch.save(model.state_dict(), "model.pth") | 
					
						
							|  |  |  |     model = GPTModel(GPT_CONFIG_124M) | 
					
						
							| 
									
										
										
										
											2024-07-24 21:53:41 -05:00
										 |  |  |     model.load_state_dict(torch.load("model.pth"), weights_only=True) |