| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | 
					
						
							|  |  |  | # Source for "Build a Large Language Model From Scratch" | 
					
						
							|  |  |  | #   - https://www.manning.com/books/build-a-large-language-model-from-scratch | 
					
						
							|  |  |  | # Code: https://github.com/rasbt/LLMs-from-scratch | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  | import itertools | 
					
						
							|  |  |  | import math | 
					
						
							|  |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  | import tiktoken | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  | import torch | 
					
						
							|  |  |  | from previous_chapters import GPTModel, create_dataloader_v1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Define a grid of hyperparameters to search over | 
					
						
							|  |  |  | HPARAM_GRID = { | 
					
						
							|  |  |  |     "batch_size": [2, 4, 8, 16], | 
					
						
							|  |  |  |     "drop_rate": [0.0, 0.1, 0.2], | 
					
						
							|  |  |  |     "warmup_iters": [10, 20, 30], | 
					
						
							|  |  |  |     "weight_decay": [0.1, 0.01, 0.0], | 
					
						
							|  |  |  |     "peak_lr": [0.0001, 0.0005, 0.001, 0.005], | 
					
						
							|  |  |  |     "initial_lr": [0.00005, 0.0001], | 
					
						
							|  |  |  |     "min_lr": [0.00005, 0.00001, 0.0001], | 
					
						
							|  |  |  |     "n_epochs": [5, 10, 15, 20, 25], | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-26 20:34:50 -05:00
										 |  |  | def calc_loss_loader(data_loader, model, device, num_batches=None): | 
					
						
							|  |  |  |     total_loss = 0. | 
					
						
							| 
									
										
										
										
											2024-04-20 08:02:30 -05:00
										 |  |  |     if len(data_loader) == 0: | 
					
						
							|  |  |  |         return float("nan") | 
					
						
							|  |  |  |     elif num_batches is None: | 
					
						
							| 
									
										
										
										
											2024-03-26 20:34:50 -05:00
										 |  |  |         num_batches = len(data_loader) | 
					
						
							| 
									
										
										
										
											2024-03-27 07:11:56 -05:00
										 |  |  |     else: | 
					
						
							|  |  |  |         num_batches = min(num_batches, len(data_loader)) | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |     for i, (input_batch, target_batch) in enumerate(data_loader): | 
					
						
							| 
									
										
										
										
											2024-03-26 20:34:50 -05:00
										 |  |  |         if i < num_batches: | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |             loss = calc_loss_batch(input_batch, target_batch, model, device) | 
					
						
							|  |  |  |             total_loss += loss.item() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             break | 
					
						
							| 
									
										
										
										
											2024-03-26 20:34:50 -05:00
										 |  |  |     return total_loss / num_batches | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def calc_loss_batch(input_batch, target_batch, model, device): | 
					
						
							|  |  |  |     input_batch, target_batch = input_batch.to(device), target_batch.to(device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     logits = model(input_batch) | 
					
						
							|  |  |  |     logits = logits.view(-1, logits.size(-1)) | 
					
						
							|  |  |  |     loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1)) | 
					
						
							|  |  |  |     return loss | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def evaluate_model(model, train_loader, val_loader, device, eval_iter): | 
					
						
							|  |  |  |     model.eval() | 
					
						
							|  |  |  |     with torch.no_grad(): | 
					
						
							| 
									
										
										
										
											2024-06-12 03:59:48 +02:00
										 |  |  |         train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) | 
					
						
							|  |  |  |         val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |     model.train() | 
					
						
							|  |  |  |     return train_loss, val_loss | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def train_model(model, train_loader, val_loader, optimizer, device, | 
					
						
							|  |  |  |                 n_epochs, eval_freq, eval_iter, | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |                 encoded_start_context, tokenizer, warmup_iters=10, | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |                 initial_lr=3e-05, min_lr=1e-6): | 
					
						
							|  |  |  |     global_step = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-03 07:06:42 -05:00
										 |  |  |     max_lr = optimizer.param_groups[0]["lr"] | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Calculate total number of iterations | 
					
						
							|  |  |  |     total_training_iters = len(train_loader) * n_epochs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Calculate the learning rate increment at each step during warmup | 
					
						
							| 
									
										
										
										
											2024-06-03 07:06:42 -05:00
										 |  |  |     lr_increment = (optimizer.param_groups[0]["lr"] - initial_lr) / warmup_iters | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for epoch in range(n_epochs): | 
					
						
							|  |  |  |         model.train() | 
					
						
							|  |  |  |         for input_batch, target_batch in train_loader: | 
					
						
							|  |  |  |             optimizer.zero_grad() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Increment the global step at the beginning of the iteration | 
					
						
							|  |  |  |             global_step += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Warmup: adjust learning rate linearly | 
					
						
							| 
									
										
										
										
											2024-08-06 20:10:05 +08:00
										 |  |  |             if global_step <= warmup_iters: | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |                 lr = initial_lr + global_step * lr_increment | 
					
						
							|  |  |  |             # Cosine annealing phase | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 progress = (global_step - warmup_iters) / (total_training_iters - warmup_iters) | 
					
						
							|  |  |  |                 lr = min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Apply the calculated learning rate | 
					
						
							|  |  |  |             for param_group in optimizer.param_groups: | 
					
						
							|  |  |  |                 param_group["lr"] = lr | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             loss = calc_loss_batch(input_batch, target_batch, model, device) | 
					
						
							|  |  |  |             loss.backward() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Apply gradient clipping | 
					
						
							|  |  |  |             if global_step >= warmup_iters: | 
					
						
							|  |  |  |                 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             optimizer.step() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return train_loss, val_loss | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Generate all combinations of hyperparameters | 
					
						
							|  |  |  |     hyperparameter_combinations = list(itertools.product(*HPARAM_GRID.values())) | 
					
						
							|  |  |  |     total_combinations = len(hyperparameter_combinations) | 
					
						
							|  |  |  |     print(f"Total hyperparameter configurations: {total_combinations}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Placeholder for the best loss and best hyperparameters | 
					
						
							|  |  |  |     best_val_loss = float('inf') | 
					
						
							|  |  |  |     best_hparams = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     script_path = os.path.abspath(__file__) | 
					
						
							|  |  |  |     script_dir = os.path.dirname(script_path) | 
					
						
							|  |  |  |     with open(os.path.join(script_dir, "the-verdict.txt"), "r", encoding="utf-8") as file: | 
					
						
							|  |  |  |         text_data = file.read() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |     tokenizer = tiktoken.get_encoding("gpt2") | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |     device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     train_ratio = 0.95 | 
					
						
							|  |  |  |     split_idx = int(train_ratio * len(text_data)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     torch.manual_seed(123) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     interrupted = False | 
					
						
							|  |  |  |     current_config = 0 | 
					
						
							|  |  |  |     for combination in hyperparameter_combinations: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             current_config += 1 | 
					
						
							|  |  |  |             print(f"Evaluating configuration {current_config} of {total_combinations}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Unpack the current combination of hyperparameters | 
					
						
							|  |  |  |             HPARAM_CONFIG = dict(zip(HPARAM_GRID.keys(), combination)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             GPT_CONFIG_124M = { | 
					
						
							| 
									
										
										
										
											2024-04-04 07:27:41 -05:00
										 |  |  |                 "vocab_size": 50257,    # Vocabulary size | 
					
						
							|  |  |  |                 "context_length": 256,  # Context length -- shortened from original 1024 tokens | 
					
						
							|  |  |  |                 "emb_dim": 768,         # Embedding dimension | 
					
						
							|  |  |  |                 "n_heads": 12,          # Number of attention heads | 
					
						
							|  |  |  |                 "n_layers": 12,         # Number of layers | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |                 "drop_rate": HPARAM_CONFIG["drop_rate"], | 
					
						
							| 
									
										
										
										
											2024-04-04 07:27:41 -05:00
										 |  |  |                 "qkv_bias": False,     # Query-Key-Value bias | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             torch.manual_seed(123) | 
					
						
							|  |  |  |             train_loader = create_dataloader_v1( | 
					
						
							|  |  |  |                 text_data[:split_idx], | 
					
						
							|  |  |  |                 batch_size=HPARAM_CONFIG["batch_size"], | 
					
						
							| 
									
										
										
										
											2024-04-04 07:27:41 -05:00
										 |  |  |                 max_length=GPT_CONFIG_124M["context_length"], | 
					
						
							|  |  |  |                 stride=GPT_CONFIG_124M["context_length"], | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |                 drop_last=True, | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |                 shuffle=True, | 
					
						
							|  |  |  |                 num_workers=0 | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             val_loader = create_dataloader_v1( | 
					
						
							|  |  |  |                 text_data[split_idx:], | 
					
						
							|  |  |  |                 batch_size=HPARAM_CONFIG["batch_size"], | 
					
						
							| 
									
										
										
										
											2024-04-04 07:27:41 -05:00
										 |  |  |                 max_length=GPT_CONFIG_124M["context_length"], | 
					
						
							|  |  |  |                 stride=GPT_CONFIG_124M["context_length"], | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |                 drop_last=False, | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |                 shuffle=False, | 
					
						
							|  |  |  |                 num_workers=0 | 
					
						
							| 
									
										
										
										
											2024-03-18 08:16:17 -05:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  |             model = GPTModel(GPT_CONFIG_124M) | 
					
						
							|  |  |  |             model.to(device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             optimizer = torch.optim.AdamW( | 
					
						
							|  |  |  |                 model.parameters(), | 
					
						
							|  |  |  |                 lr=HPARAM_CONFIG["peak_lr"], | 
					
						
							|  |  |  |                 weight_decay=HPARAM_CONFIG["weight_decay"] | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |             encoded_start_context = tokenizer.encode("Nevertheless") | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |             encoded_tensor = torch.tensor(encoded_start_context).unsqueeze(0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             train_loss, val_loss = train_model( | 
					
						
							|  |  |  |                 model, train_loader, val_loader, optimizer, device, | 
					
						
							|  |  |  |                 n_epochs=HPARAM_CONFIG["n_epochs"], | 
					
						
							|  |  |  |                 eval_freq=5, eval_iter=1, | 
					
						
							|  |  |  |                 encoded_start_context=encoded_tensor, | 
					
						
							| 
									
										
										
										
											2024-04-13 14:57:56 -04:00
										 |  |  |                 tokenizer=tokenizer, | 
					
						
							| 
									
										
										
										
											2024-02-27 08:51:03 -06:00
										 |  |  |                 warmup_iters=HPARAM_CONFIG["warmup_iters"], | 
					
						
							|  |  |  |                 initial_lr=HPARAM_CONFIG["initial_lr"], | 
					
						
							|  |  |  |                 min_lr=HPARAM_CONFIG["min_lr"] | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Log the best hyperparameters based on validation loss | 
					
						
							|  |  |  |             if val_loss < best_val_loss: | 
					
						
							|  |  |  |                 best_val_loss = val_loss | 
					
						
							|  |  |  |                 best_train_loss = train_loss | 
					
						
							|  |  |  |                 best_hparams = HPARAM_CONFIG | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except KeyboardInterrupt: | 
					
						
							|  |  |  |             print("Hyperparameter search completed.") | 
					
						
							|  |  |  |             print(f"Best hyperparameters: {best_hparams}") | 
					
						
							|  |  |  |             print(f"Best Val loss: {best_val_loss} | Training loss {train_loss}") | 
					
						
							|  |  |  |             interrupted = True | 
					
						
							|  |  |  |             break | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not interrupted: | 
					
						
							|  |  |  |         print("Hyperparameter search completed.") | 
					
						
							|  |  |  |         print(f"Best hyperparameters: {best_hparams}") | 
					
						
							| 
									
										
										
										
											2024-03-18 08:16:17 -05:00
										 |  |  |         print(f"Best Val loss: {best_val_loss} | Training loss {train_loss}") |