Checking that anchor text works for each pdf page when initializing dataloader

This commit is contained in:
Jake Poznanski 2025-06-30 16:29:33 +00:00
parent dc7fff5bf7
commit 8e5e18f54c
6 changed files with 169 additions and 173 deletions

View File

@ -5,6 +5,7 @@ import random
import re import re
import subprocess import subprocess
from dataclasses import dataclass from dataclasses import dataclass
from os import PathLike
from typing import List, Literal from typing import List, Literal
import ftfy import ftfy
@ -16,7 +17,7 @@ from olmocr.filter.coherency import get_document_coherency
def get_anchor_text( def get_anchor_text(
local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000 local_pdf_path: str | PathLike, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
) -> str: ) -> str:
assert page > 0, "Pages are 1-indexed in pdf-land" assert page > 0, "Pages are 1-indexed in pdf-land"

View File

@ -12,6 +12,7 @@ from omegaconf import DictConfig, OmegaConf
@dataclass @dataclass
class PipelineStepConfig: class PipelineStepConfig:
"""Base configuration for pipeline steps.""" """Base configuration for pipeline steps."""
name: str name: str
enabled: bool = True enabled: bool = True
@ -19,6 +20,7 @@ class PipelineStepConfig:
@dataclass @dataclass
class FrontMatterParserConfig(PipelineStepConfig): class FrontMatterParserConfig(PipelineStepConfig):
"""Configuration for FrontMatterParser step.""" """Configuration for FrontMatterParser step."""
name: str = "FrontMatterParser" name: str = "FrontMatterParser"
use_page_response_class: bool = True # Whether to use PageResponse dataclass use_page_response_class: bool = True # Whether to use PageResponse dataclass
@ -26,6 +28,7 @@ class FrontMatterParserConfig(PipelineStepConfig):
@dataclass @dataclass
class PDFRendererConfig(PipelineStepConfig): class PDFRendererConfig(PipelineStepConfig):
"""Configuration for PDFRenderer step.""" """Configuration for PDFRenderer step."""
name: str = "PDFRenderer" name: str = "PDFRenderer"
target_longest_image_dim: int = 1024 target_longest_image_dim: int = 1024
@ -33,6 +36,7 @@ class PDFRendererConfig(PipelineStepConfig):
@dataclass @dataclass
class StaticLengthDocumentAnchoringConfig(PipelineStepConfig): class StaticLengthDocumentAnchoringConfig(PipelineStepConfig):
"""Configuration for StaticLengthDocumentAnchoring step.""" """Configuration for StaticLengthDocumentAnchoring step."""
name: str = "StaticLengthDocumentAnchoring" name: str = "StaticLengthDocumentAnchoring"
target_anchor_text_len: int = 6000 target_anchor_text_len: int = 6000
@ -40,24 +44,28 @@ class StaticLengthDocumentAnchoringConfig(PipelineStepConfig):
@dataclass @dataclass
class FinetuningPromptConfig(PipelineStepConfig): class FinetuningPromptConfig(PipelineStepConfig):
"""Configuration for FinetuningPrompt step.""" """Configuration for FinetuningPrompt step."""
name: str = "FinetuningPrompt" name: str = "FinetuningPrompt"
@dataclass @dataclass
class FrontMatterOutputFormatConfig(PipelineStepConfig): class FrontMatterOutputFormatConfig(PipelineStepConfig):
"""Configuration for FrontMatterOutputFormat step.""" """Configuration for FrontMatterOutputFormat step."""
name: str = "FrontMatterOutputFormat" name: str = "FrontMatterOutputFormat"
@dataclass @dataclass
class InstructUserMessagesConfig(PipelineStepConfig): class InstructUserMessagesConfig(PipelineStepConfig):
"""Configuration for InstructUserMessages step.""" """Configuration for InstructUserMessages step."""
name: str = "InstructUserMessages" name: str = "InstructUserMessages"
@dataclass @dataclass
class TokenizerStepConfig(PipelineStepConfig): class TokenizerStepConfig(PipelineStepConfig):
"""Configuration for Tokenizer step.""" """Configuration for Tokenizer step."""
name: str = "Tokenizer" name: str = "Tokenizer"
masking_index: int = -100 masking_index: int = -100
end_of_message_token: str = "<|im_end|>" end_of_message_token: str = "<|im_end|>"
@ -66,6 +74,7 @@ class TokenizerStepConfig(PipelineStepConfig):
@dataclass @dataclass
class DatasetItemConfig: class DatasetItemConfig:
"""Configuration for a single dataset item.""" """Configuration for a single dataset item."""
root_dir: str root_dir: str
pipeline: List[Dict[str, Any]] = field(default_factory=list) pipeline: List[Dict[str, Any]] = field(default_factory=list)
@ -76,6 +85,7 @@ class DatasetItemConfig:
@dataclass @dataclass
class DatasetConfig: class DatasetConfig:
"""Configuration for dataset and data loading.""" """Configuration for dataset and data loading."""
train: List[Dict[str, Any]] = field(default_factory=list) train: List[Dict[str, Any]] = field(default_factory=list)
eval: List[Dict[str, Any]] = field(default_factory=list) eval: List[Dict[str, Any]] = field(default_factory=list)
@ -83,6 +93,7 @@ class DatasetConfig:
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""Configuration for model.""" """Configuration for model."""
name: str = "Qwen/Qwen2.5-VL-7B-Instruct" name: str = "Qwen/Qwen2.5-VL-7B-Instruct"
trust_remote_code: bool = False trust_remote_code: bool = False
@ -112,6 +123,7 @@ class ModelConfig:
@dataclass @dataclass
class TrainingConfig: class TrainingConfig:
"""Configuration for training parameters.""" """Configuration for training parameters."""
output_dir: str = "./outputs" output_dir: str = "./outputs"
num_train_epochs: int = 3 num_train_epochs: int = 3
per_device_train_batch_size: int = 1 per_device_train_batch_size: int = 1
@ -181,6 +193,7 @@ class TrainingConfig:
@dataclass @dataclass
class Config: class Config:
"""Main configuration class that combines all sub-configs.""" """Main configuration class that combines all sub-configs."""
model: ModelConfig model: ModelConfig
dataset: DatasetConfig dataset: DatasetConfig
training: TrainingConfig training: TrainingConfig
@ -208,7 +221,7 @@ class Config:
raise FileNotFoundError(f"Config file not found: {yaml_path}") raise FileNotFoundError(f"Config file not found: {yaml_path}")
# Load YAML with OmegaConf for better features # Load YAML with OmegaConf for better features
with open(yaml_path, 'r') as f: with open(yaml_path, "r") as f:
yaml_content = yaml.safe_load(f) yaml_content = yaml.safe_load(f)
# Create OmegaConf config for interpolation and validation # Create OmegaConf config for interpolation and validation
@ -221,20 +234,14 @@ class Config:
cfg_dict = OmegaConf.to_container(cfg, resolve=True) cfg_dict = OmegaConf.to_container(cfg, resolve=True)
# Create sub-configs # Create sub-configs
model_cfg = ModelConfig(**cfg_dict.get('model', {})) model_cfg = ModelConfig(**cfg_dict.get("model", {}))
dataset_cfg = DatasetConfig(**cfg_dict.get('dataset', {})) dataset_cfg = DatasetConfig(**cfg_dict.get("dataset", {}))
training_cfg = TrainingConfig(**cfg_dict.get('training', {})) training_cfg = TrainingConfig(**cfg_dict.get("training", {}))
# Create main config # Create main config
main_cfg_dict = {k: v for k, v in cfg_dict.items() main_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in ["model", "dataset", "training"]}
if k not in ['model', 'dataset', 'training']}
return cls( return cls(model=model_cfg, dataset=dataset_cfg, training=training_cfg, **main_cfg_dict)
model=model_cfg,
dataset=dataset_cfg,
training=training_cfg,
**main_cfg_dict
)
def to_yaml(self, yaml_path: Union[str, Path]) -> None: def to_yaml(self, yaml_path: Union[str, Path]) -> None:
"""Save configuration to YAML file.""" """Save configuration to YAML file."""
@ -244,7 +251,7 @@ class Config:
# Convert to OmegaConf for nice YAML output # Convert to OmegaConf for nice YAML output
cfg = OmegaConf.structured(self) cfg = OmegaConf.structured(self)
with open(yaml_path, 'w') as f: with open(yaml_path, "w") as f:
OmegaConf.save(cfg, f) OmegaConf.save(cfg, f)
def validate(self) -> None: def validate(self) -> None:
@ -252,7 +259,7 @@ class Config:
# Dataset validation - check all train and eval datasets # Dataset validation - check all train and eval datasets
for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]: for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]:
for i, dataset_cfg in enumerate(datasets): for i, dataset_cfg in enumerate(datasets):
root_dir = dataset_cfg.get('root_dir') root_dir = dataset_cfg.get("root_dir")
if not root_dir: if not root_dir:
raise ValueError(f"Missing root_dir for {split_name} dataset {i}") raise ValueError(f"Missing root_dir for {split_name} dataset {i}")
if not os.path.exists(root_dir): if not os.path.exists(root_dir):
@ -284,60 +291,59 @@ class Config:
Returns: Returns:
List of initialized pipeline step instances List of initialized pipeline step instances
""" """
from olmocr.prompts.prompts import PageResponse
from olmocr.train.dataloader import ( from olmocr.train.dataloader import (
FrontMatterParser,
PDFRenderer,
StaticLengthDocumentAnchoring,
FinetuningPrompt, FinetuningPrompt,
FrontMatterOutputFormat, FrontMatterOutputFormat,
FrontMatterParser,
InstructUserMessages, InstructUserMessages,
Tokenizer PDFRenderer,
StaticLengthDocumentAnchoring,
Tokenizer,
) )
from olmocr.prompts.prompts import PageResponse
steps = [] steps = []
for step_config in pipeline_config: for step_config in pipeline_config:
if not step_config.get('enabled', True): if not step_config.get("enabled", True):
continue continue
step_name = step_config['name'] step_name = step_config["name"]
if step_name == 'FrontMatterParser': if step_name == "FrontMatterParser":
# Handle both old and new config format # Handle both old and new config format
if 'front_matter_class' in step_config: if "front_matter_class" in step_config:
front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None front_matter_class = PageResponse if step_config["front_matter_class"] == "PageResponse" else None
else: else:
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None front_matter_class = PageResponse if step_config.get("use_page_response_class", True) else None
steps.append(FrontMatterParser(front_matter_class=front_matter_class)) steps.append(FrontMatterParser(front_matter_class=front_matter_class))
elif step_name == 'PDFRenderer': elif step_name == "PDFRenderer":
steps.append(PDFRenderer( steps.append(
target_longest_image_dim=step_config.get('target_longest_image_dim', 1024), PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024), image_transform=None) # Can be extended later
image_transform=None # Can be extended later )
))
elif step_name == 'StaticLengthDocumentAnchoring': elif step_name == "StaticLengthDocumentAnchoring":
steps.append(StaticLengthDocumentAnchoring( steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
target_anchor_text_len=step_config.get('target_anchor_text_len', 6000)
))
elif step_name == 'FinetuningPrompt': elif step_name == "FinetuningPrompt":
steps.append(FinetuningPrompt()) steps.append(FinetuningPrompt())
elif step_name == 'FrontMatterOutputFormat': elif step_name == "FrontMatterOutputFormat":
steps.append(FrontMatterOutputFormat()) steps.append(FrontMatterOutputFormat())
elif step_name == 'InstructUserMessages': elif step_name == "InstructUserMessages":
steps.append(InstructUserMessages()) steps.append(InstructUserMessages())
elif step_name == 'Tokenizer': elif step_name == "Tokenizer":
if processor is None: if processor is None:
raise ValueError("Processor must be provided for Tokenizer step") raise ValueError("Processor must be provided for Tokenizer step")
steps.append(Tokenizer( steps.append(
Tokenizer(
processor=processor, processor=processor,
masking_index=step_config.get('masking_index', -100), masking_index=step_config.get("masking_index", -100),
end_of_message_token=step_config.get('end_of_message_token', '<|im_end|>') end_of_message_token=step_config.get("end_of_message_token", "<|im_end|>"),
)) )
)
else: else:
raise ValueError(f"Unknown pipeline step: {step_name}") raise ValueError(f"Unknown pipeline step: {step_name}")
@ -346,11 +352,7 @@ class Config:
def create_default_config() -> Config: def create_default_config() -> Config:
"""Create a default configuration.""" """Create a default configuration."""
return Config( return Config(model=ModelConfig(), dataset=DatasetConfig(), training=TrainingConfig())
model=ModelConfig(),
dataset=DatasetConfig(),
training=TrainingConfig()
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -71,7 +71,7 @@ class BaseMarkdownPDFDataset(Dataset):
# Verify the resolved path exists # Verify the resolved path exists
if pdf_path.exists(): if pdf_path.exists():
# Validate PDF - check it loads and has exactly one page # Validate PDF - check it loads and has exactly one page and that you can get document-anchoring from it
try: try:
reader = PdfReader(str(pdf_path)) reader = PdfReader(str(pdf_path))
num_pages = len(reader.pages) num_pages = len(reader.pages)
@ -80,6 +80,9 @@ class BaseMarkdownPDFDataset(Dataset):
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}")) invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
continue continue
# Test that document anchoring works
get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100)
self.samples.append({"markdown_path": md_path, "pdf_path": pdf_path}) self.samples.append({"markdown_path": md_path, "pdf_path": pdf_path})
valid_count += 1 valid_count += 1

View File

@ -4,18 +4,18 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse import argparse
import logging import logging
import numpy as np
from transformers import ( import numpy as np
AutoProcessor,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Trainer,
TrainingArguments,
EarlyStoppingCallback
)
import torch import torch
from torch.utils.data import ConcatDataset from torch.utils.data import ConcatDataset
from transformers import (
AutoProcessor,
EarlyStoppingCallback,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
Trainer,
TrainingArguments,
)
from olmocr.train.config import Config from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset from olmocr.train.dataloader import BaseMarkdownPDFDataset
@ -34,53 +34,44 @@ class QwenDataCollator:
def __call__(self, examples): def __call__(self, examples):
# Filter out None values and extract the fields we need # Filter out None values and extract the fields we need
batch = { batch = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": [], "image_grid_thw": []}
'input_ids': [],
'attention_mask': [],
'labels': [],
'pixel_values': [],
'image_grid_thw': []
}
for example in examples: for example in examples:
if example is not None: if example is not None:
# Convert numpy arrays to tensors # Convert numpy arrays to tensors
batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids']) batch["input_ids"].append(torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"])
batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask']) batch["attention_mask"].append(
batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels']) torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"]
)
batch["labels"].append(torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"])
# Handle pixel_values which might be numpy array or already a tensor # Handle pixel_values which might be numpy array or already a tensor
pixel_values = example['pixel_values'] pixel_values = example["pixel_values"]
if isinstance(pixel_values, np.ndarray): if isinstance(pixel_values, np.ndarray):
pixel_values = torch.from_numpy(pixel_values) pixel_values = torch.from_numpy(pixel_values)
batch['pixel_values'].append(pixel_values) batch["pixel_values"].append(pixel_values)
# Handle image_grid_thw # Handle image_grid_thw
image_grid_thw = example['image_grid_thw'] image_grid_thw = example["image_grid_thw"]
if isinstance(image_grid_thw, np.ndarray): if isinstance(image_grid_thw, np.ndarray):
image_grid_thw = torch.from_numpy(image_grid_thw) image_grid_thw = torch.from_numpy(image_grid_thw)
batch['image_grid_thw'].append(image_grid_thw) batch["image_grid_thw"].append(image_grid_thw)
# Convert lists to tensors with proper padding # Convert lists to tensors with proper padding
# Note: For Qwen2-VL, we typically handle variable length sequences # Note: For Qwen2-VL, we typically handle variable length sequences
# The model's processor should handle the padding internally # The model's processor should handle the padding internally
return { return {
'input_ids': torch.stack(batch['input_ids']), "input_ids": torch.stack(batch["input_ids"]),
'attention_mask': torch.stack(batch['attention_mask']), "attention_mask": torch.stack(batch["attention_mask"]),
'labels': torch.stack(batch['labels']), "labels": torch.stack(batch["labels"]),
'pixel_values': torch.stack(batch['pixel_values']), # Stack into tensor "pixel_values": torch.stack(batch["pixel_values"]), # Stack into tensor
'image_grid_thw': torch.stack(batch['image_grid_thw']) "image_grid_thw": torch.stack(batch["image_grid_thw"]),
} }
def main(): def main():
parser = argparse.ArgumentParser(description="Train OlmOCR model") parser = argparse.ArgumentParser(description="Train OlmOCR model")
parser.add_argument( parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
"--config",
type=str,
default="olmocr/train/configs/example_config.yaml",
help="Path to YAML configuration file"
)
args = parser.parse_args() args = parser.parse_args()
@ -130,8 +121,8 @@ def main():
logger.info("Creating training datasets...") logger.info("Creating training datasets...")
train_datasets = [] train_datasets = []
for i, dataset_cfg in enumerate(config.dataset.train): for i, dataset_cfg in enumerate(config.dataset.train):
root_dir = dataset_cfg['root_dir'] root_dir = dataset_cfg["root_dir"]
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
logger.info(f"Creating training dataset {i+1} from: {root_dir}") logger.info(f"Creating training dataset {i+1} from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
@ -148,11 +139,11 @@ def main():
logger.info("Creating evaluation datasets...") logger.info("Creating evaluation datasets...")
eval_datasets = {} eval_datasets = {}
for i, dataset_cfg in enumerate(config.dataset.eval): for i, dataset_cfg in enumerate(config.dataset.eval):
root_dir = dataset_cfg['root_dir'] root_dir = dataset_cfg["root_dir"]
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
# Use dataset name if provided, otherwise use root_dir as name # Use dataset name if provided, otherwise use root_dir as name
dataset_name = dataset_cfg.get('name', f"eval_dataset_{i+1}") dataset_name = dataset_cfg.get("name", f"eval_dataset_{i+1}")
logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}") logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
@ -214,8 +205,7 @@ def main():
if config.training.use_early_stopping: if config.training.use_early_stopping:
callbacks.append( callbacks.append(
EarlyStoppingCallback( EarlyStoppingCallback(
early_stopping_patience=config.training.early_stopping_patience, early_stopping_patience=config.training.early_stopping_patience, early_stopping_threshold=config.training.early_stopping_threshold
early_stopping_threshold=config.training.early_stopping_threshold
) )
) )