mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-13 09:12:18 +00:00
Working on a more pipeliney thing
This commit is contained in:
parent
d0df380ae9
commit
d17bef8b4b
@ -1,6 +1,6 @@
|
|||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Optional, Type
|
from typing import Dict, Any, Optional, Type, List, Callable
|
||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -8,6 +8,7 @@ from torch.utils.data import Dataset
|
|||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
|
|
||||||
@ -37,69 +38,20 @@ class StandardFrontMatter:
|
|||||||
raise TypeError("is_diagram must be of type bool.")
|
raise TypeError("is_diagram must be of type bool.")
|
||||||
|
|
||||||
|
|
||||||
class MarkdownPDFDocumentDataset(Dataset):
|
class PipelineStep(ABC):
|
||||||
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
|
"""Abstract base class for pipeline steps."""
|
||||||
"""
|
|
||||||
Initialize the dataset by finding all markdown files with corresponding PDFs.
|
|
||||||
|
|
||||||
Args:
|
@abstractmethod
|
||||||
root_dir: Path to the root folder containing processed markdown and PDF files
|
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
target_longest_image_dim: Target dimension for the longest side of the image
|
"""Process a sample and return the modified sample."""
|
||||||
image_transform: Optional transform to apply to the PDF images
|
pass
|
||||||
front_matter_class: Optional dataclass type to validate front matter against
|
|
||||||
"""
|
|
||||||
self.root_dir = Path(root_dir)
|
class FrontMatterParser(PipelineStep):
|
||||||
self.image_transform = image_transform
|
"""Pipeline step that parses front matter from markdown content."""
|
||||||
self.target_longest_image_dim = target_longest_image_dim
|
|
||||||
|
def __init__(self, front_matter_class: Optional[Type] = None):
|
||||||
self.front_matter_class = front_matter_class
|
self.front_matter_class = front_matter_class
|
||||||
self.samples = []
|
|
||||||
|
|
||||||
# Find all markdown files recursively
|
|
||||||
print(f"Scanning for markdown files in {self.root_dir}...")
|
|
||||||
md_files = list(self.root_dir.rglob("*.md"))
|
|
||||||
|
|
||||||
# Verify each markdown file has a corresponding PDF
|
|
||||||
valid_count = 0
|
|
||||||
invalid_pdfs = []
|
|
||||||
|
|
||||||
print(f"Validating {len(md_files)} markdown-PDF pairs...")
|
|
||||||
for md_path in tqdm(md_files, desc="Validating PDFs"):
|
|
||||||
# Look for PDF with same stem (filename without extension)
|
|
||||||
pdf_path = md_path.with_suffix('.pdf')
|
|
||||||
|
|
||||||
if pdf_path.exists() or pdf_path.is_symlink():
|
|
||||||
# Resolve symlink if it is one
|
|
||||||
if pdf_path.is_symlink():
|
|
||||||
pdf_path = pdf_path.resolve()
|
|
||||||
|
|
||||||
# Verify the resolved path exists
|
|
||||||
if pdf_path.exists():
|
|
||||||
# Validate PDF - check it loads and has exactly one page
|
|
||||||
try:
|
|
||||||
reader = PdfReader(str(pdf_path))
|
|
||||||
num_pages = len(reader.pages)
|
|
||||||
|
|
||||||
if num_pages != 1:
|
|
||||||
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.samples.append({
|
|
||||||
'markdown_path': md_path,
|
|
||||||
'pdf_path': pdf_path
|
|
||||||
})
|
|
||||||
valid_count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
|
|
||||||
|
|
||||||
print(f"Found {valid_count} valid markdown-PDF pairs")
|
|
||||||
|
|
||||||
if invalid_pdfs:
|
|
||||||
print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:")
|
|
||||||
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
|
|
||||||
print(f" - {pdf_path.name}: {reason}")
|
|
||||||
if len(invalid_pdfs) > 5:
|
|
||||||
print(f" ... and {len(invalid_pdfs) - 5} more")
|
|
||||||
|
|
||||||
def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[str, str]:
|
def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[str, str]:
|
||||||
"""Extract raw front matter string and text from markdown content."""
|
"""Extract raw front matter string and text from markdown content."""
|
||||||
@ -165,6 +117,119 @@ class MarkdownPDFDocumentDataset(Dataset):
|
|||||||
|
|
||||||
return self.front_matter_class(**kwargs)
|
return self.front_matter_class(**kwargs)
|
||||||
|
|
||||||
|
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Parse front matter from markdown content."""
|
||||||
|
# Read markdown content if not already loaded
|
||||||
|
if 'markdown_content' not in sample:
|
||||||
|
sample['markdown_content'] = sample['markdown_path'].read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
# Extract and parse front matter
|
||||||
|
front_matter_str, text = self._extract_front_matter_and_text(sample['markdown_content'])
|
||||||
|
front_matter = self._parse_front_matter_string(front_matter_str)
|
||||||
|
|
||||||
|
# Parse front matter to dataclass if specified
|
||||||
|
try:
|
||||||
|
parsed_front_matter = self._parse_front_matter(front_matter)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
|
||||||
|
|
||||||
|
# Update sample
|
||||||
|
sample['text'] = text
|
||||||
|
sample['front_matter'] = parsed_front_matter
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class PDFRenderer(PipelineStep):
|
||||||
|
"""Pipeline step that renders PDF to image."""
|
||||||
|
|
||||||
|
def __init__(self, target_longest_image_dim: int, image_transform: Optional[Callable] = None):
|
||||||
|
self.target_longest_image_dim = target_longest_image_dim
|
||||||
|
self.image_transform = image_transform
|
||||||
|
|
||||||
|
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Render PDF to image."""
|
||||||
|
# Render PDF to image
|
||||||
|
base64_png = render_pdf_to_base64png(
|
||||||
|
str(sample['pdf_path']),
|
||||||
|
page_num=1,
|
||||||
|
target_longest_image_dim=self.target_longest_image_dim
|
||||||
|
)
|
||||||
|
png_bytes = base64.b64decode(base64_png)
|
||||||
|
image = Image.open(BytesIO(png_bytes))
|
||||||
|
|
||||||
|
# Apply transform if provided
|
||||||
|
if self.image_transform:
|
||||||
|
image = self.image_transform(image)
|
||||||
|
|
||||||
|
# Update sample
|
||||||
|
sample['image'] = image
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMarkdownPDFDataset(Dataset):
|
||||||
|
"""Base dataset class that loads and verifies markdown-PDF pairs."""
|
||||||
|
|
||||||
|
def __init__(self, root_dir: str | PathLike, pipeline_steps: Optional[List[PipelineStep]] = None):
|
||||||
|
"""
|
||||||
|
Initialize the dataset by finding all markdown files with corresponding PDFs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Path to the root folder containing processed markdown and PDF files
|
||||||
|
pipeline_steps: Optional list of pipeline steps to apply to each sample
|
||||||
|
"""
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.pipeline_steps = pipeline_steps or []
|
||||||
|
self.samples = []
|
||||||
|
|
||||||
|
# Find all markdown files recursively
|
||||||
|
print(f"Scanning for markdown files in {self.root_dir}...")
|
||||||
|
md_files = list(self.root_dir.rglob("*.md"))
|
||||||
|
|
||||||
|
# Verify each markdown file has a corresponding PDF
|
||||||
|
valid_count = 0
|
||||||
|
invalid_pdfs = []
|
||||||
|
|
||||||
|
print(f"Validating {len(md_files)} markdown-PDF pairs...")
|
||||||
|
for md_path in tqdm(md_files, desc="Validating PDFs"):
|
||||||
|
# Look for PDF with same stem (filename without extension)
|
||||||
|
pdf_path = md_path.with_suffix('.pdf')
|
||||||
|
|
||||||
|
if pdf_path.exists() or pdf_path.is_symlink():
|
||||||
|
# Resolve symlink if it is one
|
||||||
|
if pdf_path.is_symlink():
|
||||||
|
pdf_path = pdf_path.resolve()
|
||||||
|
|
||||||
|
# Verify the resolved path exists
|
||||||
|
if pdf_path.exists():
|
||||||
|
# Validate PDF - check it loads and has exactly one page
|
||||||
|
try:
|
||||||
|
reader = PdfReader(str(pdf_path))
|
||||||
|
num_pages = len(reader.pages)
|
||||||
|
|
||||||
|
if num_pages != 1:
|
||||||
|
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.samples.append({
|
||||||
|
'markdown_path': md_path,
|
||||||
|
'pdf_path': pdf_path
|
||||||
|
})
|
||||||
|
valid_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
|
||||||
|
|
||||||
|
print(f"Found {valid_count} valid markdown-PDF pairs")
|
||||||
|
|
||||||
|
if invalid_pdfs:
|
||||||
|
print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:")
|
||||||
|
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
|
||||||
|
print(f" - {pdf_path.name}: {reason}")
|
||||||
|
if len(invalid_pdfs) > 5:
|
||||||
|
print(f" ... and {len(invalid_pdfs) - 5} more")
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.samples)
|
return len(self.samples)
|
||||||
|
|
||||||
@ -173,40 +238,43 @@ class MarkdownPDFDocumentDataset(Dataset):
|
|||||||
Get a single sample from the dataset.
|
Get a single sample from the dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict containing:
|
dict containing at minimum:
|
||||||
- 'image': PIL Image of the rendered PDF page
|
- 'markdown_path': Path to the markdown file
|
||||||
- 'pdf_path': Path to the PDF file
|
- 'pdf_path': Path to the PDF file
|
||||||
- 'text': Text content without front matter
|
|
||||||
- 'front_matter': Dict with parsed front matter
|
Additional fields will be added by pipeline steps.
|
||||||
"""
|
"""
|
||||||
sample = self.samples[idx]
|
# Start with basic sample info
|
||||||
|
sample = self.samples[idx].copy()
|
||||||
|
|
||||||
# Read and parse markdown file
|
# Apply pipeline steps
|
||||||
markdown_content = sample['markdown_path'].read_text(encoding='utf-8')
|
for step in self.pipeline_steps:
|
||||||
front_matter_str, text = self._extract_front_matter_and_text(markdown_content)
|
sample = step.process(sample)
|
||||||
front_matter = self._parse_front_matter_string(front_matter_str)
|
|
||||||
|
|
||||||
# Render PDF to image
|
return sample
|
||||||
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1, target_longest_image_dim=self.target_longest_image_dim)
|
|
||||||
png_bytes = base64.b64decode(base64_png)
|
|
||||||
image = Image.open(BytesIO(png_bytes))
|
|
||||||
|
|
||||||
# Apply transform if provided
|
|
||||||
if self.image_transform:
|
|
||||||
image = self.image_transform(image)
|
|
||||||
|
|
||||||
# Parse front matter to dataclass if specified
|
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
||||||
try:
|
"""Dataset that includes front matter parsing and PDF rendering by default."""
|
||||||
parsed_front_matter = self._parse_front_matter(front_matter)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
|
|
||||||
|
|
||||||
return {
|
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
|
||||||
'image': image,
|
"""
|
||||||
'pdf_path': str(sample['pdf_path']),
|
Initialize the dataset with default pipeline steps.
|
||||||
'text': text,
|
|
||||||
'front_matter': parsed_front_matter
|
Args:
|
||||||
}
|
root_dir: Path to the root folder containing processed markdown and PDF files
|
||||||
|
target_longest_image_dim: Target dimension for the longest side of the image
|
||||||
|
image_transform: Optional transform to apply to the PDF images
|
||||||
|
front_matter_class: Optional dataclass type to validate front matter against
|
||||||
|
"""
|
||||||
|
# Create default pipeline steps
|
||||||
|
pipeline_steps = [
|
||||||
|
FrontMatterParser(front_matter_class),
|
||||||
|
PDFRenderer(target_longest_image_dim, image_transform)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize base class with pipeline
|
||||||
|
super().__init__(root_dir, pipeline_steps)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -222,11 +290,46 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Test dataset initialization
|
# Test base dataset without any pipeline steps
|
||||||
print(f"\nTesting dataset with root directory: {args.root_dir}")
|
print(f"\n=== Testing base dataset without pipeline steps ===")
|
||||||
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, front_matter_class=StandardFrontMatter, image_transform=None)
|
base_dataset = BaseMarkdownPDFDataset(args.root_dir)
|
||||||
|
print(f"Dataset length: {len(base_dataset)}")
|
||||||
|
|
||||||
print(f"\nDataset length: {len(dataset)}")
|
if len(base_dataset) > 0:
|
||||||
|
print("\nFirst sample (no pipeline):")
|
||||||
|
sample = base_dataset[0]
|
||||||
|
print(f" Keys: {list(sample.keys())}")
|
||||||
|
print(f" Markdown: {sample['markdown_path'].name}")
|
||||||
|
print(f" PDF: {sample['pdf_path'].name}")
|
||||||
|
|
||||||
|
# Test with individual pipeline steps
|
||||||
|
print(f"\n=== Testing with individual pipeline steps ===")
|
||||||
|
pipeline_dataset = BaseMarkdownPDFDataset(
|
||||||
|
args.root_dir,
|
||||||
|
pipeline_steps=[
|
||||||
|
FrontMatterParser(StandardFrontMatter),
|
||||||
|
PDFRenderer(target_longest_image_dim=1024)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(pipeline_dataset) > 0:
|
||||||
|
print("\nFirst sample (with pipeline):")
|
||||||
|
sample = pipeline_dataset[0]
|
||||||
|
print(f" Keys: {list(sample.keys())}")
|
||||||
|
print(f" Front Matter: {sample['front_matter']}")
|
||||||
|
print(f" Image size: {sample['image'].size}")
|
||||||
|
print(f" Text preview: {sample['text'][:100]}...")
|
||||||
|
|
||||||
|
# Test the convenience dataset class
|
||||||
|
print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===")
|
||||||
|
dataset = MarkdownPDFDocumentDataset(
|
||||||
|
args.root_dir,
|
||||||
|
target_longest_image_dim=1024,
|
||||||
|
front_matter_class=StandardFrontMatter,
|
||||||
|
image_transform=None
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Dataset length: {len(dataset)}")
|
||||||
|
|
||||||
if len(dataset) > 0:
|
if len(dataset) > 0:
|
||||||
# Show first few samples
|
# Show first few samples
|
||||||
@ -242,4 +345,4 @@ if __name__ == "__main__":
|
|||||||
print(f"Image size: {first_sample['image'].size}")
|
print(f"Image size: {first_sample['image'].size}")
|
||||||
print(f"PDF Path: {first_sample['pdf_path']}")
|
print(f"PDF Path: {first_sample['pdf_path']}")
|
||||||
print(f"Front Matter: {first_sample['front_matter']}")
|
print(f"Front Matter: {first_sample['front_matter']}")
|
||||||
print(f"Text: {first_sample['text']}...")
|
print(f"Text: {first_sample['text'][:200]}...")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user