mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-02 04:48:07 +00:00
add more experiments
This commit is contained in:
parent
b2cf956054
commit
fb54b064c9
@ -1,10 +1,19 @@
|
||||
# Additional Experiments
|
||||
|
||||
| Model | Trainable token | Trainable layers | CPU/GPU | Training time | Training acc | Validation acc | Test acc |
|
||||
|--------------------|-----------------|------------------|---------|---------------|--------------|----------------|----------|
|
||||
| gpt2-small (124M) | last | last_block | V100 | 0.39 min | 96.63% | 97.99% | 94.33% |
|
||||
| gpt2-small (124M) | first | last_block | V100 | 0.37 min | 78.46% | 80.54% | 75.00% |
|
||||
| gpt2-small (124M) | last | last_layer | V100 | 0.33 min | 78.65% | 87.25% | 78.33% |
|
||||
| gpt2-small (124M) | last | all | V100 | 0.94 min | 99.62% | 96.64% | 96.33% |
|
||||
| gpt2-medium (355M) | last | last_block | V100 | 0.91 min | 87.50% | 51.01% | 56.67% |
|
||||
| gpt2-large (774M) | last | last_block | V100 | 1.91 min | 99.52% | 98.66% | 96.67% |
|
||||
The table below adds experiments to answer additional questions about various design choices. The first row uses the same settings as the main chapter and is used as a reference.
|
||||
For example,
|
||||
|
||||
- comparing rows 1 and 2 answers the question: "What is the performance difference when we train the last or first token?";
|
||||
- comparing rows 1 and 3 answers the question: "What is the performance difference when we train only the last layer instead of the last block?";
|
||||
- and so forth.
|
||||
|
||||
| | Model | Weights | Trainable token | Trainable layers | Context length | CPU/GPU | Training time | Training acc | Validation acc | Test acc |
|
||||
|---|--------------------|------------|-----------------|------------------|-------------------------|---------|---------------|--------------|----------------|----------|
|
||||
| 1 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) | V100 | 0.39 min | 96.63% | 97.99% | 94.33% |
|
||||
| 2 | gpt2-small (124M) | pretrained | first | last_block | longest train ex. (120) | V100 | 0.37 min | 78.46% | 80.54% | 75.00% |
|
||||
| 3 | gpt2-small (124M) | pretrained | last | last_layer | longest train ex. (120) | V100 | 0.33 min | 78.65% | 87.25% | 78.33% |
|
||||
| 4 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | V100 | 0.94 min | 99.62% | 96.64% | 96.33% |
|
||||
| 5 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | V100 | 0.91 min | 87.50% | 51.01% | 56.67% |
|
||||
| 6 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | V100 | 1.91 min | 99.52% | 98.66% | 96.67% |
|
||||
| 7 | gpt2-small (124M) | random | last | all | longest train ex. (120) | V100 | 0.93 min | 100% | 97.32% | 93.00% |
|
||||
| 8 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | V100 | 3.24 min | 83.08% | 87.92% | 78.33% |
|
@ -64,7 +64,7 @@ def download_and_unzip(url, zip_path, extract_to, new_file_path):
|
||||
out_file.write(response.read())
|
||||
|
||||
# Unzipping the file
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_to)
|
||||
|
||||
# Renaming the file to indicate its format
|
||||
@ -106,7 +106,7 @@ def create_dataset_csvs(data_file_path):
|
||||
test_df.to_csv("test.csv", index=None)
|
||||
|
||||
|
||||
def instantiate_model(choose_model):
|
||||
def instantiate_model(choose_model, load_weights):
|
||||
|
||||
BASE_CONFIG = {
|
||||
"vocab_size": 50257, # Vocabulary size
|
||||
@ -123,12 +123,13 @@ def instantiate_model(choose_model):
|
||||
}
|
||||
|
||||
BASE_CONFIG.update(model_configs[choose_model])
|
||||
|
||||
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
|
||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||
|
||||
model = GPTModel(BASE_CONFIG)
|
||||
load_weights_into_gpt(model, params)
|
||||
|
||||
if load_weights:
|
||||
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
|
||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||
load_weights_into_gpt(model, params)
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@ -246,6 +247,14 @@ if __name__ == "__main__":
|
||||
" 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weights",
|
||||
type=str,
|
||||
default="pretrained",
|
||||
help=(
|
||||
"Whether to use 'pretrained' or 'random' weights."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainable_layers",
|
||||
type=str,
|
||||
@ -262,6 +271,15 @@ if __name__ == "__main__":
|
||||
"Which token to train. Options: 'first', 'last'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_length",
|
||||
type=str,
|
||||
default="longest_training_example",
|
||||
help=(
|
||||
"The context length of the data inputs."
|
||||
"Options: 'longest_training_example', 'model_context_length' or integer value."
|
||||
)
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -272,6 +290,52 @@ if __name__ == "__main__":
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_token argument")
|
||||
|
||||
|
||||
###############################
|
||||
# Load model
|
||||
###############################
|
||||
|
||||
if args.weights == "pretrained":
|
||||
load_weights = True
|
||||
elif args.weights == "random":
|
||||
load_weights = False
|
||||
else:
|
||||
raise ValueError("Invalid --weights argument.")
|
||||
|
||||
model = instantiate_model(args.model_size, load_weights)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if args.model_size == "gpt2-small (124M)":
|
||||
in_features = 768
|
||||
elif args.model_size == "gpt2-medium (355M)":
|
||||
in_features = 1024
|
||||
elif args.model_size == "gpt2-large (774M)":
|
||||
in_features = 1280
|
||||
elif args.model_size == "gpt2-xl (1558M)":
|
||||
in_features = 1280
|
||||
else:
|
||||
raise ValueError("Invalid --model_size argument")
|
||||
|
||||
torch.manual_seed(123)
|
||||
model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
|
||||
|
||||
if args.trainable_layers == "last_layer":
|
||||
pass
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.trf_blocks[-1].parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.final_norm.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
###############################
|
||||
# Instantiate dataloaders
|
||||
###############################
|
||||
@ -291,9 +355,19 @@ if __name__ == "__main__":
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
|
||||
val_dataset = SpamDataset(base_path / "validation.csv", max_length=None, tokenizer=tokenizer)
|
||||
test_dataset = SpamDataset(base_path / "test.csv", max_length=None, tokenizer=tokenizer)
|
||||
if args.context_length == "model_context_length":
|
||||
max_length = model.pos_emb.weight.shape[0]
|
||||
elif args.context_length == "longest_training_example":
|
||||
max_length = None
|
||||
else:
|
||||
try:
|
||||
max_length = int(args.context_length)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid --context_length argument")
|
||||
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@ -322,45 +396,6 @@ if __name__ == "__main__":
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
###############################
|
||||
# Load model
|
||||
###############################
|
||||
|
||||
model = instantiate_model(args.model_size)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if args.model_size == "gpt2-small (124M)":
|
||||
in_features = 768
|
||||
elif args.model_size == "gpt2-medium (355M)":
|
||||
in_features = 1024
|
||||
elif args.model_size == "gpt2-large (774M)":
|
||||
in_features = 1280
|
||||
elif args.model_size == "gpt2-xl (1558M)":
|
||||
in_features = 1280
|
||||
else:
|
||||
raise ValueError("Invalid --model_size argument")
|
||||
|
||||
torch.manual_seed(123)
|
||||
print(model.out_head.weight.shape)
|
||||
model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
|
||||
|
||||
if args.trainable_layers == "last_layer":
|
||||
pass
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.trf_blocks[-1].parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.final_norm.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
###############################
|
||||
# Train model
|
||||
###############################
|
||||
|
Loading…
x
Reference in New Issue
Block a user