Checking that anchor text works for each pdf page when initializing dataloader

This commit is contained in:
Jake Poznanski 2025-06-30 16:29:33 +00:00
parent dc7fff5bf7
commit 8e5e18f54c
6 changed files with 169 additions and 173 deletions

View File

@ -1225,4 +1225,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -5,6 +5,7 @@ import random
import re import re
import subprocess import subprocess
from dataclasses import dataclass from dataclasses import dataclass
from os import PathLike
from typing import List, Literal from typing import List, Literal
import ftfy import ftfy
@ -16,7 +17,7 @@ from olmocr.filter.coherency import get_document_coherency
def get_anchor_text( def get_anchor_text(
local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000 local_pdf_path: str | PathLike, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
) -> str: ) -> str:
assert page > 0, "Pages are 1-indexed in pdf-land" assert page > 0, "Pages are 1-indexed in pdf-land"

View File

@ -25,4 +25,4 @@ oneshot(model=model, recipe=recipe)
# Save the model. # Save the model.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe" SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe"
model.save_pretrained(SAVE_DIR) model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR)

View File

@ -12,6 +12,7 @@ from omegaconf import DictConfig, OmegaConf
@dataclass @dataclass
class PipelineStepConfig: class PipelineStepConfig:
"""Base configuration for pipeline steps.""" """Base configuration for pipeline steps."""
name: str name: str
enabled: bool = True enabled: bool = True
@ -19,6 +20,7 @@ class PipelineStepConfig:
@dataclass @dataclass
class FrontMatterParserConfig(PipelineStepConfig): class FrontMatterParserConfig(PipelineStepConfig):
"""Configuration for FrontMatterParser step.""" """Configuration for FrontMatterParser step."""
name: str = "FrontMatterParser" name: str = "FrontMatterParser"
use_page_response_class: bool = True # Whether to use PageResponse dataclass use_page_response_class: bool = True # Whether to use PageResponse dataclass
@ -26,6 +28,7 @@ class FrontMatterParserConfig(PipelineStepConfig):
@dataclass @dataclass
class PDFRendererConfig(PipelineStepConfig): class PDFRendererConfig(PipelineStepConfig):
"""Configuration for PDFRenderer step.""" """Configuration for PDFRenderer step."""
name: str = "PDFRenderer" name: str = "PDFRenderer"
target_longest_image_dim: int = 1024 target_longest_image_dim: int = 1024
@ -33,6 +36,7 @@ class PDFRendererConfig(PipelineStepConfig):
@dataclass @dataclass
class StaticLengthDocumentAnchoringConfig(PipelineStepConfig): class StaticLengthDocumentAnchoringConfig(PipelineStepConfig):
"""Configuration for StaticLengthDocumentAnchoring step.""" """Configuration for StaticLengthDocumentAnchoring step."""
name: str = "StaticLengthDocumentAnchoring" name: str = "StaticLengthDocumentAnchoring"
target_anchor_text_len: int = 6000 target_anchor_text_len: int = 6000
@ -40,24 +44,28 @@ class StaticLengthDocumentAnchoringConfig(PipelineStepConfig):
@dataclass @dataclass
class FinetuningPromptConfig(PipelineStepConfig): class FinetuningPromptConfig(PipelineStepConfig):
"""Configuration for FinetuningPrompt step.""" """Configuration for FinetuningPrompt step."""
name: str = "FinetuningPrompt" name: str = "FinetuningPrompt"
@dataclass @dataclass
class FrontMatterOutputFormatConfig(PipelineStepConfig): class FrontMatterOutputFormatConfig(PipelineStepConfig):
"""Configuration for FrontMatterOutputFormat step.""" """Configuration for FrontMatterOutputFormat step."""
name: str = "FrontMatterOutputFormat" name: str = "FrontMatterOutputFormat"
@dataclass @dataclass
class InstructUserMessagesConfig(PipelineStepConfig): class InstructUserMessagesConfig(PipelineStepConfig):
"""Configuration for InstructUserMessages step.""" """Configuration for InstructUserMessages step."""
name: str = "InstructUserMessages" name: str = "InstructUserMessages"
@dataclass @dataclass
class TokenizerStepConfig(PipelineStepConfig): class TokenizerStepConfig(PipelineStepConfig):
"""Configuration for Tokenizer step.""" """Configuration for Tokenizer step."""
name: str = "Tokenizer" name: str = "Tokenizer"
masking_index: int = -100 masking_index: int = -100
end_of_message_token: str = "<|im_end|>" end_of_message_token: str = "<|im_end|>"
@ -66,16 +74,18 @@ class TokenizerStepConfig(PipelineStepConfig):
@dataclass @dataclass
class DatasetItemConfig: class DatasetItemConfig:
"""Configuration for a single dataset item.""" """Configuration for a single dataset item."""
root_dir: str root_dir: str
pipeline: List[Dict[str, Any]] = field(default_factory=list) pipeline: List[Dict[str, Any]] = field(default_factory=list)
# Optional sampling # Optional sampling
max_samples: Optional[int] = None max_samples: Optional[int] = None
@dataclass
@dataclass
class DatasetConfig: class DatasetConfig:
"""Configuration for dataset and data loading.""" """Configuration for dataset and data loading."""
train: List[Dict[str, Any]] = field(default_factory=list) train: List[Dict[str, Any]] = field(default_factory=list)
eval: List[Dict[str, Any]] = field(default_factory=list) eval: List[Dict[str, Any]] = field(default_factory=list)
@ -83,23 +93,24 @@ class DatasetConfig:
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""Configuration for model.""" """Configuration for model."""
name: str = "Qwen/Qwen2.5-VL-7B-Instruct" name: str = "Qwen/Qwen2.5-VL-7B-Instruct"
trust_remote_code: bool = False trust_remote_code: bool = False
# Model initialization # Model initialization
load_in_8bit: bool = False load_in_8bit: bool = False
load_in_4bit: bool = False load_in_4bit: bool = False
device_map: Optional[Union[str, Dict[str, Any]]] = "auto" device_map: Optional[Union[str, Dict[str, Any]]] = "auto"
torch_dtype: str = "auto" # "auto", "float16", "bfloat16", "float32" torch_dtype: str = "auto" # "auto", "float16", "bfloat16", "float32"
# Flash attention # Flash attention
use_flash_attention: bool = True use_flash_attention: bool = True
attn_implementation: Optional[str] = None # "flash_attention_2", "sdpa", "eager" attn_implementation: Optional[str] = None # "flash_attention_2", "sdpa", "eager"
# Model modifications # Model modifications
freeze_vision_tower: bool = False freeze_vision_tower: bool = False
freeze_language_model: bool = False freeze_language_model: bool = False
# LoRA configuration (optional) # LoRA configuration (optional)
use_lora: bool = False use_lora: bool = False
lora_rank: int = 8 lora_rank: int = 8
@ -107,22 +118,23 @@ class ModelConfig:
lora_dropout: float = 0.1 lora_dropout: float = 0.1
lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj"]) lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj"])
lora_modules_to_save: Optional[List[str]] = None lora_modules_to_save: Optional[List[str]] = None
@dataclass @dataclass
class TrainingConfig: class TrainingConfig:
"""Configuration for training parameters.""" """Configuration for training parameters."""
output_dir: str = "./outputs" output_dir: str = "./outputs"
num_train_epochs: int = 3 num_train_epochs: int = 3
per_device_train_batch_size: int = 1 per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1 per_device_eval_batch_size: int = 1
gradient_accumulation_steps: int = 8 gradient_accumulation_steps: int = 8
# Learning rate and scheduler # Learning rate and scheduler
learning_rate: float = 2e-5 learning_rate: float = 2e-5
lr_scheduler_type: str = "cosine" lr_scheduler_type: str = "cosine"
warmup_ratio: float = 0.1 warmup_ratio: float = 0.1
# Optimization # Optimization
optim: str = "adamw_torch" optim: str = "adamw_torch"
adam_beta1: float = 0.9 adam_beta1: float = 0.9
@ -130,16 +142,16 @@ class TrainingConfig:
adam_epsilon: float = 1e-8 adam_epsilon: float = 1e-8
weight_decay: float = 0.01 weight_decay: float = 0.01
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
# Gradient checkpointing # Gradient checkpointing
gradient_checkpointing: bool = True gradient_checkpointing: bool = True
gradient_checkpointing_kwargs: Dict[str, Any] = field(default_factory=lambda: {"use_reentrant": False}) gradient_checkpointing_kwargs: Dict[str, Any] = field(default_factory=lambda: {"use_reentrant": False})
# Mixed precision # Mixed precision
fp16: bool = False fp16: bool = False
bf16: bool = True bf16: bool = True
tf32: bool = True # Enable TF32 on Ampere GPUs tf32: bool = True # Enable TF32 on Ampere GPUs
# Evaluation and checkpointing # Evaluation and checkpointing
evaluation_strategy: str = "steps" evaluation_strategy: str = "steps"
eval_steps: int = 500 eval_steps: int = 500
@ -149,29 +161,29 @@ class TrainingConfig:
load_best_model_at_end: bool = True load_best_model_at_end: bool = True
metric_for_best_model: str = "eval_loss" metric_for_best_model: str = "eval_loss"
greater_is_better: bool = False greater_is_better: bool = False
# Logging # Logging
logging_dir: Optional[str] = None logging_dir: Optional[str] = None
logging_strategy: str = "steps" logging_strategy: str = "steps"
logging_steps: int = 10 logging_steps: int = 10
logging_first_step: bool = True logging_first_step: bool = True
report_to: List[str] = field(default_factory=lambda: ["wandb"]) report_to: List[str] = field(default_factory=lambda: ["wandb"])
# Other training settings # Other training settings
seed: int = 42 seed: int = 42
data_seed: Optional[int] = None data_seed: Optional[int] = None
# Resume from checkpoint # Resume from checkpoint
resume_from_checkpoint: Optional[str] = None resume_from_checkpoint: Optional[str] = None
# DeepSpeed # DeepSpeed
deepspeed: Optional[str] = None deepspeed: Optional[str] = None
# Performance # Performance
dataloader_drop_last: bool = True dataloader_drop_last: bool = True
dataloader_num_workers: int = 4 dataloader_num_workers: int = 4
remove_unused_columns: bool = False # Important for custom datasets remove_unused_columns: bool = False # Important for custom datasets
# Early stopping # Early stopping
use_early_stopping: bool = False use_early_stopping: bool = False
early_stopping_patience: int = 3 early_stopping_patience: int = 3
@ -181,176 +193,166 @@ class TrainingConfig:
@dataclass @dataclass
class Config: class Config:
"""Main configuration class that combines all sub-configs.""" """Main configuration class that combines all sub-configs."""
model: ModelConfig model: ModelConfig
dataset: DatasetConfig dataset: DatasetConfig
training: TrainingConfig training: TrainingConfig
# Environment # Environment
project_name: str = "olmocr-training" project_name: str = "olmocr-training"
run_name: Optional[str] = None run_name: Optional[str] = None
tags: List[str] = field(default_factory=list) tags: List[str] = field(default_factory=list)
notes: Optional[str] = None notes: Optional[str] = None
# Experiment tracking # Experiment tracking
experiment_tracker: str = "tensorboard" # "tensorboard", "wandb", "mlflow" experiment_tracker: str = "tensorboard" # "tensorboard", "wandb", "mlflow"
wandb_project: Optional[str] = None wandb_project: Optional[str] = None
wandb_entity: Optional[str] = None wandb_entity: Optional[str] = None
# Distributed training # Distributed training
distributed: bool = False distributed: bool = False
local_rank: int = -1 local_rank: int = -1
@classmethod @classmethod
def from_yaml(cls, yaml_path: Union[str, Path]) -> "Config": def from_yaml(cls, yaml_path: Union[str, Path]) -> "Config":
"""Load configuration from YAML file.""" """Load configuration from YAML file."""
yaml_path = Path(yaml_path) yaml_path = Path(yaml_path)
if not yaml_path.exists(): if not yaml_path.exists():
raise FileNotFoundError(f"Config file not found: {yaml_path}") raise FileNotFoundError(f"Config file not found: {yaml_path}")
# Load YAML with OmegaConf for better features # Load YAML with OmegaConf for better features
with open(yaml_path, 'r') as f: with open(yaml_path, "r") as f:
yaml_content = yaml.safe_load(f) yaml_content = yaml.safe_load(f)
# Create OmegaConf config for interpolation and validation # Create OmegaConf config for interpolation and validation
cfg = OmegaConf.create(yaml_content) cfg = OmegaConf.create(yaml_content)
# Resolve any interpolations # Resolve any interpolations
OmegaConf.resolve(cfg) OmegaConf.resolve(cfg)
# Convert to dict and create dataclass # Convert to dict and create dataclass
cfg_dict = OmegaConf.to_container(cfg, resolve=True) cfg_dict = OmegaConf.to_container(cfg, resolve=True)
# Create sub-configs # Create sub-configs
model_cfg = ModelConfig(**cfg_dict.get('model', {})) model_cfg = ModelConfig(**cfg_dict.get("model", {}))
dataset_cfg = DatasetConfig(**cfg_dict.get('dataset', {})) dataset_cfg = DatasetConfig(**cfg_dict.get("dataset", {}))
training_cfg = TrainingConfig(**cfg_dict.get('training', {})) training_cfg = TrainingConfig(**cfg_dict.get("training", {}))
# Create main config # Create main config
main_cfg_dict = {k: v for k, v in cfg_dict.items() main_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in ["model", "dataset", "training"]}
if k not in ['model', 'dataset', 'training']}
return cls(model=model_cfg, dataset=dataset_cfg, training=training_cfg, **main_cfg_dict)
return cls(
model=model_cfg,
dataset=dataset_cfg,
training=training_cfg,
**main_cfg_dict
)
def to_yaml(self, yaml_path: Union[str, Path]) -> None: def to_yaml(self, yaml_path: Union[str, Path]) -> None:
"""Save configuration to YAML file.""" """Save configuration to YAML file."""
yaml_path = Path(yaml_path) yaml_path = Path(yaml_path)
yaml_path.parent.mkdir(parents=True, exist_ok=True) yaml_path.parent.mkdir(parents=True, exist_ok=True)
# Convert to OmegaConf for nice YAML output # Convert to OmegaConf for nice YAML output
cfg = OmegaConf.structured(self) cfg = OmegaConf.structured(self)
with open(yaml_path, 'w') as f: with open(yaml_path, "w") as f:
OmegaConf.save(cfg, f) OmegaConf.save(cfg, f)
def validate(self) -> None: def validate(self) -> None:
"""Validate configuration values.""" """Validate configuration values."""
# Dataset validation - check all train and eval datasets # Dataset validation - check all train and eval datasets
for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]: for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]:
for i, dataset_cfg in enumerate(datasets): for i, dataset_cfg in enumerate(datasets):
root_dir = dataset_cfg.get('root_dir') root_dir = dataset_cfg.get("root_dir")
if not root_dir: if not root_dir:
raise ValueError(f"Missing root_dir for {split_name} dataset {i}") raise ValueError(f"Missing root_dir for {split_name} dataset {i}")
if not os.path.exists(root_dir): if not os.path.exists(root_dir):
raise ValueError(f"Dataset root directory does not exist: {root_dir}") raise ValueError(f"Dataset root directory does not exist: {root_dir}")
# Training validation # Training validation
if self.training.fp16 and self.training.bf16: if self.training.fp16 and self.training.bf16:
raise ValueError("Cannot use both fp16 and bf16") raise ValueError("Cannot use both fp16 and bf16")
# Model validation # Model validation
if self.model.load_in_8bit and self.model.load_in_4bit: if self.model.load_in_8bit and self.model.load_in_4bit:
raise ValueError("Cannot load in both 8bit and 4bit") raise ValueError("Cannot load in both 8bit and 4bit")
# Output directory # Output directory
Path(self.training.output_dir).mkdir(parents=True, exist_ok=True) Path(self.training.output_dir).mkdir(parents=True, exist_ok=True)
# Logging directory # Logging directory
if self.training.logging_dir is None: if self.training.logging_dir is None:
self.training.logging_dir = os.path.join(self.training.output_dir, "logs") self.training.logging_dir = os.path.join(self.training.output_dir, "logs")
Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True) Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True)
def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None): def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None):
"""Create actual pipeline step instances from pipeline configuration. """Create actual pipeline step instances from pipeline configuration.
Args: Args:
pipeline_config: List of pipeline step configurations pipeline_config: List of pipeline step configurations
processor: The model processor (required for Tokenizer step) processor: The model processor (required for Tokenizer step)
Returns: Returns:
List of initialized pipeline step instances List of initialized pipeline step instances
""" """
from olmocr.prompts.prompts import PageResponse
from olmocr.train.dataloader import ( from olmocr.train.dataloader import (
FrontMatterParser,
PDFRenderer,
StaticLengthDocumentAnchoring,
FinetuningPrompt, FinetuningPrompt,
FrontMatterOutputFormat, FrontMatterOutputFormat,
FrontMatterParser,
InstructUserMessages, InstructUserMessages,
Tokenizer PDFRenderer,
StaticLengthDocumentAnchoring,
Tokenizer,
) )
from olmocr.prompts.prompts import PageResponse
steps = [] steps = []
for step_config in pipeline_config: for step_config in pipeline_config:
if not step_config.get('enabled', True): if not step_config.get("enabled", True):
continue continue
step_name = step_config['name'] step_name = step_config["name"]
if step_name == 'FrontMatterParser': if step_name == "FrontMatterParser":
# Handle both old and new config format # Handle both old and new config format
if 'front_matter_class' in step_config: if "front_matter_class" in step_config:
front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None front_matter_class = PageResponse if step_config["front_matter_class"] == "PageResponse" else None
else: else:
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None front_matter_class = PageResponse if step_config.get("use_page_response_class", True) else None
steps.append(FrontMatterParser(front_matter_class=front_matter_class)) steps.append(FrontMatterParser(front_matter_class=front_matter_class))
elif step_name == 'PDFRenderer': elif step_name == "PDFRenderer":
steps.append(PDFRenderer( steps.append(
target_longest_image_dim=step_config.get('target_longest_image_dim', 1024), PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024), image_transform=None) # Can be extended later
image_transform=None # Can be extended later )
))
elif step_name == "StaticLengthDocumentAnchoring":
elif step_name == 'StaticLengthDocumentAnchoring': steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
steps.append(StaticLengthDocumentAnchoring(
target_anchor_text_len=step_config.get('target_anchor_text_len', 6000) elif step_name == "FinetuningPrompt":
))
elif step_name == 'FinetuningPrompt':
steps.append(FinetuningPrompt()) steps.append(FinetuningPrompt())
elif step_name == 'FrontMatterOutputFormat': elif step_name == "FrontMatterOutputFormat":
steps.append(FrontMatterOutputFormat()) steps.append(FrontMatterOutputFormat())
elif step_name == 'InstructUserMessages': elif step_name == "InstructUserMessages":
steps.append(InstructUserMessages()) steps.append(InstructUserMessages())
elif step_name == 'Tokenizer': elif step_name == "Tokenizer":
if processor is None: if processor is None:
raise ValueError("Processor must be provided for Tokenizer step") raise ValueError("Processor must be provided for Tokenizer step")
steps.append(Tokenizer( steps.append(
processor=processor, Tokenizer(
masking_index=step_config.get('masking_index', -100), processor=processor,
end_of_message_token=step_config.get('end_of_message_token', '<|im_end|>') masking_index=step_config.get("masking_index", -100),
)) end_of_message_token=step_config.get("end_of_message_token", "<|im_end|>"),
)
)
else: else:
raise ValueError(f"Unknown pipeline step: {step_name}") raise ValueError(f"Unknown pipeline step: {step_name}")
return steps return steps
def create_default_config() -> Config: def create_default_config() -> Config:
"""Create a default configuration.""" """Create a default configuration."""
return Config( return Config(model=ModelConfig(), dataset=DatasetConfig(), training=TrainingConfig())
model=ModelConfig(),
dataset=DatasetConfig(),
training=TrainingConfig()
)
if __name__ == "__main__": if __name__ == "__main__":
@ -358,7 +360,7 @@ if __name__ == "__main__":
config = create_default_config() config = create_default_config()
config.to_yaml("configs/default_config.yaml") config.to_yaml("configs/default_config.yaml")
print("Default config saved to configs/default_config.yaml") print("Default config saved to configs/default_config.yaml")
# Example: Load from YAML # Example: Load from YAML
# loaded_config = Config.from_yaml("configs/default_config.yaml") # loaded_config = Config.from_yaml("configs/default_config.yaml")
# print(loaded_config) # print(loaded_config)

View File

@ -71,7 +71,7 @@ class BaseMarkdownPDFDataset(Dataset):
# Verify the resolved path exists # Verify the resolved path exists
if pdf_path.exists(): if pdf_path.exists():
# Validate PDF - check it loads and has exactly one page # Validate PDF - check it loads and has exactly one page and that you can get document-anchoring from it
try: try:
reader = PdfReader(str(pdf_path)) reader = PdfReader(str(pdf_path))
num_pages = len(reader.pages) num_pages = len(reader.pages)
@ -80,6 +80,9 @@ class BaseMarkdownPDFDataset(Dataset):
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}")) invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
continue continue
# Test that document anchoring works
get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100)
self.samples.append({"markdown_path": md_path, "pdf_path": pdf_path}) self.samples.append({"markdown_path": md_path, "pdf_path": pdf_path})
valid_count += 1 valid_count += 1

View File

@ -4,18 +4,18 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse import argparse
import logging import logging
import numpy as np
from transformers import ( import numpy as np
AutoProcessor,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Trainer,
TrainingArguments,
EarlyStoppingCallback
)
import torch import torch
from torch.utils.data import ConcatDataset from torch.utils.data import ConcatDataset
from transformers import (
AutoProcessor,
EarlyStoppingCallback,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
Trainer,
TrainingArguments,
)
from olmocr.train.config import Config from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset from olmocr.train.dataloader import BaseMarkdownPDFDataset
@ -31,76 +31,67 @@ logger = logging.getLogger(__name__)
class QwenDataCollator: class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays.""" """Data collator for vision-language models that handles numpy arrays."""
def __call__(self, examples): def __call__(self, examples):
# Filter out None values and extract the fields we need # Filter out None values and extract the fields we need
batch = { batch = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": [], "image_grid_thw": []}
'input_ids': [],
'attention_mask': [],
'labels': [],
'pixel_values': [],
'image_grid_thw': []
}
for example in examples: for example in examples:
if example is not None: if example is not None:
# Convert numpy arrays to tensors # Convert numpy arrays to tensors
batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids']) batch["input_ids"].append(torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"])
batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask']) batch["attention_mask"].append(
batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels']) torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"]
)
batch["labels"].append(torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"])
# Handle pixel_values which might be numpy array or already a tensor # Handle pixel_values which might be numpy array or already a tensor
pixel_values = example['pixel_values'] pixel_values = example["pixel_values"]
if isinstance(pixel_values, np.ndarray): if isinstance(pixel_values, np.ndarray):
pixel_values = torch.from_numpy(pixel_values) pixel_values = torch.from_numpy(pixel_values)
batch['pixel_values'].append(pixel_values) batch["pixel_values"].append(pixel_values)
# Handle image_grid_thw # Handle image_grid_thw
image_grid_thw = example['image_grid_thw'] image_grid_thw = example["image_grid_thw"]
if isinstance(image_grid_thw, np.ndarray): if isinstance(image_grid_thw, np.ndarray):
image_grid_thw = torch.from_numpy(image_grid_thw) image_grid_thw = torch.from_numpy(image_grid_thw)
batch['image_grid_thw'].append(image_grid_thw) batch["image_grid_thw"].append(image_grid_thw)
# Convert lists to tensors with proper padding # Convert lists to tensors with proper padding
# Note: For Qwen2-VL, we typically handle variable length sequences # Note: For Qwen2-VL, we typically handle variable length sequences
# The model's processor should handle the padding internally # The model's processor should handle the padding internally
return { return {
'input_ids': torch.stack(batch['input_ids']), "input_ids": torch.stack(batch["input_ids"]),
'attention_mask': torch.stack(batch['attention_mask']), "attention_mask": torch.stack(batch["attention_mask"]),
'labels': torch.stack(batch['labels']), "labels": torch.stack(batch["labels"]),
'pixel_values': torch.stack(batch['pixel_values']), # Stack into tensor "pixel_values": torch.stack(batch["pixel_values"]), # Stack into tensor
'image_grid_thw': torch.stack(batch['image_grid_thw']) "image_grid_thw": torch.stack(batch["image_grid_thw"]),
} }
def main(): def main():
parser = argparse.ArgumentParser(description="Train OlmOCR model") parser = argparse.ArgumentParser(description="Train OlmOCR model")
parser.add_argument( parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
"--config",
type=str,
default="olmocr/train/configs/example_config.yaml",
help="Path to YAML configuration file"
)
args = parser.parse_args() args = parser.parse_args()
# Load configuration # Load configuration
logger.info(f"Loading configuration from: {args.config}") logger.info(f"Loading configuration from: {args.config}")
config = Config.from_yaml(args.config) config = Config.from_yaml(args.config)
# Validate configuration # Validate configuration
try: try:
config.validate() config.validate()
except ValueError as e: except ValueError as e:
logger.error(f"Configuration validation failed: {e}") logger.error(f"Configuration validation failed: {e}")
return return
# Load processor for tokenization # Load processor for tokenization
logger.info(f"Loading processor: {config.model.name}") logger.info(f"Loading processor: {config.model.name}")
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
config.model.name, config.model.name,
) )
# Load model # Load model
logger.info(f"Loading model: {config.model.name}") logger.info(f"Loading model: {config.model.name}")
if "Qwen2.5-VL" in config.model.name: if "Qwen2.5-VL" in config.model.name:
@ -121,50 +112,50 @@ def main():
) )
else: else:
raise NotImplementedError() raise NotImplementedError()
# Enable gradient checkpointing if configured # Enable gradient checkpointing if configured
if config.training.gradient_checkpointing: if config.training.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
# Create training datasets # Create training datasets
logger.info("Creating training datasets...") logger.info("Creating training datasets...")
train_datasets = [] train_datasets = []
for i, dataset_cfg in enumerate(config.dataset.train): for i, dataset_cfg in enumerate(config.dataset.train):
root_dir = dataset_cfg['root_dir'] root_dir = dataset_cfg["root_dir"]
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
logger.info(f"Creating training dataset {i+1} from: {root_dir}") logger.info(f"Creating training dataset {i+1} from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples") logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0: if len(dataset) > 0:
train_datasets.append(dataset) train_datasets.append(dataset)
# Combine all training datasets # Combine all training datasets
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0] train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
logger.info(f"Total training samples: {len(train_dataset)}") logger.info(f"Total training samples: {len(train_dataset)}")
# Create evaluation datasets # Create evaluation datasets
logger.info("Creating evaluation datasets...") logger.info("Creating evaluation datasets...")
eval_datasets = {} eval_datasets = {}
for i, dataset_cfg in enumerate(config.dataset.eval): for i, dataset_cfg in enumerate(config.dataset.eval):
root_dir = dataset_cfg['root_dir'] root_dir = dataset_cfg["root_dir"]
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
# Use dataset name if provided, otherwise use root_dir as name # Use dataset name if provided, otherwise use root_dir as name
dataset_name = dataset_cfg.get('name', f"eval_dataset_{i+1}") dataset_name = dataset_cfg.get("name", f"eval_dataset_{i+1}")
logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}") logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples") logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0: if len(dataset) > 0:
eval_datasets[dataset_name] = dataset eval_datasets[dataset_name] = dataset
# Log total evaluation samples across all datasets # Log total evaluation samples across all datasets
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values()) total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}") logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
# Set up training arguments # Set up training arguments
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=config.training.output_dir, output_dir=config.training.output_dir,
@ -208,17 +199,16 @@ def main():
eval_on_start=True, eval_on_start=True,
run_name=config.run_name, run_name=config.run_name,
) )
# Set up callbacks # Set up callbacks
callbacks = [] callbacks = []
if config.training.use_early_stopping: if config.training.use_early_stopping:
callbacks.append( callbacks.append(
EarlyStoppingCallback( EarlyStoppingCallback(
early_stopping_patience=config.training.early_stopping_patience, early_stopping_patience=config.training.early_stopping_patience, early_stopping_threshold=config.training.early_stopping_threshold
early_stopping_threshold=config.training.early_stopping_threshold
) )
) )
# Initialize trainer # Initialize trainer
logger.info("Initializing trainer...") logger.info("Initializing trainer...")
trainer = Trainer( trainer = Trainer(
@ -229,19 +219,19 @@ def main():
data_collator=QwenDataCollator(), data_collator=QwenDataCollator(),
callbacks=callbacks, callbacks=callbacks,
) )
# Start training # Start training
logger.info("Starting training...") logger.info("Starting training...")
train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint) train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
# Save the final model # Save the final model
logger.info("Saving final model...") logger.info("Saving final model...")
trainer.save_model() trainer.save_model()
trainer.save_state() trainer.save_state()
# Log metrics # Log metrics
logger.info(f"Training completed! Metrics: {train_result.metrics}") logger.info(f"Training completed! Metrics: {train_result.metrics}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()