diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index fab2229..401b7a7 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -1,12 +1,13 @@ import base64 import logging from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor, as_completed from dataclasses import dataclass, fields from functools import reduce from io import BytesIO from os import PathLike from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Type, TypeAlias +from typing import Any, Callable, Dict, List, Optional, Type, TypeAlias, Tuple import numpy as np import yaml @@ -26,6 +27,47 @@ Sample: TypeAlias = Dict[str, Any] logger = logging.getLogger(__name__) +def validate_pdf_pair(md_path: Path) -> Tuple[Optional[Dict[str, Path]], Optional[Tuple[Path, str]]]: + """Validate a single markdown-PDF pair. + + Args: + md_path: Path to the markdown file + + Returns: + Tuple of (valid_sample, invalid_pdf_info) + - valid_sample: Dict with markdown_path and pdf_path if valid, None otherwise + - invalid_pdf_info: Tuple of (pdf_path, reason) if invalid, None otherwise + """ + # Look for PDF with same stem (filename without extension) + pdf_path = md_path.with_suffix(".pdf") + + if pdf_path.exists() or pdf_path.is_symlink(): + # Resolve symlink if it is one + if pdf_path.is_symlink(): + pdf_path = pdf_path.resolve() + + # Verify the resolved path exists + if pdf_path.exists(): + # Validate PDF - check it loads and has exactly one page and that you can get document-anchoring from it + try: + reader = PdfReader(str(pdf_path)) + num_pages = len(reader.pages) + + if num_pages != 1: + return None, (pdf_path, f"Expected 1 page, found {num_pages}") + + # Test that document anchoring works + from olmocr.prompts.anchor import get_anchor_text + get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100) + + return {"markdown_path": md_path, "pdf_path": pdf_path}, None + + except Exception as e: + return None, (pdf_path, f"Failed to load: {str(e)}") + + return None, None + + @dataclass(frozen=True, slots=True) class PipelineStep(ABC): """Abstract base class for pipeline steps.""" @@ -55,39 +97,35 @@ class BaseMarkdownPDFDataset(Dataset): logger.info(f"Scanning for markdown files in {self.root_dir}...") md_files = list(self.root_dir.rglob("*.md")) - # Verify each markdown file has a corresponding PDF + # Verify each markdown file has a corresponding PDF using ProcessPoolExecutor valid_count = 0 invalid_pdfs = [] - logger.info(f"Validating {len(md_files)} markdown-PDF pairs...") - for md_path in tqdm(md_files, desc="Validating PDFs"): - # Look for PDF with same stem (filename without extension) - pdf_path = md_path.with_suffix(".pdf") - - if pdf_path.exists() or pdf_path.is_symlink(): - # Resolve symlink if it is one - if pdf_path.is_symlink(): - pdf_path = pdf_path.resolve() - - # Verify the resolved path exists - if pdf_path.exists(): - # Validate PDF - check it loads and has exactly one page and that you can get document-anchoring from it + logger.info(f"Validating {len(md_files)} markdown-PDF pairs using ProcessPoolExecutor...") + + # Use ProcessPoolExecutor for parallel validation + with ProcessPoolExecutor(max_workers=8) as executor: + # Submit all validation tasks + future_to_md = {executor.submit(validate_pdf_pair, md_path): md_path for md_path in md_files} + + # Process results as they complete + with tqdm(total=len(md_files), desc="Validating PDFs") as pbar: + for future in as_completed(future_to_md): + md_path = future_to_md[future] try: - reader = PdfReader(str(pdf_path)) - num_pages = len(reader.pages) - - if num_pages != 1: - invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}")) - continue - - # Test that document anchoring works - get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100) - - self.samples.append({"markdown_path": md_path, "pdf_path": pdf_path}) - valid_count += 1 - + valid_sample, invalid_pdf_info = future.result() + + if valid_sample: + self.samples.append(valid_sample) + valid_count += 1 + elif invalid_pdf_info: + invalid_pdfs.append(invalid_pdf_info) + except Exception as e: - invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}")) + logger.error(f"Error processing {md_path}: {str(e)}") + invalid_pdfs.append((md_path.with_suffix(".pdf"), f"Processing error: {str(e)}")) + + pbar.update(1) logger.info(f"Found {valid_count} valid markdown-PDF pairs")