From d75f74bd0cc43052e49f40f8a1621a409732b5f9 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Mon, 31 Mar 2025 16:25:53 -0500 Subject: [PATCH] Fix data download if UCI is temporarily down (#592) --- .../gpt_class_finetune.py | 39 ++++--------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/ch06/01_main-chapter-code/gpt_class_finetune.py b/ch06/01_main-chapter-code/gpt_class_finetune.py index 8e925dd..8308304 100644 --- a/ch06/01_main-chapter-code/gpt_class_finetune.py +++ b/ch06/01_main-chapter-code/gpt_class_finetune.py @@ -21,34 +21,15 @@ from gpt_download import download_and_load_gpt2 from previous_chapters import GPTModel, load_weights_into_gpt -def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False): +def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path): if data_file_path.exists(): print(f"{data_file_path} already exists. Skipping download and extraction.") return - if test_mode: # Try multiple times since CI sometimes has connectivity issues - max_retries = 5 - delay = 5 # delay between retries in seconds - for attempt in range(max_retries): - try: - # Downloading the file - with urllib.request.urlopen(url, timeout=10) as response: - with open(zip_path, "wb") as out_file: - out_file.write(response.read()) - break # if download is successful, break out of the loop - except urllib.error.URLError as e: - print(f"Attempt {attempt + 1} failed: {e}") - if attempt < max_retries - 1: - time.sleep(delay) # wait before retrying - else: - print("Failed to download file after several attempts.") - return # exit if all retries fail - - else: # Code as it appears in the chapter - # Downloading the file - with urllib.request.urlopen(url) as response: - with open(zip_path, "wb") as out_file: - out_file.write(response.read()) + # Downloading the file + with urllib.request.urlopen(url) as response: + with open(zip_path, "wb") as out_file: + out_file.write(response.read()) # Unzipping the file with zipfile.ZipFile(zip_path, "r") as zip_ref: @@ -277,15 +258,11 @@ if __name__ == "__main__": data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv" try: - download_and_unzip_spam_data( - url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode - ) + download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e: print(f"Primary URL failed: {e}. Trying backup URL...") - backup_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip" - download_and_unzip_spam_data( - backup_url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode - ) + url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip" + download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"]) balanced_df = create_balanced_dataset(df)