| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | 
					
						
							|  |  |  | # Source for "Build a Large Language Model From Scratch" | 
					
						
							|  |  |  | #   - https://www.manning.com/books/build-a-large-language-model-from-scratch | 
					
						
							|  |  |  | # Code: https://github.com/rasbt/LLMs-from-scratch | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # A minimal instruction finetuning file based on the code in chapter 7 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from functools import partial | 
					
						
							|  |  |  | from importlib.metadata import version | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import re | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | import urllib | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import matplotlib.pyplot as plt | 
					
						
							|  |  |  | import tiktoken | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | from torch.utils.data import Dataset, DataLoader | 
					
						
							|  |  |  | from tqdm import tqdm | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Import from local files in this folder | 
					
						
							|  |  |  | from gpt_download import download_and_load_gpt2 | 
					
						
							|  |  |  | from previous_chapters import ( | 
					
						
							|  |  |  |     calc_loss_loader, | 
					
						
							|  |  |  |     generate, | 
					
						
							|  |  |  |     GPTModel, | 
					
						
							|  |  |  |     load_weights_into_gpt, | 
					
						
							|  |  |  |     text_to_token_ids, | 
					
						
							|  |  |  |     train_model_simple, | 
					
						
							|  |  |  |     token_ids_to_text | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class InstructionDataset(Dataset): | 
					
						
							|  |  |  |     def __init__(self, data, tokenizer): | 
					
						
							|  |  |  |         self.data = data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Pre-tokenize texts | 
					
						
							|  |  |  |         self.encoded_texts = [] | 
					
						
							|  |  |  |         for entry in data: | 
					
						
							|  |  |  |             instruction_plus_input = format_input(entry) | 
					
						
							|  |  |  |             response_text = f"\n\n### Response:\n{entry['output']}" | 
					
						
							|  |  |  |             full_text = instruction_plus_input + response_text | 
					
						
							|  |  |  |             self.encoded_texts.append( | 
					
						
							|  |  |  |                 tokenizer.encode(full_text) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __getitem__(self, index): | 
					
						
							|  |  |  |         return self.encoded_texts[index] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __len__(self): | 
					
						
							|  |  |  |         return len(self.data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def custom_collate_fn( | 
					
						
							|  |  |  |     batch, | 
					
						
							|  |  |  |     pad_token_id=50256, | 
					
						
							|  |  |  |     ignore_index=-100, | 
					
						
							|  |  |  |     allowed_max_length=None, | 
					
						
							|  |  |  |     device="cpu" | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     # Find the longest sequence in the batch | 
					
						
							|  |  |  |     batch_max_length = max(len(item)+1 for item in batch) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Pad and prepare inputs and targets | 
					
						
							|  |  |  |     inputs_lst, targets_lst = [], [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for item in batch: | 
					
						
							|  |  |  |         new_item = item.copy() | 
					
						
							|  |  |  |         # Add an <|endoftext|> token | 
					
						
							|  |  |  |         new_item += [pad_token_id] | 
					
						
							|  |  |  |         # Pad sequences to max_length | 
					
						
							|  |  |  |         padded = new_item + [pad_token_id] * (batch_max_length - len(new_item)) | 
					
						
							|  |  |  |         inputs = torch.tensor(padded[:-1])  # Truncate the last token for inputs | 
					
						
							|  |  |  |         targets = torch.tensor(padded[1:])  # Shift +1 to the right for targets | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # New: Replace all but the first padding tokens in targets by ignore_index | 
					
						
							|  |  |  |         mask = targets == pad_token_id | 
					
						
							|  |  |  |         indices = torch.nonzero(mask).squeeze() | 
					
						
							|  |  |  |         if indices.numel() > 1: | 
					
						
							|  |  |  |             targets[indices[1:]] = ignore_index | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # New: Optionally truncate to maximum sequence length | 
					
						
							|  |  |  |         if allowed_max_length is not None: | 
					
						
							|  |  |  |             inputs = inputs[:allowed_max_length] | 
					
						
							|  |  |  |             targets = targets[:allowed_max_length] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         inputs_lst.append(inputs) | 
					
						
							|  |  |  |         targets_lst.append(targets) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Convert list of inputs and targets to tensors and transfer to target device | 
					
						
							|  |  |  |     inputs_tensor = torch.stack(inputs_lst).to(device) | 
					
						
							|  |  |  |     targets_tensor = torch.stack(targets_lst).to(device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return inputs_tensor, targets_tensor | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def download_and_load_file(file_path, url): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not os.path.exists(file_path): | 
					
						
							|  |  |  |         with urllib.request.urlopen(url) as response: | 
					
						
							|  |  |  |             text_data = response.read().decode("utf-8") | 
					
						
							|  |  |  |         with open(file_path, "w", encoding="utf-8") as file: | 
					
						
							|  |  |  |             file.write(text_data) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         with open(file_path, "r", encoding="utf-8") as file: | 
					
						
							|  |  |  |             text_data = file.read() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with open(file_path, "r") as file: | 
					
						
							|  |  |  |         data = json.load(file) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def format_input(entry): | 
					
						
							|  |  |  |     instruction_text = ( | 
					
						
							|  |  |  |         f"Below is an instruction that describes a task. " | 
					
						
							|  |  |  |         f"Write a response that appropriately completes the request." | 
					
						
							|  |  |  |         f"\n\n### Instruction:\n{entry['instruction']}" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return instruction_text + input_text | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses): | 
					
						
							|  |  |  |     fig, ax1 = plt.subplots(figsize=(12, 6)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Plot training and validation loss against epochs | 
					
						
							|  |  |  |     ax1.plot(epochs_seen, train_losses, label="Training loss") | 
					
						
							|  |  |  |     ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss") | 
					
						
							|  |  |  |     ax1.set_xlabel("Epochs") | 
					
						
							|  |  |  |     ax1.set_ylabel("Loss") | 
					
						
							|  |  |  |     ax1.legend(loc="upper right") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Create a second x-axis for tokens seen | 
					
						
							|  |  |  |     ax2 = ax1.twiny()  # Create a second x-axis that shares the same y-axis | 
					
						
							|  |  |  |     ax2.plot(tokens_seen, train_losses, alpha=0)  # Invisible plot for aligning ticks | 
					
						
							|  |  |  |     ax2.set_xlabel("Tokens seen") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     fig.tight_layout()  # Adjust layout to make room | 
					
						
							|  |  |  |     plot_name = "loss-plot-standalone.pdf" | 
					
						
							|  |  |  |     print(f"Plot saved as {plot_name}") | 
					
						
							|  |  |  |     plt.savefig(plot_name) | 
					
						
							|  |  |  |     # plt.show() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-22 08:57:18 -05:00
										 |  |  | def main(test_mode=False): | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  |     ####################################### | 
					
						
							|  |  |  |     # Print package versions | 
					
						
							|  |  |  |     ####################################### | 
					
						
							|  |  |  |     print() | 
					
						
							|  |  |  |     pkgs = [ | 
					
						
							|  |  |  |         "matplotlib",  # Plotting library | 
					
						
							|  |  |  |         "tiktoken",    # Tokenizer | 
					
						
							|  |  |  |         "torch",       # Deep learning library | 
					
						
							|  |  |  |         "tqdm",        # Progress bar | 
					
						
							|  |  |  |         "tensorflow",  # For OpenAI's pretrained weights | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  |     for p in pkgs: | 
					
						
							|  |  |  |         print(f"{p} version: {version(p)}") | 
					
						
							|  |  |  |     print(50*"-") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ####################################### | 
					
						
							|  |  |  |     # Download and prepare dataset | 
					
						
							|  |  |  |     ####################################### | 
					
						
							|  |  |  |     file_path = "instruction-data.json" | 
					
						
							|  |  |  |     url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json" | 
					
						
							|  |  |  |     data = download_and_load_file(file_path, url) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     train_portion = int(len(data) * 0.85)  # 85% for training | 
					
						
							|  |  |  |     test_portion = int(len(data) * 0.1)    # 10% for testing | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     train_data = data[:train_portion] | 
					
						
							|  |  |  |     test_data = data[train_portion:train_portion + test_portion] | 
					
						
							|  |  |  |     val_data = data[train_portion + test_portion:] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-22 08:57:18 -05:00
										 |  |  |     # Use very small subset for testing purposes | 
					
						
							|  |  |  |     if args.test_mode: | 
					
						
							|  |  |  |         train_data = train_data[:10] | 
					
						
							|  |  |  |         val_data = val_data[:10] | 
					
						
							|  |  |  |         test_data = test_data[:10] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  |     print("Training set length:", len(train_data)) | 
					
						
							|  |  |  |     print("Validation set length:", len(val_data)) | 
					
						
							|  |  |  |     print("Test set length:", len(test_data)) | 
					
						
							|  |  |  |     print(50*"-") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tokenizer = tiktoken.get_encoding("gpt2") | 
					
						
							|  |  |  |     device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 
					
						
							|  |  |  |     print("Device:", device) | 
					
						
							| 
									
										
										
										
											2024-06-21 06:31:31 -05:00
										 |  |  |     print(50*"-") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  |     customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     num_workers = 0 | 
					
						
							|  |  |  |     batch_size = 8 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     torch.manual_seed(123) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     train_dataset = InstructionDataset(train_data, tokenizer) | 
					
						
							|  |  |  |     train_loader = DataLoader( | 
					
						
							|  |  |  |         train_dataset, | 
					
						
							|  |  |  |         batch_size=batch_size, | 
					
						
							|  |  |  |         collate_fn=customized_collate_fn, | 
					
						
							|  |  |  |         shuffle=True, | 
					
						
							|  |  |  |         drop_last=True, | 
					
						
							|  |  |  |         num_workers=num_workers | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     val_dataset = InstructionDataset(val_data, tokenizer) | 
					
						
							|  |  |  |     val_loader = DataLoader( | 
					
						
							|  |  |  |         val_dataset, | 
					
						
							|  |  |  |         batch_size=batch_size, | 
					
						
							|  |  |  |         collate_fn=customized_collate_fn, | 
					
						
							|  |  |  |         shuffle=False, | 
					
						
							|  |  |  |         drop_last=False, | 
					
						
							|  |  |  |         num_workers=num_workers | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ####################################### | 
					
						
							|  |  |  |     # Load pretrained model | 
					
						
							|  |  |  |     ####################################### | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-22 08:57:18 -05:00
										 |  |  |     # Small GPT model for testing purposes | 
					
						
							|  |  |  |     if args.test_mode: | 
					
						
							|  |  |  |         BASE_CONFIG = { | 
					
						
							|  |  |  |             "vocab_size": 50257, | 
					
						
							|  |  |  |             "context_length": 120, | 
					
						
							|  |  |  |             "drop_rate": 0.0, | 
					
						
							|  |  |  |             "qkv_bias": False, | 
					
						
							|  |  |  |             "emb_dim": 12, | 
					
						
							|  |  |  |             "n_layers": 1, | 
					
						
							|  |  |  |             "n_heads": 2 | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         model = GPTModel(BASE_CONFIG) | 
					
						
							|  |  |  |         model.eval() | 
					
						
							|  |  |  |         device = "cpu" | 
					
						
							|  |  |  |         CHOOSE_MODEL = "Small test model" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Code as it is used in the main chapter | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         BASE_CONFIG = { | 
					
						
							|  |  |  |             "vocab_size": 50257,     # Vocabulary size | 
					
						
							|  |  |  |             "context_length": 1024,  # Context length | 
					
						
							|  |  |  |             "drop_rate": 0.0,        # Dropout rate | 
					
						
							|  |  |  |             "qkv_bias": True         # Query-key-value bias | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_configs = { | 
					
						
							|  |  |  |             "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12}, | 
					
						
							|  |  |  |             "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16}, | 
					
						
							|  |  |  |             "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20}, | 
					
						
							|  |  |  |             "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25}, | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-22 08:57:18 -05:00
										 |  |  |         CHOOSE_MODEL = "gpt2-medium (355M)" | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-22 08:57:18 -05:00
										 |  |  |         BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-22 08:57:18 -05:00
										 |  |  |         model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")") | 
					
						
							|  |  |  |         settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-22 08:57:18 -05:00
										 |  |  |         model = GPTModel(BASE_CONFIG) | 
					
						
							|  |  |  |         load_weights_into_gpt(model, params) | 
					
						
							|  |  |  |         model.eval() | 
					
						
							|  |  |  |         model.to(device) | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     print("Loaded model:", CHOOSE_MODEL) | 
					
						
							|  |  |  |     print(50*"-") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ####################################### | 
					
						
							|  |  |  |     # Finetuning the model | 
					
						
							|  |  |  |     ####################################### | 
					
						
							|  |  |  |     print("Initial losses") | 
					
						
							|  |  |  |     with torch.no_grad(): | 
					
						
							|  |  |  |         train_loss = calc_loss_loader(train_loader, model, device, num_batches=5) | 
					
						
							|  |  |  |         val_loss = calc_loss_loader(val_loader, model, device, num_batches=5) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print("   Training loss:", train_loss) | 
					
						
							|  |  |  |     print("   Validation loss:", val_loss) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     start_time = time.time() | 
					
						
							|  |  |  |     optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1) | 
					
						
							| 
									
										
										
										
											2024-06-22 08:57:18 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  |     num_epochs = 2 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-21 05:23:24 -05:00
										 |  |  |     torch.manual_seed(123) | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  |     train_losses, val_losses, tokens_seen = train_model_simple( | 
					
						
							|  |  |  |         model, train_loader, val_loader, optimizer, device, | 
					
						
							|  |  |  |         num_epochs=num_epochs, eval_freq=5, eval_iter=5, | 
					
						
							|  |  |  |         start_context=format_input(val_data[0]), tokenizer=tokenizer | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     end_time = time.time() | 
					
						
							|  |  |  |     execution_time_minutes = (end_time - start_time) / 60 | 
					
						
							|  |  |  |     print(f"Training completed in {execution_time_minutes:.2f} minutes.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     epochs_tensor = torch.linspace(0, num_epochs, len(train_losses)) | 
					
						
							|  |  |  |     plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses) | 
					
						
							|  |  |  |     print(50*"-") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ####################################### | 
					
						
							|  |  |  |     # Saving results | 
					
						
							|  |  |  |     ####################################### | 
					
						
							| 
									
										
										
										
											2024-06-21 05:23:24 -05:00
										 |  |  |     print("Generating responses") | 
					
						
							| 
									
										
										
										
											2024-06-20 07:37:47 -05:00
										 |  |  |     for i, entry in tqdm(enumerate(test_data), total=len(test_data)): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         input_text = format_input(entry) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         token_ids = generate( | 
					
						
							|  |  |  |             model=model, | 
					
						
							|  |  |  |             idx=text_to_token_ids(input_text, tokenizer).to(device), | 
					
						
							|  |  |  |             max_new_tokens=256, | 
					
						
							|  |  |  |             context_size=BASE_CONFIG["context_length"], | 
					
						
							|  |  |  |             eos_id=50256 | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         generated_text = token_ids_to_text(token_ids, tokenizer) | 
					
						
							|  |  |  |         response_text = generated_text[len(input_text):].replace("### Response:", "").strip() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         test_data[i]["model_response"] = response_text | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     test_data_path = "instruction-data-with-response-standalone.json" | 
					
						
							|  |  |  |     with open(test_data_path, "w") as file: | 
					
						
							|  |  |  |         json.dump(test_data, file, indent=4)  # "indent" for pretty-printing | 
					
						
							|  |  |  |     print(f"Responses saved as {test_data_path}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     file_name = f"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft-standalone.pth" | 
					
						
							|  |  |  |     torch.save(model.state_dict(), file_name) | 
					
						
							|  |  |  |     print(f"Model saved as {file_name}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2024-06-22 08:57:18 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     import argparse | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser( | 
					
						
							|  |  |  |         description="Finetune a GPT model for classification" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--test_mode", | 
					
						
							|  |  |  |         default=False, | 
					
						
							|  |  |  |         action="store_true", | 
					
						
							|  |  |  |         help=("This flag runs the model in test mode for internal testing purposes. " | 
					
						
							|  |  |  |               "Otherwise, it runs the model as it is used in the chapter (recommended).") | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     main(args.test_mode) |