mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-25 06:49:34 +00:00 
			
		
		
		
	add test mode for dataset download
This commit is contained in:
		
							parent
							
								
									bdea15f6c6
								
							
						
					
					
						commit
						5541f7c8fe
					
				| @ -21,15 +21,34 @@ from gpt_download import download_and_load_gpt2 | |||||||
| from previous_chapters import GPTModel, load_weights_into_gpt | from previous_chapters import GPTModel, load_weights_into_gpt | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path): | def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False): | ||||||
|     if data_file_path.exists(): |     if data_file_path.exists(): | ||||||
|         print(f"{data_file_path} already exists. Skipping download and extraction.") |         print(f"{data_file_path} already exists. Skipping download and extraction.") | ||||||
|         return |         return | ||||||
| 
 | 
 | ||||||
|     # Downloading the file |     if test_mode:  # Try multiple times since CI sometimes has connectivity issues | ||||||
|     with urllib.request.urlopen(url) as response: |         max_retries = 5 | ||||||
|         with open(zip_path, "wb") as out_file: |         delay = 5  # delay between retries in seconds | ||||||
|             out_file.write(response.read()) |         for attempt in range(max_retries): | ||||||
|  |             try: | ||||||
|  |                 # Downloading the file | ||||||
|  |                 with urllib.request.urlopen(url, timeout=10) as response: | ||||||
|  |                     with open(zip_path, "wb") as out_file: | ||||||
|  |                         out_file.write(response.read()) | ||||||
|  |                 break  # if download is successful, break out of the loop | ||||||
|  |             except urllib.error.URLError as e: | ||||||
|  |                 print(f"Attempt {attempt + 1} failed: {e}") | ||||||
|  |                 if attempt < max_retries - 1: | ||||||
|  |                     time.sleep(delay)  # wait before retrying | ||||||
|  |                 else: | ||||||
|  |                     print("Failed to download file after several attempts.") | ||||||
|  |                     return  # exit if all retries fail | ||||||
|  | 
 | ||||||
|  |     else:  # Code as it appears in the chapter | ||||||
|  |         # Downloading the file | ||||||
|  |         with urllib.request.urlopen(url) as response: | ||||||
|  |             with open(zip_path, "wb") as out_file: | ||||||
|  |                 out_file.write(response.read()) | ||||||
| 
 | 
 | ||||||
|     # Unzipping the file |     # Unzipping the file | ||||||
|     with zipfile.ZipFile(zip_path, "r") as zip_ref: |     with zipfile.ZipFile(zip_path, "r") as zip_ref: | ||||||
| @ -238,6 +257,7 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--test_mode", |         "--test_mode", | ||||||
|  |         default=False, | ||||||
|         action="store_true", |         action="store_true", | ||||||
|         help=("This flag runs the model in test mode for internal testing purposes. " |         help=("This flag runs the model in test mode for internal testing purposes. " | ||||||
|               "Otherwise, it runs the model as it is used in the chapter (recommended).") |               "Otherwise, it runs the model as it is used in the chapter (recommended).") | ||||||
| @ -253,7 +273,7 @@ if __name__ == "__main__": | |||||||
|     extracted_path = "sms_spam_collection" |     extracted_path = "sms_spam_collection" | ||||||
|     data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv" |     data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv" | ||||||
| 
 | 
 | ||||||
|     download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) |     download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode) | ||||||
|     df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"]) |     df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"]) | ||||||
|     balanced_df = create_balanced_dataset(df) |     balanced_df = create_balanced_dataset(df) | ||||||
|     balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1}) |     balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1}) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 rasbt
						rasbt