mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-31 12:00:23 +00:00
add test mode for dataset download
This commit is contained in:
parent
bdea15f6c6
commit
5541f7c8fe
@ -21,15 +21,34 @@ from gpt_download import download_and_load_gpt2
|
|||||||
from previous_chapters import GPTModel, load_weights_into_gpt
|
from previous_chapters import GPTModel, load_weights_into_gpt
|
||||||
|
|
||||||
|
|
||||||
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
|
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False):
|
||||||
if data_file_path.exists():
|
if data_file_path.exists():
|
||||||
print(f"{data_file_path} already exists. Skipping download and extraction.")
|
print(f"{data_file_path} already exists. Skipping download and extraction.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Downloading the file
|
if test_mode: # Try multiple times since CI sometimes has connectivity issues
|
||||||
with urllib.request.urlopen(url) as response:
|
max_retries = 5
|
||||||
with open(zip_path, "wb") as out_file:
|
delay = 5 # delay between retries in seconds
|
||||||
out_file.write(response.read())
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
# Downloading the file
|
||||||
|
with urllib.request.urlopen(url, timeout=10) as response:
|
||||||
|
with open(zip_path, "wb") as out_file:
|
||||||
|
out_file.write(response.read())
|
||||||
|
break # if download is successful, break out of the loop
|
||||||
|
except urllib.error.URLError as e:
|
||||||
|
print(f"Attempt {attempt + 1} failed: {e}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(delay) # wait before retrying
|
||||||
|
else:
|
||||||
|
print("Failed to download file after several attempts.")
|
||||||
|
return # exit if all retries fail
|
||||||
|
|
||||||
|
else: # Code as it appears in the chapter
|
||||||
|
# Downloading the file
|
||||||
|
with urllib.request.urlopen(url) as response:
|
||||||
|
with open(zip_path, "wb") as out_file:
|
||||||
|
out_file.write(response.read())
|
||||||
|
|
||||||
# Unzipping the file
|
# Unzipping the file
|
||||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
@ -238,6 +257,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test_mode",
|
"--test_mode",
|
||||||
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=("This flag runs the model in test mode for internal testing purposes. "
|
help=("This flag runs the model in test mode for internal testing purposes. "
|
||||||
"Otherwise, it runs the model as it is used in the chapter (recommended).")
|
"Otherwise, it runs the model as it is used in the chapter (recommended).")
|
||||||
@ -253,7 +273,7 @@ if __name__ == "__main__":
|
|||||||
extracted_path = "sms_spam_collection"
|
extracted_path = "sms_spam_collection"
|
||||||
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
|
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
|
||||||
|
|
||||||
download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
|
download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode)
|
||||||
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
|
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
|
||||||
balanced_df = create_balanced_dataset(df)
|
balanced_df = create_balanced_dataset(df)
|
||||||
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
|
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user