mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 09:50:23 +00:00 
			
		
		
		
	Fix data download if UCI is temporarily down (#592)
This commit is contained in:
		
							parent
							
								
									0bdcce4e40
								
							
						
					
					
						commit
						d75f74bd0c
					
				| @ -21,34 +21,15 @@ from gpt_download import download_and_load_gpt2 | ||||
| from previous_chapters import GPTModel, load_weights_into_gpt | ||||
| 
 | ||||
| 
 | ||||
| def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False): | ||||
| def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path): | ||||
|     if data_file_path.exists(): | ||||
|         print(f"{data_file_path} already exists. Skipping download and extraction.") | ||||
|         return | ||||
| 
 | ||||
|     if test_mode:  # Try multiple times since CI sometimes has connectivity issues | ||||
|         max_retries = 5 | ||||
|         delay = 5  # delay between retries in seconds | ||||
|         for attempt in range(max_retries): | ||||
|             try: | ||||
|                 # Downloading the file | ||||
|                 with urllib.request.urlopen(url, timeout=10) as response: | ||||
|                     with open(zip_path, "wb") as out_file: | ||||
|                         out_file.write(response.read()) | ||||
|                 break  # if download is successful, break out of the loop | ||||
|             except urllib.error.URLError as e: | ||||
|                 print(f"Attempt {attempt + 1} failed: {e}") | ||||
|                 if attempt < max_retries - 1: | ||||
|                     time.sleep(delay)  # wait before retrying | ||||
|                 else: | ||||
|                     print("Failed to download file after several attempts.") | ||||
|                     return  # exit if all retries fail | ||||
| 
 | ||||
|     else:  # Code as it appears in the chapter | ||||
|         # Downloading the file | ||||
|         with urllib.request.urlopen(url) as response: | ||||
|             with open(zip_path, "wb") as out_file: | ||||
|                 out_file.write(response.read()) | ||||
|     # Downloading the file | ||||
|     with urllib.request.urlopen(url) as response: | ||||
|         with open(zip_path, "wb") as out_file: | ||||
|             out_file.write(response.read()) | ||||
| 
 | ||||
|     # Unzipping the file | ||||
|     with zipfile.ZipFile(zip_path, "r") as zip_ref: | ||||
| @ -277,15 +258,11 @@ if __name__ == "__main__": | ||||
|     data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv" | ||||
| 
 | ||||
|     try: | ||||
|         download_and_unzip_spam_data( | ||||
|             url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode | ||||
|         ) | ||||
|         download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) | ||||
|     except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e: | ||||
|         print(f"Primary URL failed: {e}. Trying backup URL...") | ||||
|         backup_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip" | ||||
|         download_and_unzip_spam_data( | ||||
|             backup_url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode | ||||
|         ) | ||||
|         url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip" | ||||
|         download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) | ||||
| 
 | ||||
|     df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"]) | ||||
|     balanced_df = create_balanced_dataset(df) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian Raschka
						Sebastian Raschka