diff --git a/ch06/01_main-chapter-code/gpt-class-finetune.py b/ch06/01_main-chapter-code/gpt-class-finetune.py index b1c7053..545dfd4 100644 --- a/ch06/01_main-chapter-code/gpt-class-finetune.py +++ b/ch06/01_main-chapter-code/gpt-class-finetune.py @@ -21,15 +21,34 @@ from gpt_download import download_and_load_gpt2 from previous_chapters import GPTModel, load_weights_into_gpt -def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path): +def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False): if data_file_path.exists(): print(f"{data_file_path} already exists. Skipping download and extraction.") return - # Downloading the file - with urllib.request.urlopen(url) as response: - with open(zip_path, "wb") as out_file: - out_file.write(response.read()) + if test_mode: # Try multiple times since CI sometimes has connectivity issues + max_retries = 5 + delay = 5 # delay between retries in seconds + for attempt in range(max_retries): + try: + # Downloading the file + with urllib.request.urlopen(url, timeout=10) as response: + with open(zip_path, "wb") as out_file: + out_file.write(response.read()) + break # if download is successful, break out of the loop + except urllib.error.URLError as e: + print(f"Attempt {attempt + 1} failed: {e}") + if attempt < max_retries - 1: + time.sleep(delay) # wait before retrying + else: + print("Failed to download file after several attempts.") + return # exit if all retries fail + + else: # Code as it appears in the chapter + # Downloading the file + with urllib.request.urlopen(url) as response: + with open(zip_path, "wb") as out_file: + out_file.write(response.read()) # Unzipping the file with zipfile.ZipFile(zip_path, "r") as zip_ref: @@ -238,6 +257,7 @@ if __name__ == "__main__": ) parser.add_argument( "--test_mode", + default=False, action="store_true", help=("This flag runs the model in test mode for internal testing purposes. " "Otherwise, it runs the model as it is used in the chapter (recommended).") @@ -253,7 +273,7 @@ if __name__ == "__main__": extracted_path = "sms_spam_collection" data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv" - download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) + download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode) df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"]) balanced_df = create_balanced_dataset(df) balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})