mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-31 03:50:23 +00:00
fix no padding option
This commit is contained in:
parent
42d003c4ee
commit
623bc19665
@ -46,7 +46,7 @@ class LinearWithLoRA(torch.nn.Module):
|
||||
|
||||
|
||||
class SpamDataset(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, skip_padding=False):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, no_padding=False):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
|
||||
|
||||
@ -56,7 +56,7 @@ class SpamDataset(Dataset):
|
||||
for text in self.data["Text"]
|
||||
]
|
||||
|
||||
if skip_padding:
|
||||
if not no_padding:
|
||||
# Pad sequences to the longest sequence
|
||||
self.encoded_texts = [
|
||||
et + [pad_token_id] * (self.max_length - len(et))
|
||||
@ -438,7 +438,7 @@ if __name__ == "__main__":
|
||||
if args.context_length == "model_context_length":
|
||||
max_length = model.pos_emb.weight.shape[0]
|
||||
elif args.context_length == "longest_training_example":
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer)
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer, no_padding=args.no_padding)
|
||||
max_length = train_dataset.max_length
|
||||
else:
|
||||
try:
|
||||
@ -447,9 +447,9 @@ if __name__ == "__main__":
|
||||
raise ValueError("Invalid --context_length argument")
|
||||
|
||||
if train_dataset is None:
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding)
|
||||
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding)
|
||||
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding)
|
||||
train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
||||
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
||||
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@ -510,4 +510,4 @@ if __name__ == "__main__":
|
||||
|
||||
print(f"Training accuracy: {train_accuracy*100:.2f}%")
|
||||
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
||||
print(f"Test accuracy: {test_accuracy*100:.2f}%")
|
||||
print(f"Test accuracy: {test_accuracy*100:.2f}%")
|
Loading…
x
Reference in New Issue
Block a user