diff --git a/ch06/02_bonus_additional-experiments/additional-experiments.py b/ch06/02_bonus_additional-experiments/additional-experiments.py index ccbe060..ae60d8d 100644 --- a/ch06/02_bonus_additional-experiments/additional-experiments.py +++ b/ch06/02_bonus_additional-experiments/additional-experiments.py @@ -46,7 +46,7 @@ class LinearWithLoRA(torch.nn.Module): class SpamDataset(Dataset): - def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, skip_padding=False): + def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, no_padding=False): self.data = pd.read_csv(csv_file) self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer) @@ -56,7 +56,7 @@ class SpamDataset(Dataset): for text in self.data["Text"] ] - if skip_padding: + if not no_padding: # Pad sequences to the longest sequence self.encoded_texts = [ et + [pad_token_id] * (self.max_length - len(et)) @@ -438,7 +438,7 @@ if __name__ == "__main__": if args.context_length == "model_context_length": max_length = model.pos_emb.weight.shape[0] elif args.context_length == "longest_training_example": - train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer) + train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer, no_padding=args.no_padding) max_length = train_dataset.max_length else: try: @@ -447,9 +447,9 @@ if __name__ == "__main__": raise ValueError("Invalid --context_length argument") if train_dataset is None: - train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding) - val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding) - test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding) + train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) + val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) + test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) tokenizer = tiktoken.get_encoding("gpt2") @@ -510,4 +510,4 @@ if __name__ == "__main__": print(f"Training accuracy: {train_accuracy*100:.2f}%") print(f"Validation accuracy: {val_accuracy*100:.2f}%") - print(f"Test accuracy: {test_accuracy*100:.2f}%") + print(f"Test accuracy: {test_accuracy*100:.2f}%") \ No newline at end of file