Working on a more pipeliney thing

This commit is contained in:
Jake Poznanski 2025-06-11 21:51:24 +00:00
parent d0df380ae9
commit d17bef8b4b

View File

@ -1,6 +1,6 @@
from os import PathLike
from pathlib import Path
from typing import Dict, Any, Optional, Type
from typing import Dict, Any, Optional, Type, List, Callable
import base64
from io import BytesIO
from PIL import Image
@ -8,6 +8,7 @@ from torch.utils.data import Dataset
from pypdf import PdfReader
from tqdm import tqdm
from dataclasses import dataclass, fields
from abc import ABC, abstractmethod
from olmocr.data.renderpdf import render_pdf_to_base64png
@ -37,69 +38,20 @@ class StandardFrontMatter:
raise TypeError("is_diagram must be of type bool.")
class MarkdownPDFDocumentDataset(Dataset):
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
"""
Initialize the dataset by finding all markdown files with corresponding PDFs.
class PipelineStep(ABC):
"""Abstract base class for pipeline steps."""
Args:
root_dir: Path to the root folder containing processed markdown and PDF files
target_longest_image_dim: Target dimension for the longest side of the image
image_transform: Optional transform to apply to the PDF images
front_matter_class: Optional dataclass type to validate front matter against
"""
self.root_dir = Path(root_dir)
self.image_transform = image_transform
self.target_longest_image_dim = target_longest_image_dim
@abstractmethod
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Process a sample and return the modified sample."""
pass
class FrontMatterParser(PipelineStep):
"""Pipeline step that parses front matter from markdown content."""
def __init__(self, front_matter_class: Optional[Type] = None):
self.front_matter_class = front_matter_class
self.samples = []
# Find all markdown files recursively
print(f"Scanning for markdown files in {self.root_dir}...")
md_files = list(self.root_dir.rglob("*.md"))
# Verify each markdown file has a corresponding PDF
valid_count = 0
invalid_pdfs = []
print(f"Validating {len(md_files)} markdown-PDF pairs...")
for md_path in tqdm(md_files, desc="Validating PDFs"):
# Look for PDF with same stem (filename without extension)
pdf_path = md_path.with_suffix('.pdf')
if pdf_path.exists() or pdf_path.is_symlink():
# Resolve symlink if it is one
if pdf_path.is_symlink():
pdf_path = pdf_path.resolve()
# Verify the resolved path exists
if pdf_path.exists():
# Validate PDF - check it loads and has exactly one page
try:
reader = PdfReader(str(pdf_path))
num_pages = len(reader.pages)
if num_pages != 1:
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
continue
self.samples.append({
'markdown_path': md_path,
'pdf_path': pdf_path
})
valid_count += 1
except Exception as e:
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
print(f"Found {valid_count} valid markdown-PDF pairs")
if invalid_pdfs:
print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:")
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
print(f" - {pdf_path.name}: {reason}")
if len(invalid_pdfs) > 5:
print(f" ... and {len(invalid_pdfs) - 5} more")
def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[str, str]:
"""Extract raw front matter string and text from markdown content."""
@ -165,6 +117,119 @@ class MarkdownPDFDocumentDataset(Dataset):
return self.front_matter_class(**kwargs)
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Parse front matter from markdown content."""
# Read markdown content if not already loaded
if 'markdown_content' not in sample:
sample['markdown_content'] = sample['markdown_path'].read_text(encoding='utf-8')
# Extract and parse front matter
front_matter_str, text = self._extract_front_matter_and_text(sample['markdown_content'])
front_matter = self._parse_front_matter_string(front_matter_str)
# Parse front matter to dataclass if specified
try:
parsed_front_matter = self._parse_front_matter(front_matter)
except Exception as e:
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
# Update sample
sample['text'] = text
sample['front_matter'] = parsed_front_matter
return sample
class PDFRenderer(PipelineStep):
"""Pipeline step that renders PDF to image."""
def __init__(self, target_longest_image_dim: int, image_transform: Optional[Callable] = None):
self.target_longest_image_dim = target_longest_image_dim
self.image_transform = image_transform
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Render PDF to image."""
# Render PDF to image
base64_png = render_pdf_to_base64png(
str(sample['pdf_path']),
page_num=1,
target_longest_image_dim=self.target_longest_image_dim
)
png_bytes = base64.b64decode(base64_png)
image = Image.open(BytesIO(png_bytes))
# Apply transform if provided
if self.image_transform:
image = self.image_transform(image)
# Update sample
sample['image'] = image
return sample
class BaseMarkdownPDFDataset(Dataset):
"""Base dataset class that loads and verifies markdown-PDF pairs."""
def __init__(self, root_dir: str | PathLike, pipeline_steps: Optional[List[PipelineStep]] = None):
"""
Initialize the dataset by finding all markdown files with corresponding PDFs.
Args:
root_dir: Path to the root folder containing processed markdown and PDF files
pipeline_steps: Optional list of pipeline steps to apply to each sample
"""
self.root_dir = Path(root_dir)
self.pipeline_steps = pipeline_steps or []
self.samples = []
# Find all markdown files recursively
print(f"Scanning for markdown files in {self.root_dir}...")
md_files = list(self.root_dir.rglob("*.md"))
# Verify each markdown file has a corresponding PDF
valid_count = 0
invalid_pdfs = []
print(f"Validating {len(md_files)} markdown-PDF pairs...")
for md_path in tqdm(md_files, desc="Validating PDFs"):
# Look for PDF with same stem (filename without extension)
pdf_path = md_path.with_suffix('.pdf')
if pdf_path.exists() or pdf_path.is_symlink():
# Resolve symlink if it is one
if pdf_path.is_symlink():
pdf_path = pdf_path.resolve()
# Verify the resolved path exists
if pdf_path.exists():
# Validate PDF - check it loads and has exactly one page
try:
reader = PdfReader(str(pdf_path))
num_pages = len(reader.pages)
if num_pages != 1:
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
continue
self.samples.append({
'markdown_path': md_path,
'pdf_path': pdf_path
})
valid_count += 1
except Exception as e:
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
print(f"Found {valid_count} valid markdown-PDF pairs")
if invalid_pdfs:
print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:")
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
print(f" - {pdf_path.name}: {reason}")
if len(invalid_pdfs) > 5:
print(f" ... and {len(invalid_pdfs) - 5} more")
def __len__(self) -> int:
return len(self.samples)
@ -173,40 +238,43 @@ class MarkdownPDFDocumentDataset(Dataset):
Get a single sample from the dataset.
Returns:
dict containing:
- 'image': PIL Image of the rendered PDF page
dict containing at minimum:
- 'markdown_path': Path to the markdown file
- 'pdf_path': Path to the PDF file
- 'text': Text content without front matter
- 'front_matter': Dict with parsed front matter
Additional fields will be added by pipeline steps.
"""
sample = self.samples[idx]
# Start with basic sample info
sample = self.samples[idx].copy()
# Read and parse markdown file
markdown_content = sample['markdown_path'].read_text(encoding='utf-8')
front_matter_str, text = self._extract_front_matter_and_text(markdown_content)
front_matter = self._parse_front_matter_string(front_matter_str)
# Apply pipeline steps
for step in self.pipeline_steps:
sample = step.process(sample)
# Render PDF to image
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1, target_longest_image_dim=self.target_longest_image_dim)
png_bytes = base64.b64decode(base64_png)
image = Image.open(BytesIO(png_bytes))
return sample
# Apply transform if provided
if self.image_transform:
image = self.image_transform(image)
# Parse front matter to dataclass if specified
try:
parsed_front_matter = self._parse_front_matter(front_matter)
except Exception as e:
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
"""Dataset that includes front matter parsing and PDF rendering by default."""
return {
'image': image,
'pdf_path': str(sample['pdf_path']),
'text': text,
'front_matter': parsed_front_matter
}
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
"""
Initialize the dataset with default pipeline steps.
Args:
root_dir: Path to the root folder containing processed markdown and PDF files
target_longest_image_dim: Target dimension for the longest side of the image
image_transform: Optional transform to apply to the PDF images
front_matter_class: Optional dataclass type to validate front matter against
"""
# Create default pipeline steps
pipeline_steps = [
FrontMatterParser(front_matter_class),
PDFRenderer(target_longest_image_dim, image_transform)
]
# Initialize base class with pipeline
super().__init__(root_dir, pipeline_steps)
if __name__ == "__main__":
@ -222,11 +290,46 @@ if __name__ == "__main__":
args = parser.parse_args()
# Test dataset initialization
print(f"\nTesting dataset with root directory: {args.root_dir}")
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, front_matter_class=StandardFrontMatter, image_transform=None)
# Test base dataset without any pipeline steps
print(f"\n=== Testing base dataset without pipeline steps ===")
base_dataset = BaseMarkdownPDFDataset(args.root_dir)
print(f"Dataset length: {len(base_dataset)}")
print(f"\nDataset length: {len(dataset)}")
if len(base_dataset) > 0:
print("\nFirst sample (no pipeline):")
sample = base_dataset[0]
print(f" Keys: {list(sample.keys())}")
print(f" Markdown: {sample['markdown_path'].name}")
print(f" PDF: {sample['pdf_path'].name}")
# Test with individual pipeline steps
print(f"\n=== Testing with individual pipeline steps ===")
pipeline_dataset = BaseMarkdownPDFDataset(
args.root_dir,
pipeline_steps=[
FrontMatterParser(StandardFrontMatter),
PDFRenderer(target_longest_image_dim=1024)
]
)
if len(pipeline_dataset) > 0:
print("\nFirst sample (with pipeline):")
sample = pipeline_dataset[0]
print(f" Keys: {list(sample.keys())}")
print(f" Front Matter: {sample['front_matter']}")
print(f" Image size: {sample['image'].size}")
print(f" Text preview: {sample['text'][:100]}...")
# Test the convenience dataset class
print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===")
dataset = MarkdownPDFDocumentDataset(
args.root_dir,
target_longest_image_dim=1024,
front_matter_class=StandardFrontMatter,
image_transform=None
)
print(f"Dataset length: {len(dataset)}")
if len(dataset) > 0:
# Show first few samples
@ -242,4 +345,4 @@ if __name__ == "__main__":
print(f"Image size: {first_sample['image'].size}")
print(f"PDF Path: {first_sample['pdf_path']}")
print(f"Front Matter: {first_sample['front_matter']}")
print(f"Text: {first_sample['text']}...")
print(f"Text: {first_sample['text'][:200]}...")