diff --git a/olmocr/data/prepare_olmocrmix.py b/olmocr/data/prepare_olmocrmix.py index 928f3a2..814f01f 100644 --- a/olmocr/data/prepare_olmocrmix.py +++ b/olmocr/data/prepare_olmocrmix.py @@ -139,12 +139,40 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: else: print(f"Downloading dataset {dataset_path} to {hugging_face_dir}...") - # Download the entire repository including PDFs and parquet files - local_dir = snapshot_download( - repo_id=dataset_path, - repo_type="dataset", - local_dir=hugging_face_dir, - ) + # For allenai/olmOCR-mix-0225, download everything as before + # For other datasets, filter to only download needed files + if dataset_path == "allenai/olmOCR-mix-0225": + # Download the entire repository including PDFs and parquet files + local_dir = snapshot_download( + repo_id=dataset_path, + repo_type="dataset", + local_dir=hugging_face_dir, + ) + else: + # For other datasets, only download the specific parquet file and related PDF tarballs + # Construct the dataset tag for filtering + dataset_tag = f"{subset}_{split}" + + # Define patterns to allow: + # 1. The specific parquet file + # 2. Related PDF tarballs in pdf_tarballs directory + # 3. README and metadata files (for dataset info) + # 4. urls.jsonl for URL mappings if it exists + allow_patterns = [ + f"{dataset_tag}.parquet", + f"pdf_tarballs/{dataset_tag}_*.tar.gz", + "README.md", + "*.json", # Include any metadata JSON files + ] + + print(f"Filtering download to patterns: {allow_patterns}") + + local_dir = snapshot_download( + repo_id=dataset_path, + repo_type="dataset", + local_dir=hugging_face_dir, + allow_patterns=allow_patterns, + ) print(f"Downloaded to: {local_dir}") @@ -178,8 +206,15 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: if existing_pdfs: print(f"Found {len(existing_pdfs)} already extracted PDFs in {extracted_dir}, skipping extraction step") else: - # Find all tarball files - tarball_files = list(pdf_tarballs_dir.glob("*.tar*")) + list(pdf_tarballs_dir.glob("*.tgz")) + # Find tarball files based on dataset type + if dataset_path == "allenai/olmOCR-mix-0225": + # Extract all tarballs for the full dataset + tarball_files = list(pdf_tarballs_dir.glob("*.tar*")) + list(pdf_tarballs_dir.glob("*.tgz")) + else: + # Only extract tarballs matching the dataset_tag pattern + dataset_tag = f"{subset}_{split}" + tarball_files = list(pdf_tarballs_dir.glob(f"{dataset_tag}_*.tar*")) + list(pdf_tarballs_dir.glob(f"{dataset_tag}_*.tgz")) + print(f"Filtering tarballs to pattern: {dataset_tag}_*") if tarball_files: print(f"\nFound {len(tarball_files)} PDF tarballs to extract...") @@ -328,8 +363,11 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: def main(): parser = argparse.ArgumentParser(description="Prepare OLMoCR mix dataset") parser.add_argument("--dataset-path", type=str, default="allenai/olmOCR-mix-0225", help="HuggingFace dataset path (e.g., 'allenai/olmocr-mix')") - parser.add_argument("--subset", type=str, default="00_documents", required=True, help="Dataset subset name") - parser.add_argument("--split", type=str, default="eval_s2pdf", required=True, help="Dataset split ex eval_s2pdf") + + # Add subset and split to the parser (not the group) but they'll be validated later + parser.add_argument("--subset", type=str, default=None, help="Dataset subset name") + parser.add_argument("--split", type=str, default=None, help="Dataset split ex eval_s2pdf") + parser.add_argument("--destination", type=str, required=True, help="Destination directory path") parser.add_argument("--max-examples", type=int, default=None, help="Maximum number of examples to process (default: all)")