diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index a9f7120..012bc9a 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -1,6 +1,6 @@ from os import PathLike from pathlib import Path -from typing import Dict, Any, Optional, Type +from typing import Dict, Any, Optional, Type, List, Callable import base64 from io import BytesIO from PIL import Image @@ -8,6 +8,7 @@ from torch.utils.data import Dataset from pypdf import PdfReader from tqdm import tqdm from dataclasses import dataclass, fields +from abc import ABC, abstractmethod from olmocr.data.renderpdf import render_pdf_to_base64png @@ -35,71 +36,22 @@ class StandardFrontMatter: raise TypeError("is_table must be of type bool.") if not isinstance(self.is_diagram, bool): raise TypeError("is_diagram must be of type bool.") - -class MarkdownPDFDocumentDataset(Dataset): - def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None): - """ - Initialize the dataset by finding all markdown files with corresponding PDFs. - - Args: - root_dir: Path to the root folder containing processed markdown and PDF files - target_longest_image_dim: Target dimension for the longest side of the image - image_transform: Optional transform to apply to the PDF images - front_matter_class: Optional dataclass type to validate front matter against - """ - self.root_dir = Path(root_dir) - self.image_transform = image_transform - self.target_longest_image_dim = target_longest_image_dim + +class PipelineStep(ABC): + """Abstract base class for pipeline steps.""" + + @abstractmethod + def process(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Process a sample and return the modified sample.""" + pass + + +class FrontMatterParser(PipelineStep): + """Pipeline step that parses front matter from markdown content.""" + + def __init__(self, front_matter_class: Optional[Type] = None): self.front_matter_class = front_matter_class - self.samples = [] - - # Find all markdown files recursively - print(f"Scanning for markdown files in {self.root_dir}...") - md_files = list(self.root_dir.rglob("*.md")) - - # Verify each markdown file has a corresponding PDF - valid_count = 0 - invalid_pdfs = [] - - print(f"Validating {len(md_files)} markdown-PDF pairs...") - for md_path in tqdm(md_files, desc="Validating PDFs"): - # Look for PDF with same stem (filename without extension) - pdf_path = md_path.with_suffix('.pdf') - - if pdf_path.exists() or pdf_path.is_symlink(): - # Resolve symlink if it is one - if pdf_path.is_symlink(): - pdf_path = pdf_path.resolve() - - # Verify the resolved path exists - if pdf_path.exists(): - # Validate PDF - check it loads and has exactly one page - try: - reader = PdfReader(str(pdf_path)) - num_pages = len(reader.pages) - - if num_pages != 1: - invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}")) - continue - - self.samples.append({ - 'markdown_path': md_path, - 'pdf_path': pdf_path - }) - valid_count += 1 - - except Exception as e: - invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}")) - - print(f"Found {valid_count} valid markdown-PDF pairs") - - if invalid_pdfs: - print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:") - for pdf_path, reason in invalid_pdfs[:5]: # Show first 5 - print(f" - {pdf_path.name}: {reason}") - if len(invalid_pdfs) > 5: - print(f" ... and {len(invalid_pdfs) - 5} more") def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[str, str]: """Extract raw front matter string and text from markdown content.""" @@ -165,6 +117,119 @@ class MarkdownPDFDocumentDataset(Dataset): return self.front_matter_class(**kwargs) + def process(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Parse front matter from markdown content.""" + # Read markdown content if not already loaded + if 'markdown_content' not in sample: + sample['markdown_content'] = sample['markdown_path'].read_text(encoding='utf-8') + + # Extract and parse front matter + front_matter_str, text = self._extract_front_matter_and_text(sample['markdown_content']) + front_matter = self._parse_front_matter_string(front_matter_str) + + # Parse front matter to dataclass if specified + try: + parsed_front_matter = self._parse_front_matter(front_matter) + except Exception as e: + raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}") + + # Update sample + sample['text'] = text + sample['front_matter'] = parsed_front_matter + + return sample + + +class PDFRenderer(PipelineStep): + """Pipeline step that renders PDF to image.""" + + def __init__(self, target_longest_image_dim: int, image_transform: Optional[Callable] = None): + self.target_longest_image_dim = target_longest_image_dim + self.image_transform = image_transform + + def process(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Render PDF to image.""" + # Render PDF to image + base64_png = render_pdf_to_base64png( + str(sample['pdf_path']), + page_num=1, + target_longest_image_dim=self.target_longest_image_dim + ) + png_bytes = base64.b64decode(base64_png) + image = Image.open(BytesIO(png_bytes)) + + # Apply transform if provided + if self.image_transform: + image = self.image_transform(image) + + # Update sample + sample['image'] = image + + return sample + + +class BaseMarkdownPDFDataset(Dataset): + """Base dataset class that loads and verifies markdown-PDF pairs.""" + + def __init__(self, root_dir: str | PathLike, pipeline_steps: Optional[List[PipelineStep]] = None): + """ + Initialize the dataset by finding all markdown files with corresponding PDFs. + + Args: + root_dir: Path to the root folder containing processed markdown and PDF files + pipeline_steps: Optional list of pipeline steps to apply to each sample + """ + self.root_dir = Path(root_dir) + self.pipeline_steps = pipeline_steps or [] + self.samples = [] + + # Find all markdown files recursively + print(f"Scanning for markdown files in {self.root_dir}...") + md_files = list(self.root_dir.rglob("*.md")) + + # Verify each markdown file has a corresponding PDF + valid_count = 0 + invalid_pdfs = [] + + print(f"Validating {len(md_files)} markdown-PDF pairs...") + for md_path in tqdm(md_files, desc="Validating PDFs"): + # Look for PDF with same stem (filename without extension) + pdf_path = md_path.with_suffix('.pdf') + + if pdf_path.exists() or pdf_path.is_symlink(): + # Resolve symlink if it is one + if pdf_path.is_symlink(): + pdf_path = pdf_path.resolve() + + # Verify the resolved path exists + if pdf_path.exists(): + # Validate PDF - check it loads and has exactly one page + try: + reader = PdfReader(str(pdf_path)) + num_pages = len(reader.pages) + + if num_pages != 1: + invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}")) + continue + + self.samples.append({ + 'markdown_path': md_path, + 'pdf_path': pdf_path + }) + valid_count += 1 + + except Exception as e: + invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}")) + + print(f"Found {valid_count} valid markdown-PDF pairs") + + if invalid_pdfs: + print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:") + for pdf_path, reason in invalid_pdfs[:5]: # Show first 5 + print(f" - {pdf_path.name}: {reason}") + if len(invalid_pdfs) > 5: + print(f" ... and {len(invalid_pdfs) - 5} more") + def __len__(self) -> int: return len(self.samples) @@ -173,40 +238,43 @@ class MarkdownPDFDocumentDataset(Dataset): Get a single sample from the dataset. Returns: - dict containing: - - 'image': PIL Image of the rendered PDF page + dict containing at minimum: + - 'markdown_path': Path to the markdown file - 'pdf_path': Path to the PDF file - - 'text': Text content without front matter - - 'front_matter': Dict with parsed front matter + + Additional fields will be added by pipeline steps. """ - sample = self.samples[idx] + # Start with basic sample info + sample = self.samples[idx].copy() - # Read and parse markdown file - markdown_content = sample['markdown_path'].read_text(encoding='utf-8') - front_matter_str, text = self._extract_front_matter_and_text(markdown_content) - front_matter = self._parse_front_matter_string(front_matter_str) + # Apply pipeline steps + for step in self.pipeline_steps: + sample = step.process(sample) - # Render PDF to image - base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1, target_longest_image_dim=self.target_longest_image_dim) - png_bytes = base64.b64decode(base64_png) - image = Image.open(BytesIO(png_bytes)) + return sample + + +class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset): + """Dataset that includes front matter parsing and PDF rendering by default.""" + + def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None): + """ + Initialize the dataset with default pipeline steps. - # Apply transform if provided - if self.image_transform: - image = self.image_transform(image) + Args: + root_dir: Path to the root folder containing processed markdown and PDF files + target_longest_image_dim: Target dimension for the longest side of the image + image_transform: Optional transform to apply to the PDF images + front_matter_class: Optional dataclass type to validate front matter against + """ + # Create default pipeline steps + pipeline_steps = [ + FrontMatterParser(front_matter_class), + PDFRenderer(target_longest_image_dim, image_transform) + ] - # Parse front matter to dataclass if specified - try: - parsed_front_matter = self._parse_front_matter(front_matter) - except Exception as e: - raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}") - - return { - 'image': image, - 'pdf_path': str(sample['pdf_path']), - 'text': text, - 'front_matter': parsed_front_matter - } + # Initialize base class with pipeline + super().__init__(root_dir, pipeline_steps) if __name__ == "__main__": @@ -222,11 +290,46 @@ if __name__ == "__main__": args = parser.parse_args() - # Test dataset initialization - print(f"\nTesting dataset with root directory: {args.root_dir}") - dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, front_matter_class=StandardFrontMatter, image_transform=None) + # Test base dataset without any pipeline steps + print(f"\n=== Testing base dataset without pipeline steps ===") + base_dataset = BaseMarkdownPDFDataset(args.root_dir) + print(f"Dataset length: {len(base_dataset)}") - print(f"\nDataset length: {len(dataset)}") + if len(base_dataset) > 0: + print("\nFirst sample (no pipeline):") + sample = base_dataset[0] + print(f" Keys: {list(sample.keys())}") + print(f" Markdown: {sample['markdown_path'].name}") + print(f" PDF: {sample['pdf_path'].name}") + + # Test with individual pipeline steps + print(f"\n=== Testing with individual pipeline steps ===") + pipeline_dataset = BaseMarkdownPDFDataset( + args.root_dir, + pipeline_steps=[ + FrontMatterParser(StandardFrontMatter), + PDFRenderer(target_longest_image_dim=1024) + ] + ) + + if len(pipeline_dataset) > 0: + print("\nFirst sample (with pipeline):") + sample = pipeline_dataset[0] + print(f" Keys: {list(sample.keys())}") + print(f" Front Matter: {sample['front_matter']}") + print(f" Image size: {sample['image'].size}") + print(f" Text preview: {sample['text'][:100]}...") + + # Test the convenience dataset class + print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===") + dataset = MarkdownPDFDocumentDataset( + args.root_dir, + target_longest_image_dim=1024, + front_matter_class=StandardFrontMatter, + image_transform=None + ) + + print(f"Dataset length: {len(dataset)}") if len(dataset) > 0: # Show first few samples @@ -242,4 +345,4 @@ if __name__ == "__main__": print(f"Image size: {first_sample['image'].size}") print(f"PDF Path: {first_sample['pdf_path']}") print(f"Front Matter: {first_sample['front_matter']}") - print(f"Text: {first_sample['text']}...") + print(f"Text: {first_sample['text'][:200]}...")