From fb54b064c9f392767e6a43a2fc69d2beb3ea66ed Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 24 Apr 2024 07:23:11 -0500 Subject: [PATCH] add more experiments --- .../02_bonus_additional-experiments/README.md | 25 ++-- .../additional-experiments.py | 133 +++++++++++------- 2 files changed, 101 insertions(+), 57 deletions(-) diff --git a/ch06/02_bonus_additional-experiments/README.md b/ch06/02_bonus_additional-experiments/README.md index b29baad..fcaf441 100644 --- a/ch06/02_bonus_additional-experiments/README.md +++ b/ch06/02_bonus_additional-experiments/README.md @@ -1,10 +1,19 @@ # Additional Experiments -| Model | Trainable token | Trainable layers | CPU/GPU | Training time | Training acc | Validation acc | Test acc | -|--------------------|-----------------|------------------|---------|---------------|--------------|----------------|----------| -| gpt2-small (124M) | last | last_block | V100 | 0.39 min | 96.63% | 97.99% | 94.33% | -| gpt2-small (124M) | first | last_block | V100 | 0.37 min | 78.46% | 80.54% | 75.00% | -| gpt2-small (124M) | last | last_layer | V100 | 0.33 min | 78.65% | 87.25% | 78.33% | -| gpt2-small (124M) | last | all | V100 | 0.94 min | 99.62% | 96.64% | 96.33% | -| gpt2-medium (355M) | last | last_block | V100 | 0.91 min | 87.50% | 51.01% | 56.67% | -| gpt2-large (774M) | last | last_block | V100 | 1.91 min | 99.52% | 98.66% | 96.67% | \ No newline at end of file +The table below adds experiments to answer additional questions about various design choices. The first row uses the same settings as the main chapter and is used as a reference. +For example, + +- comparing rows 1 and 2 answers the question: "What is the performance difference when we train the last or first token?"; +- comparing rows 1 and 3 answers the question: "What is the performance difference when we train only the last layer instead of the last block?"; +- and so forth. + +| | Model | Weights | Trainable token | Trainable layers | Context length | CPU/GPU | Training time | Training acc | Validation acc | Test acc | +|---|--------------------|------------|-----------------|------------------|-------------------------|---------|---------------|--------------|----------------|----------| +| 1 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) | V100 | 0.39 min | 96.63% | 97.99% | 94.33% | +| 2 | gpt2-small (124M) | pretrained | first | last_block | longest train ex. (120) | V100 | 0.37 min | 78.46% | 80.54% | 75.00% | +| 3 | gpt2-small (124M) | pretrained | last | last_layer | longest train ex. (120) | V100 | 0.33 min | 78.65% | 87.25% | 78.33% | +| 4 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | V100 | 0.94 min | 99.62% | 96.64% | 96.33% | +| 5 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | V100 | 0.91 min | 87.50% | 51.01% | 56.67% | +| 6 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | V100 | 1.91 min | 99.52% | 98.66% | 96.67% | +| 7 | gpt2-small (124M) | random | last | all | longest train ex. (120) | V100 | 0.93 min | 100% | 97.32% | 93.00% | +| 8 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | V100 | 3.24 min | 83.08% | 87.92% | 78.33% | \ No newline at end of file diff --git a/ch06/02_bonus_additional-experiments/additional-experiments.py b/ch06/02_bonus_additional-experiments/additional-experiments.py index c4414e0..a56224e 100644 --- a/ch06/02_bonus_additional-experiments/additional-experiments.py +++ b/ch06/02_bonus_additional-experiments/additional-experiments.py @@ -64,7 +64,7 @@ def download_and_unzip(url, zip_path, extract_to, new_file_path): out_file.write(response.read()) # Unzipping the file - with zipfile.ZipFile(zip_path, 'r') as zip_ref: + with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_to) # Renaming the file to indicate its format @@ -106,7 +106,7 @@ def create_dataset_csvs(data_file_path): test_df.to_csv("test.csv", index=None) -def instantiate_model(choose_model): +def instantiate_model(choose_model, load_weights): BASE_CONFIG = { "vocab_size": 50257, # Vocabulary size @@ -123,12 +123,13 @@ def instantiate_model(choose_model): } BASE_CONFIG.update(model_configs[choose_model]) - - model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")") - settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") - model = GPTModel(BASE_CONFIG) - load_weights_into_gpt(model, params) + + if load_weights: + model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")") + settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") + load_weights_into_gpt(model, params) + model.eval() return model @@ -246,6 +247,14 @@ if __name__ == "__main__": " 'gpt2-large (774M)', 'gpt2-xl (1558M)'." ) ) + parser.add_argument( + "--weights", + type=str, + default="pretrained", + help=( + "Whether to use 'pretrained' or 'random' weights." + ) + ) parser.add_argument( "--trainable_layers", type=str, @@ -262,6 +271,15 @@ if __name__ == "__main__": "Which token to train. Options: 'first', 'last'." ) ) + parser.add_argument( + "--context_length", + type=str, + default="longest_training_example", + help=( + "The context length of the data inputs." + "Options: 'longest_training_example', 'model_context_length' or integer value." + ) + ) args = parser.parse_args() @@ -272,6 +290,52 @@ if __name__ == "__main__": else: raise ValueError("Invalid --trainable_token argument") + + ############################### + # Load model + ############################### + + if args.weights == "pretrained": + load_weights = True + elif args.weights == "random": + load_weights = False + else: + raise ValueError("Invalid --weights argument.") + + model = instantiate_model(args.model_size, load_weights) + for param in model.parameters(): + param.requires_grad = False + + if args.model_size == "gpt2-small (124M)": + in_features = 768 + elif args.model_size == "gpt2-medium (355M)": + in_features = 1024 + elif args.model_size == "gpt2-large (774M)": + in_features = 1280 + elif args.model_size == "gpt2-xl (1558M)": + in_features = 1280 + else: + raise ValueError("Invalid --model_size argument") + + torch.manual_seed(123) + model.out_head = torch.nn.Linear(in_features=in_features, out_features=2) + + if args.trainable_layers == "last_layer": + pass + elif args.trainable_layers == "last_block": + for param in model.trf_blocks[-1].parameters(): + param.requires_grad = True + for param in model.final_norm.parameters(): + param.requires_grad = True + elif args.trainable_layers == "all": + for param in model.parameters(): + param.requires_grad = True + else: + raise ValueError("Invalid --trainable_layers argument.") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + ############################### # Instantiate dataloaders ############################### @@ -291,9 +355,19 @@ if __name__ == "__main__": tokenizer = tiktoken.get_encoding("gpt2") - train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer) - val_dataset = SpamDataset(base_path / "validation.csv", max_length=None, tokenizer=tokenizer) - test_dataset = SpamDataset(base_path / "test.csv", max_length=None, tokenizer=tokenizer) + if args.context_length == "model_context_length": + max_length = model.pos_emb.weight.shape[0] + elif args.context_length == "longest_training_example": + max_length = None + else: + try: + max_length = int(args.context_length) + except ValueError: + raise ValueError("Invalid --context_length argument") + + train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer) + val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer) + test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer) tokenizer = tiktoken.get_encoding("gpt2") @@ -322,45 +396,6 @@ if __name__ == "__main__": drop_last=False, ) - ############################### - # Load model - ############################### - - model = instantiate_model(args.model_size) - for param in model.parameters(): - param.requires_grad = False - - if args.model_size == "gpt2-small (124M)": - in_features = 768 - elif args.model_size == "gpt2-medium (355M)": - in_features = 1024 - elif args.model_size == "gpt2-large (774M)": - in_features = 1280 - elif args.model_size == "gpt2-xl (1558M)": - in_features = 1280 - else: - raise ValueError("Invalid --model_size argument") - - torch.manual_seed(123) - print(model.out_head.weight.shape) - model.out_head = torch.nn.Linear(in_features=in_features, out_features=2) - - if args.trainable_layers == "last_layer": - pass - elif args.trainable_layers == "last_block": - for param in model.trf_blocks[-1].parameters(): - param.requires_grad = True - for param in model.final_norm.parameters(): - param.requires_grad = True - elif args.trainable_layers == "all": - for param in model.parameters(): - param.requires_grad = True - else: - raise ValueError("Invalid --trainable_layers argument.") - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - ############################### # Train model ###############################