diff --git a/olmocr/train/prepare_olmocrmix.py b/olmocr/train/prepare_olmocrmix.py index 75cbe1d..e9d3fe1 100644 --- a/olmocr/train/prepare_olmocrmix.py +++ b/olmocr/train/prepare_olmocrmix.py @@ -4,10 +4,24 @@ from os import PathLike from pathlib import Path from typing import Optional import pandas as pd +import tarfile +from concurrent.futures import ProcessPoolExecutor, as_completed +from tqdm import tqdm from huggingface_hub import snapshot_download +def extract_tarball(tarball_path: Path, extract_dir: Path) -> int: + """Extract a single tarball and return the number of files extracted.""" + try: + with tarfile.open(tarball_path, 'r') as tar: + tar.extractall(extract_dir) + return len(tar.getmembers()) + except Exception as e: + print(f"Error extracting {tarball_path}: {e}") + return 0 + + def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: str | PathLike, max_examples: Optional[int] = None) -> str: """ Prepare OLMoCR mix dataset by downloading from HuggingFace and organizing into a folder structure. @@ -52,8 +66,44 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: else: raise NotImplementedError() + # Step 3: Extract PDF tarballs + pdf_tarballs_dir = dest_path / "hugging_face" / "pdf_tarballs" + if pdf_tarballs_dir.exists(): + extracted_dir = pdf_tarballs_dir / "extracted" + extracted_dir.mkdir(exist_ok=True) + + # Find all tarball files + tarball_files = list(pdf_tarballs_dir.glob("*.tar*")) + list(pdf_tarballs_dir.glob("*.tgz")) + + if tarball_files: + print(f"\nFound {len(tarball_files)} PDF tarballs to extract...") + + # Use ProcessPoolExecutor for parallel extraction + with ProcessPoolExecutor() as executor: + # Submit all tasks + future_to_tarball = {} + for tarball in tarball_files: + future = executor.submit(extract_tarball, tarball, extracted_dir) + future_to_tarball[future] = tarball + + # Process results as they complete with progress bar + total_files_extracted = 0 + with tqdm(total=len(tarball_files), desc="Extracting tarballs") as pbar: + for future in as_completed(future_to_tarball): + tarball = future_to_tarball[future] + try: + files_extracted = future.result() + total_files_extracted += files_extracted + pbar.set_postfix({"files": total_files_extracted}) + except Exception as e: + print(f"\nError with {tarball.name}: {e}") + pbar.update(1) + + print(f"Extracted {total_files_extracted} files from tarballs to {extracted_dir}") + else: + print(f"No PDF tarballs directory found at {pdf_tarballs_dir}") - # Step 3: Process parquet files + # Step 4: Process parquet files total_processed = 0 total_errors = 0 @@ -117,6 +167,9 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: print(f"Completed! Processed {total_processed} examples to {processed_dir}") print(f"Total errors: {total_errors}") + + + return str(processed_dir)