mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-30 17:29:59 +00:00 
			
		
		
		
	fix no padding option
This commit is contained in:
		
							parent
							
								
									42d003c4ee
								
							
						
					
					
						commit
						623bc19665
					
				| @ -46,7 +46,7 @@ class LinearWithLoRA(torch.nn.Module): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SpamDataset(Dataset): | class SpamDataset(Dataset): | ||||||
|     def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, skip_padding=False): |     def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, no_padding=False): | ||||||
|         self.data = pd.read_csv(csv_file) |         self.data = pd.read_csv(csv_file) | ||||||
|         self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer) |         self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer) | ||||||
| 
 | 
 | ||||||
| @ -56,7 +56,7 @@ class SpamDataset(Dataset): | |||||||
|             for text in self.data["Text"] |             for text in self.data["Text"] | ||||||
|         ] |         ] | ||||||
| 
 | 
 | ||||||
|         if skip_padding: |         if not no_padding: | ||||||
|             # Pad sequences to the longest sequence |             # Pad sequences to the longest sequence | ||||||
|             self.encoded_texts = [ |             self.encoded_texts = [ | ||||||
|                 et + [pad_token_id] * (self.max_length - len(et)) |                 et + [pad_token_id] * (self.max_length - len(et)) | ||||||
| @ -438,7 +438,7 @@ if __name__ == "__main__": | |||||||
|         if args.context_length == "model_context_length": |         if args.context_length == "model_context_length": | ||||||
|             max_length = model.pos_emb.weight.shape[0] |             max_length = model.pos_emb.weight.shape[0] | ||||||
|         elif args.context_length == "longest_training_example": |         elif args.context_length == "longest_training_example": | ||||||
|             train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer) |             train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer, no_padding=args.no_padding) | ||||||
|             max_length = train_dataset.max_length |             max_length = train_dataset.max_length | ||||||
|         else: |         else: | ||||||
|             try: |             try: | ||||||
| @ -447,9 +447,9 @@ if __name__ == "__main__": | |||||||
|                 raise ValueError("Invalid --context_length argument") |                 raise ValueError("Invalid --context_length argument") | ||||||
| 
 | 
 | ||||||
|     if train_dataset is None: |     if train_dataset is None: | ||||||
|         train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding) |         train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) | ||||||
|     val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding) |     val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) | ||||||
|     test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, skip_padding=args.no_padding) |     test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding) | ||||||
| 
 | 
 | ||||||
|     tokenizer = tiktoken.get_encoding("gpt2") |     tokenizer = tiktoken.get_encoding("gpt2") | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 rasbt
						rasbt