improve instructions

This commit is contained in:
rasbt 2024-04-02 07:12:22 -05:00
parent 776a517d18
commit f30dd2dd2b
2 changed files with 12 additions and 9 deletions

View File

@ -34,10 +34,10 @@ Follow these steps to download the dataset:
Next, run the `prepare_dataset.py` script, which concatenates the (as of this writing, 60,173) text files into fewer larger files so that they can be more efficiently transferred and accessed:
```
prepare_dataset.py \
--data_dir "gutenberg/data" \
python prepare_dataset.py \
--data_dir gutenberg/data \
--max_size_mb 500 \
--output_dir "gutenberg_preprocessed"
--output_dir gutenberg_preprocessed
```
> [!TIP]
@ -53,7 +53,7 @@ prepare_dataset.py \
You can run the pretraining script as follows. Note that the additional command line arguments are shown with the default values for illustration purposes:
```bash
pretraining_simple.py \
python pretraining_simple.py \
--data_dir "gutenberg_preprocessed" \
--n_epochs 1 \
--batch_size 4 \

View File

@ -9,6 +9,7 @@ Script that processes the Project Gutenberg files into fewer larger files.
import argparse
import os
import re
def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftext|>", fallback_encoding="latin1"):
@ -29,6 +30,8 @@ def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftex
with open(file_path, "r", encoding=fallback_encoding) as file:
content = file.read()
# Regular expression to replace multiple blank lines with a single blank line
content = re.sub(r'\n\s*\n', '\n\n', content)
estimated_size = len(content.encode("utf-8"))
if current_size + estimated_size > max_size_mb * 1024 * 1024:
@ -46,11 +49,12 @@ def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftex
target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt")
with open(target_file_path, "w", encoding="utf-8") as target_file:
target_file.write(separator.join(current_content))
return file_counter
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GPT Model Training Configuration")
parser = argparse.ArgumentParser(description="Preprocess and combine text files for pretraining")
parser.add_argument("--data_dir", type=str, default="gutenberg/data",
help="Directory containing the downloaded raw training data")
@ -64,7 +68,6 @@ if __name__ == "__main__":
all_files = [os.path.join(path, name) for path, subdirs, files in os.walk(args.data_dir)
for name in files if name.endswith((".txt", ".txt.utf8")) and "raw" not in path]
target_dir = "path_to_your_large_files"
print(f"{len(all_files)} files to process.")
combine_files(all_files, args.output_dir)
print(f"{len(all_files)} file(s) to process.")
file_counter = combine_files(all_files, args.output_dir)
print(f"{file_counter} file(s) saved in {os.path.abspath(args.output_dir)}")