mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 01:41:26 +00:00 
			
		
		
		
	fix no padding option
This commit is contained in:
		
							parent
							
								
									42d003c4ee
								
							
						
					
					
						commit
						623bc19665
					
				| @ -46,7 +46,7 @@ class LinearWithLoRA(torch.nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class SpamDataset(Dataset): | ||||
|     def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, skip_padding=False): | ||||
|     def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, no_padding=False): | ||||
|         self.data = pd.read_csv(csv_file) | ||||
|         self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer) | ||||
| 
 | ||||
| @ -56,7 +56,7 @@ class SpamDataset(Dataset): | ||||
|             for text in self.data["Text"] | ||||
|         ] | ||||
| 
 | ||||
|         if skip_padding: | ||||
|         if not no_padding: | ||||
|             # Pad sequences to the longest sequence | ||||
|             self.encoded_texts = [ | ||||
|                 et + [pad_token_id] * (self.max_length - len(et)) | ||||
| @ -438,7 +438,7 @@ if __name__ == "__main__": | ||||
|         if args.context_length == "model_context_length": | ||||
|             max_length = model.pos_emb.weight.shape[0] | ||||
|         elif args.context_length == "longest_training_example": | ||||
|             train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer) | ||||
|             train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer, no_padding=args.no_padding) | ||||
|             max_length = train_dataset.max_length | ||||
|         else: | ||||
|             try: | ||||
| @ -447,9 +447,9 @@ if __name__ == "__main__": | ||||
|                 raise ValueError("Invalid --context_length argument") | ||||
| 
 | ||||
|     if train_dataset is None: | ||||
|         train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding) | ||||
|     val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding) | ||||
|     test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding) | ||||
|         train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) | ||||
|     val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) | ||||
|     test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) | ||||
| 
 | ||||
|     tokenizer = tiktoken.get_encoding("gpt2") | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 rasbt
						rasbt