From dc26541da2f69064a8f17728d03a6bddeb1fc91d Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Mon, 7 Oct 2024 20:59:43 +0000 Subject: [PATCH] Starting code to build parquets... --- pdelfin/train/buildparquetdataset.py | 77 ++++++++++++++++++++++++++++ pdelfin/train/core/config.py | 5 +- pdelfin/train/dataloader.py | 3 -- 3 files changed, 80 insertions(+), 5 deletions(-) create mode 100644 pdelfin/train/buildparquetdataset.py diff --git a/pdelfin/train/buildparquetdataset.py b/pdelfin/train/buildparquetdataset.py new file mode 100644 index 0000000..7fabbe7 --- /dev/null +++ b/pdelfin/train/buildparquetdataset.py @@ -0,0 +1,77 @@ +import argparse +import logging +from functools import partial +import os +import boto3 +from datasets import Dataset +from botocore.exceptions import NoCredentialsError, PartialCredentialsError +from pdelfin.train.dataloader import build_batch_query_response_vision_dataset + + +def save_dataset_in_parquet(dataset: Dataset, output_dir: str, rows_per_file: int = 10000, s3_endpoint_url: str = None): + logger.info("Saving dataset in Parquet files") + + # Check if the output is an S3 path + is_s3 = output_dir.startswith("s3://") + if is_s3: + s3_client = boto3.client('s3', endpoint_url=s3_endpoint_url) if s3_endpoint_url else boto3.client('s3') + else: + os.makedirs(output_dir, exist_ok=True) + + total_rows = len(dataset) + for start_idx in range(0, total_rows, rows_per_file): + end_idx = min(start_idx + rows_per_file, total_rows) + file_name = f"dataset_{start_idx}_{end_idx}.parquet" + if is_s3: + # Saving to S3 + bucket_name, key_prefix = parse_s3_path(output_dir) + output_path = f"{key_prefix}/{file_name}" + local_temp_file = f"/tmp/{file_name}" + logger.info(f"Saving rows {start_idx} to {end_idx} locally at {local_temp_file}") + dataset.select(range(start_idx, end_idx)).to_parquet(local_temp_file) + try: + logger.info(f"Uploading {local_temp_file} to s3://{bucket_name}/{output_path}") + s3_client.upload_file(local_temp_file, bucket_name, output_path) + except (NoCredentialsError, PartialCredentialsError) as e: + logger.error(f"Failed to upload to S3: {e}") + raise + finally: + os.remove(local_temp_file) + else: + # Saving locally + output_path = os.path.join(output_dir, file_name) + logger.info(f"Saving rows {start_idx} to {end_idx} in {output_path}") + dataset.select(range(start_idx, end_idx)).to_parquet(output_path) + +def parse_s3_path(s3_path: str): + """Parses an S3 path into bucket and key prefix.""" + if not s3_path.startswith("s3://"): + raise ValueError("S3 path must start with 's3://'") + path = s3_path[5:] + bucket_name, _, key_prefix = path.partition('/') + return bucket_name, key_prefix + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process and save dataset as Parquet files.") + parser.add_argument("--query_path", type=str, required=True, help="Path to the query dataset JSONL files.") + parser.add_argument("--response_path", type=str, required=True, help="Path to the response dataset JSONL files.") + parser.add_argument("--output_dir", type=str, required=True, help="Directory or S3 path to save the output Parquet files.") + parser.add_argument("--num_proc", type=int, default=32, help="Number of processes to use for data processing.") + parser.add_argument("--s3_endpoint_url", type=str, default=None, help="Custom S3 endpoint URL, e.g., for S3-compatible storage.") + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + # Build the dataset + final_dataset = build_batch_query_response_vision_dataset( + query_glob_path=args.query_path, + response_glob_path=args.response_path, + num_proc=args.num_proc + ) + + # Save the dataset as Parquet files + save_dataset_in_parquet(final_dataset, args.output_dir, s3_endpoint_url=args.s3_endpoint_url) + + logger.info("Dataset processing and saving completed.") \ No newline at end of file diff --git a/pdelfin/train/core/config.py b/pdelfin/train/core/config.py index a1e838b..6e677fc 100644 --- a/pdelfin/train/core/config.py +++ b/pdelfin/train/core/config.py @@ -75,8 +75,9 @@ class AwsConfig: @dataclass class SourceConfig: name: str = field(help="The name of the source") - query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data") - response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai") + parquet_path: Optional[str] = field(help="The s3/glob path to a bunch of parquet files for a preprocessed dataset.", default=None) + query_glob_path: Optional[str] = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data", default=None) + response_glob_path: Optional[str] = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai", default=None) @dataclass diff --git a/pdelfin/train/dataloader.py b/pdelfin/train/dataloader.py index c0caa3d..4aa066d 100644 --- a/pdelfin/train/dataloader.py +++ b/pdelfin/train/dataloader.py @@ -240,9 +240,6 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc) - # Limit the size of the input text not to explode the context size - final_dataset = final_dataset.filter(lambda x: len(x["raw_page_text"]) < 4000, num_proc=num_proc) - return final_dataset