mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-13 09:12:18 +00:00
Working on a more pipeliney thing
This commit is contained in:
parent
d0df380ae9
commit
d17bef8b4b
@ -1,6 +1,6 @@
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Type
|
||||
from typing import Dict, Any, Optional, Type, List, Callable
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
@ -8,6 +8,7 @@ from torch.utils.data import Dataset
|
||||
from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
from dataclasses import dataclass, fields
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
|
||||
@ -37,69 +38,20 @@ class StandardFrontMatter:
|
||||
raise TypeError("is_diagram must be of type bool.")
|
||||
|
||||
|
||||
class MarkdownPDFDocumentDataset(Dataset):
|
||||
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
|
||||
"""
|
||||
Initialize the dataset by finding all markdown files with corresponding PDFs.
|
||||
class PipelineStep(ABC):
|
||||
"""Abstract base class for pipeline steps."""
|
||||
|
||||
Args:
|
||||
root_dir: Path to the root folder containing processed markdown and PDF files
|
||||
target_longest_image_dim: Target dimension for the longest side of the image
|
||||
image_transform: Optional transform to apply to the PDF images
|
||||
front_matter_class: Optional dataclass type to validate front matter against
|
||||
"""
|
||||
self.root_dir = Path(root_dir)
|
||||
self.image_transform = image_transform
|
||||
self.target_longest_image_dim = target_longest_image_dim
|
||||
@abstractmethod
|
||||
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process a sample and return the modified sample."""
|
||||
pass
|
||||
|
||||
|
||||
class FrontMatterParser(PipelineStep):
|
||||
"""Pipeline step that parses front matter from markdown content."""
|
||||
|
||||
def __init__(self, front_matter_class: Optional[Type] = None):
|
||||
self.front_matter_class = front_matter_class
|
||||
self.samples = []
|
||||
|
||||
# Find all markdown files recursively
|
||||
print(f"Scanning for markdown files in {self.root_dir}...")
|
||||
md_files = list(self.root_dir.rglob("*.md"))
|
||||
|
||||
# Verify each markdown file has a corresponding PDF
|
||||
valid_count = 0
|
||||
invalid_pdfs = []
|
||||
|
||||
print(f"Validating {len(md_files)} markdown-PDF pairs...")
|
||||
for md_path in tqdm(md_files, desc="Validating PDFs"):
|
||||
# Look for PDF with same stem (filename without extension)
|
||||
pdf_path = md_path.with_suffix('.pdf')
|
||||
|
||||
if pdf_path.exists() or pdf_path.is_symlink():
|
||||
# Resolve symlink if it is one
|
||||
if pdf_path.is_symlink():
|
||||
pdf_path = pdf_path.resolve()
|
||||
|
||||
# Verify the resolved path exists
|
||||
if pdf_path.exists():
|
||||
# Validate PDF - check it loads and has exactly one page
|
||||
try:
|
||||
reader = PdfReader(str(pdf_path))
|
||||
num_pages = len(reader.pages)
|
||||
|
||||
if num_pages != 1:
|
||||
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
|
||||
continue
|
||||
|
||||
self.samples.append({
|
||||
'markdown_path': md_path,
|
||||
'pdf_path': pdf_path
|
||||
})
|
||||
valid_count += 1
|
||||
|
||||
except Exception as e:
|
||||
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
|
||||
|
||||
print(f"Found {valid_count} valid markdown-PDF pairs")
|
||||
|
||||
if invalid_pdfs:
|
||||
print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:")
|
||||
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
|
||||
print(f" - {pdf_path.name}: {reason}")
|
||||
if len(invalid_pdfs) > 5:
|
||||
print(f" ... and {len(invalid_pdfs) - 5} more")
|
||||
|
||||
def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[str, str]:
|
||||
"""Extract raw front matter string and text from markdown content."""
|
||||
@ -165,6 +117,119 @@ class MarkdownPDFDocumentDataset(Dataset):
|
||||
|
||||
return self.front_matter_class(**kwargs)
|
||||
|
||||
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Parse front matter from markdown content."""
|
||||
# Read markdown content if not already loaded
|
||||
if 'markdown_content' not in sample:
|
||||
sample['markdown_content'] = sample['markdown_path'].read_text(encoding='utf-8')
|
||||
|
||||
# Extract and parse front matter
|
||||
front_matter_str, text = self._extract_front_matter_and_text(sample['markdown_content'])
|
||||
front_matter = self._parse_front_matter_string(front_matter_str)
|
||||
|
||||
# Parse front matter to dataclass if specified
|
||||
try:
|
||||
parsed_front_matter = self._parse_front_matter(front_matter)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
|
||||
|
||||
# Update sample
|
||||
sample['text'] = text
|
||||
sample['front_matter'] = parsed_front_matter
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class PDFRenderer(PipelineStep):
|
||||
"""Pipeline step that renders PDF to image."""
|
||||
|
||||
def __init__(self, target_longest_image_dim: int, image_transform: Optional[Callable] = None):
|
||||
self.target_longest_image_dim = target_longest_image_dim
|
||||
self.image_transform = image_transform
|
||||
|
||||
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Render PDF to image."""
|
||||
# Render PDF to image
|
||||
base64_png = render_pdf_to_base64png(
|
||||
str(sample['pdf_path']),
|
||||
page_num=1,
|
||||
target_longest_image_dim=self.target_longest_image_dim
|
||||
)
|
||||
png_bytes = base64.b64decode(base64_png)
|
||||
image = Image.open(BytesIO(png_bytes))
|
||||
|
||||
# Apply transform if provided
|
||||
if self.image_transform:
|
||||
image = self.image_transform(image)
|
||||
|
||||
# Update sample
|
||||
sample['image'] = image
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class BaseMarkdownPDFDataset(Dataset):
|
||||
"""Base dataset class that loads and verifies markdown-PDF pairs."""
|
||||
|
||||
def __init__(self, root_dir: str | PathLike, pipeline_steps: Optional[List[PipelineStep]] = None):
|
||||
"""
|
||||
Initialize the dataset by finding all markdown files with corresponding PDFs.
|
||||
|
||||
Args:
|
||||
root_dir: Path to the root folder containing processed markdown and PDF files
|
||||
pipeline_steps: Optional list of pipeline steps to apply to each sample
|
||||
"""
|
||||
self.root_dir = Path(root_dir)
|
||||
self.pipeline_steps = pipeline_steps or []
|
||||
self.samples = []
|
||||
|
||||
# Find all markdown files recursively
|
||||
print(f"Scanning for markdown files in {self.root_dir}...")
|
||||
md_files = list(self.root_dir.rglob("*.md"))
|
||||
|
||||
# Verify each markdown file has a corresponding PDF
|
||||
valid_count = 0
|
||||
invalid_pdfs = []
|
||||
|
||||
print(f"Validating {len(md_files)} markdown-PDF pairs...")
|
||||
for md_path in tqdm(md_files, desc="Validating PDFs"):
|
||||
# Look for PDF with same stem (filename without extension)
|
||||
pdf_path = md_path.with_suffix('.pdf')
|
||||
|
||||
if pdf_path.exists() or pdf_path.is_symlink():
|
||||
# Resolve symlink if it is one
|
||||
if pdf_path.is_symlink():
|
||||
pdf_path = pdf_path.resolve()
|
||||
|
||||
# Verify the resolved path exists
|
||||
if pdf_path.exists():
|
||||
# Validate PDF - check it loads and has exactly one page
|
||||
try:
|
||||
reader = PdfReader(str(pdf_path))
|
||||
num_pages = len(reader.pages)
|
||||
|
||||
if num_pages != 1:
|
||||
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
|
||||
continue
|
||||
|
||||
self.samples.append({
|
||||
'markdown_path': md_path,
|
||||
'pdf_path': pdf_path
|
||||
})
|
||||
valid_count += 1
|
||||
|
||||
except Exception as e:
|
||||
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
|
||||
|
||||
print(f"Found {valid_count} valid markdown-PDF pairs")
|
||||
|
||||
if invalid_pdfs:
|
||||
print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:")
|
||||
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
|
||||
print(f" - {pdf_path.name}: {reason}")
|
||||
if len(invalid_pdfs) > 5:
|
||||
print(f" ... and {len(invalid_pdfs) - 5} more")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
@ -173,40 +238,43 @@ class MarkdownPDFDocumentDataset(Dataset):
|
||||
Get a single sample from the dataset.
|
||||
|
||||
Returns:
|
||||
dict containing:
|
||||
- 'image': PIL Image of the rendered PDF page
|
||||
dict containing at minimum:
|
||||
- 'markdown_path': Path to the markdown file
|
||||
- 'pdf_path': Path to the PDF file
|
||||
- 'text': Text content without front matter
|
||||
- 'front_matter': Dict with parsed front matter
|
||||
|
||||
Additional fields will be added by pipeline steps.
|
||||
"""
|
||||
sample = self.samples[idx]
|
||||
# Start with basic sample info
|
||||
sample = self.samples[idx].copy()
|
||||
|
||||
# Read and parse markdown file
|
||||
markdown_content = sample['markdown_path'].read_text(encoding='utf-8')
|
||||
front_matter_str, text = self._extract_front_matter_and_text(markdown_content)
|
||||
front_matter = self._parse_front_matter_string(front_matter_str)
|
||||
# Apply pipeline steps
|
||||
for step in self.pipeline_steps:
|
||||
sample = step.process(sample)
|
||||
|
||||
# Render PDF to image
|
||||
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1, target_longest_image_dim=self.target_longest_image_dim)
|
||||
png_bytes = base64.b64decode(base64_png)
|
||||
image = Image.open(BytesIO(png_bytes))
|
||||
return sample
|
||||
|
||||
# Apply transform if provided
|
||||
if self.image_transform:
|
||||
image = self.image_transform(image)
|
||||
|
||||
# Parse front matter to dataclass if specified
|
||||
try:
|
||||
parsed_front_matter = self._parse_front_matter(front_matter)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
|
||||
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
||||
"""Dataset that includes front matter parsing and PDF rendering by default."""
|
||||
|
||||
return {
|
||||
'image': image,
|
||||
'pdf_path': str(sample['pdf_path']),
|
||||
'text': text,
|
||||
'front_matter': parsed_front_matter
|
||||
}
|
||||
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
|
||||
"""
|
||||
Initialize the dataset with default pipeline steps.
|
||||
|
||||
Args:
|
||||
root_dir: Path to the root folder containing processed markdown and PDF files
|
||||
target_longest_image_dim: Target dimension for the longest side of the image
|
||||
image_transform: Optional transform to apply to the PDF images
|
||||
front_matter_class: Optional dataclass type to validate front matter against
|
||||
"""
|
||||
# Create default pipeline steps
|
||||
pipeline_steps = [
|
||||
FrontMatterParser(front_matter_class),
|
||||
PDFRenderer(target_longest_image_dim, image_transform)
|
||||
]
|
||||
|
||||
# Initialize base class with pipeline
|
||||
super().__init__(root_dir, pipeline_steps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -222,11 +290,46 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Test dataset initialization
|
||||
print(f"\nTesting dataset with root directory: {args.root_dir}")
|
||||
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, front_matter_class=StandardFrontMatter, image_transform=None)
|
||||
# Test base dataset without any pipeline steps
|
||||
print(f"\n=== Testing base dataset without pipeline steps ===")
|
||||
base_dataset = BaseMarkdownPDFDataset(args.root_dir)
|
||||
print(f"Dataset length: {len(base_dataset)}")
|
||||
|
||||
print(f"\nDataset length: {len(dataset)}")
|
||||
if len(base_dataset) > 0:
|
||||
print("\nFirst sample (no pipeline):")
|
||||
sample = base_dataset[0]
|
||||
print(f" Keys: {list(sample.keys())}")
|
||||
print(f" Markdown: {sample['markdown_path'].name}")
|
||||
print(f" PDF: {sample['pdf_path'].name}")
|
||||
|
||||
# Test with individual pipeline steps
|
||||
print(f"\n=== Testing with individual pipeline steps ===")
|
||||
pipeline_dataset = BaseMarkdownPDFDataset(
|
||||
args.root_dir,
|
||||
pipeline_steps=[
|
||||
FrontMatterParser(StandardFrontMatter),
|
||||
PDFRenderer(target_longest_image_dim=1024)
|
||||
]
|
||||
)
|
||||
|
||||
if len(pipeline_dataset) > 0:
|
||||
print("\nFirst sample (with pipeline):")
|
||||
sample = pipeline_dataset[0]
|
||||
print(f" Keys: {list(sample.keys())}")
|
||||
print(f" Front Matter: {sample['front_matter']}")
|
||||
print(f" Image size: {sample['image'].size}")
|
||||
print(f" Text preview: {sample['text'][:100]}...")
|
||||
|
||||
# Test the convenience dataset class
|
||||
print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===")
|
||||
dataset = MarkdownPDFDocumentDataset(
|
||||
args.root_dir,
|
||||
target_longest_image_dim=1024,
|
||||
front_matter_class=StandardFrontMatter,
|
||||
image_transform=None
|
||||
)
|
||||
|
||||
print(f"Dataset length: {len(dataset)}")
|
||||
|
||||
if len(dataset) > 0:
|
||||
# Show first few samples
|
||||
@ -242,4 +345,4 @@ if __name__ == "__main__":
|
||||
print(f"Image size: {first_sample['image'].size}")
|
||||
print(f"PDF Path: {first_sample['pdf_path']}")
|
||||
print(f"Front Matter: {first_sample['front_matter']}")
|
||||
print(f"Text: {first_sample['text']}...")
|
||||
print(f"Text: {first_sample['text'][:200]}...")
|
||||
|
Loading…
x
Reference in New Issue
Block a user