mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 16:52:20 +00:00
Checking that anchor text works for each pdf page when initializing dataloader
This commit is contained in:
parent
dc7fff5bf7
commit
8e5e18f54c
@ -1225,4 +1225,4 @@ async def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
@ -5,6 +5,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from os import PathLike
|
||||||
from typing import List, Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
import ftfy
|
import ftfy
|
||||||
@ -16,7 +17,7 @@ from olmocr.filter.coherency import get_document_coherency
|
|||||||
|
|
||||||
|
|
||||||
def get_anchor_text(
|
def get_anchor_text(
|
||||||
local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
|
local_pdf_path: str | PathLike, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
|
||||||
) -> str:
|
) -> str:
|
||||||
assert page > 0, "Pages are 1-indexed in pdf-land"
|
assert page > 0, "Pages are 1-indexed in pdf-land"
|
||||||
|
|
||||||
|
@ -25,4 +25,4 @@ oneshot(model=model, recipe=recipe)
|
|||||||
# Save the model.
|
# Save the model.
|
||||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe"
|
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe"
|
||||||
model.save_pretrained(SAVE_DIR)
|
model.save_pretrained(SAVE_DIR)
|
||||||
tokenizer.save_pretrained(SAVE_DIR)
|
tokenizer.save_pretrained(SAVE_DIR)
|
||||||
|
@ -12,6 +12,7 @@ from omegaconf import DictConfig, OmegaConf
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PipelineStepConfig:
|
class PipelineStepConfig:
|
||||||
"""Base configuration for pipeline steps."""
|
"""Base configuration for pipeline steps."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ class PipelineStepConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FrontMatterParserConfig(PipelineStepConfig):
|
class FrontMatterParserConfig(PipelineStepConfig):
|
||||||
"""Configuration for FrontMatterParser step."""
|
"""Configuration for FrontMatterParser step."""
|
||||||
|
|
||||||
name: str = "FrontMatterParser"
|
name: str = "FrontMatterParser"
|
||||||
use_page_response_class: bool = True # Whether to use PageResponse dataclass
|
use_page_response_class: bool = True # Whether to use PageResponse dataclass
|
||||||
|
|
||||||
@ -26,6 +28,7 @@ class FrontMatterParserConfig(PipelineStepConfig):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PDFRendererConfig(PipelineStepConfig):
|
class PDFRendererConfig(PipelineStepConfig):
|
||||||
"""Configuration for PDFRenderer step."""
|
"""Configuration for PDFRenderer step."""
|
||||||
|
|
||||||
name: str = "PDFRenderer"
|
name: str = "PDFRenderer"
|
||||||
target_longest_image_dim: int = 1024
|
target_longest_image_dim: int = 1024
|
||||||
|
|
||||||
@ -33,6 +36,7 @@ class PDFRendererConfig(PipelineStepConfig):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class StaticLengthDocumentAnchoringConfig(PipelineStepConfig):
|
class StaticLengthDocumentAnchoringConfig(PipelineStepConfig):
|
||||||
"""Configuration for StaticLengthDocumentAnchoring step."""
|
"""Configuration for StaticLengthDocumentAnchoring step."""
|
||||||
|
|
||||||
name: str = "StaticLengthDocumentAnchoring"
|
name: str = "StaticLengthDocumentAnchoring"
|
||||||
target_anchor_text_len: int = 6000
|
target_anchor_text_len: int = 6000
|
||||||
|
|
||||||
@ -40,24 +44,28 @@ class StaticLengthDocumentAnchoringConfig(PipelineStepConfig):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningPromptConfig(PipelineStepConfig):
|
class FinetuningPromptConfig(PipelineStepConfig):
|
||||||
"""Configuration for FinetuningPrompt step."""
|
"""Configuration for FinetuningPrompt step."""
|
||||||
|
|
||||||
name: str = "FinetuningPrompt"
|
name: str = "FinetuningPrompt"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrontMatterOutputFormatConfig(PipelineStepConfig):
|
class FrontMatterOutputFormatConfig(PipelineStepConfig):
|
||||||
"""Configuration for FrontMatterOutputFormat step."""
|
"""Configuration for FrontMatterOutputFormat step."""
|
||||||
|
|
||||||
name: str = "FrontMatterOutputFormat"
|
name: str = "FrontMatterOutputFormat"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InstructUserMessagesConfig(PipelineStepConfig):
|
class InstructUserMessagesConfig(PipelineStepConfig):
|
||||||
"""Configuration for InstructUserMessages step."""
|
"""Configuration for InstructUserMessages step."""
|
||||||
|
|
||||||
name: str = "InstructUserMessages"
|
name: str = "InstructUserMessages"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenizerStepConfig(PipelineStepConfig):
|
class TokenizerStepConfig(PipelineStepConfig):
|
||||||
"""Configuration for Tokenizer step."""
|
"""Configuration for Tokenizer step."""
|
||||||
|
|
||||||
name: str = "Tokenizer"
|
name: str = "Tokenizer"
|
||||||
masking_index: int = -100
|
masking_index: int = -100
|
||||||
end_of_message_token: str = "<|im_end|>"
|
end_of_message_token: str = "<|im_end|>"
|
||||||
@ -66,16 +74,18 @@ class TokenizerStepConfig(PipelineStepConfig):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class DatasetItemConfig:
|
class DatasetItemConfig:
|
||||||
"""Configuration for a single dataset item."""
|
"""Configuration for a single dataset item."""
|
||||||
|
|
||||||
root_dir: str
|
root_dir: str
|
||||||
pipeline: List[Dict[str, Any]] = field(default_factory=list)
|
pipeline: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
# Optional sampling
|
# Optional sampling
|
||||||
max_samples: Optional[int] = None
|
max_samples: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
|
@dataclass
|
||||||
class DatasetConfig:
|
class DatasetConfig:
|
||||||
"""Configuration for dataset and data loading."""
|
"""Configuration for dataset and data loading."""
|
||||||
|
|
||||||
train: List[Dict[str, Any]] = field(default_factory=list)
|
train: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
eval: List[Dict[str, Any]] = field(default_factory=list)
|
eval: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
@ -83,23 +93,24 @@ class DatasetConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
"""Configuration for model."""
|
"""Configuration for model."""
|
||||||
|
|
||||||
name: str = "Qwen/Qwen2.5-VL-7B-Instruct"
|
name: str = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
|
|
||||||
# Model initialization
|
# Model initialization
|
||||||
load_in_8bit: bool = False
|
load_in_8bit: bool = False
|
||||||
load_in_4bit: bool = False
|
load_in_4bit: bool = False
|
||||||
device_map: Optional[Union[str, Dict[str, Any]]] = "auto"
|
device_map: Optional[Union[str, Dict[str, Any]]] = "auto"
|
||||||
torch_dtype: str = "auto" # "auto", "float16", "bfloat16", "float32"
|
torch_dtype: str = "auto" # "auto", "float16", "bfloat16", "float32"
|
||||||
|
|
||||||
# Flash attention
|
# Flash attention
|
||||||
use_flash_attention: bool = True
|
use_flash_attention: bool = True
|
||||||
attn_implementation: Optional[str] = None # "flash_attention_2", "sdpa", "eager"
|
attn_implementation: Optional[str] = None # "flash_attention_2", "sdpa", "eager"
|
||||||
|
|
||||||
# Model modifications
|
# Model modifications
|
||||||
freeze_vision_tower: bool = False
|
freeze_vision_tower: bool = False
|
||||||
freeze_language_model: bool = False
|
freeze_language_model: bool = False
|
||||||
|
|
||||||
# LoRA configuration (optional)
|
# LoRA configuration (optional)
|
||||||
use_lora: bool = False
|
use_lora: bool = False
|
||||||
lora_rank: int = 8
|
lora_rank: int = 8
|
||||||
@ -107,22 +118,23 @@ class ModelConfig:
|
|||||||
lora_dropout: float = 0.1
|
lora_dropout: float = 0.1
|
||||||
lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj"])
|
lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj"])
|
||||||
lora_modules_to_save: Optional[List[str]] = None
|
lora_modules_to_save: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingConfig:
|
class TrainingConfig:
|
||||||
"""Configuration for training parameters."""
|
"""Configuration for training parameters."""
|
||||||
|
|
||||||
output_dir: str = "./outputs"
|
output_dir: str = "./outputs"
|
||||||
num_train_epochs: int = 3
|
num_train_epochs: int = 3
|
||||||
per_device_train_batch_size: int = 1
|
per_device_train_batch_size: int = 1
|
||||||
per_device_eval_batch_size: int = 1
|
per_device_eval_batch_size: int = 1
|
||||||
gradient_accumulation_steps: int = 8
|
gradient_accumulation_steps: int = 8
|
||||||
|
|
||||||
# Learning rate and scheduler
|
# Learning rate and scheduler
|
||||||
learning_rate: float = 2e-5
|
learning_rate: float = 2e-5
|
||||||
lr_scheduler_type: str = "cosine"
|
lr_scheduler_type: str = "cosine"
|
||||||
warmup_ratio: float = 0.1
|
warmup_ratio: float = 0.1
|
||||||
|
|
||||||
# Optimization
|
# Optimization
|
||||||
optim: str = "adamw_torch"
|
optim: str = "adamw_torch"
|
||||||
adam_beta1: float = 0.9
|
adam_beta1: float = 0.9
|
||||||
@ -130,16 +142,16 @@ class TrainingConfig:
|
|||||||
adam_epsilon: float = 1e-8
|
adam_epsilon: float = 1e-8
|
||||||
weight_decay: float = 0.01
|
weight_decay: float = 0.01
|
||||||
max_grad_norm: float = 1.0
|
max_grad_norm: float = 1.0
|
||||||
|
|
||||||
# Gradient checkpointing
|
# Gradient checkpointing
|
||||||
gradient_checkpointing: bool = True
|
gradient_checkpointing: bool = True
|
||||||
gradient_checkpointing_kwargs: Dict[str, Any] = field(default_factory=lambda: {"use_reentrant": False})
|
gradient_checkpointing_kwargs: Dict[str, Any] = field(default_factory=lambda: {"use_reentrant": False})
|
||||||
|
|
||||||
# Mixed precision
|
# Mixed precision
|
||||||
fp16: bool = False
|
fp16: bool = False
|
||||||
bf16: bool = True
|
bf16: bool = True
|
||||||
tf32: bool = True # Enable TF32 on Ampere GPUs
|
tf32: bool = True # Enable TF32 on Ampere GPUs
|
||||||
|
|
||||||
# Evaluation and checkpointing
|
# Evaluation and checkpointing
|
||||||
evaluation_strategy: str = "steps"
|
evaluation_strategy: str = "steps"
|
||||||
eval_steps: int = 500
|
eval_steps: int = 500
|
||||||
@ -149,29 +161,29 @@ class TrainingConfig:
|
|||||||
load_best_model_at_end: bool = True
|
load_best_model_at_end: bool = True
|
||||||
metric_for_best_model: str = "eval_loss"
|
metric_for_best_model: str = "eval_loss"
|
||||||
greater_is_better: bool = False
|
greater_is_better: bool = False
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
logging_dir: Optional[str] = None
|
logging_dir: Optional[str] = None
|
||||||
logging_strategy: str = "steps"
|
logging_strategy: str = "steps"
|
||||||
logging_steps: int = 10
|
logging_steps: int = 10
|
||||||
logging_first_step: bool = True
|
logging_first_step: bool = True
|
||||||
report_to: List[str] = field(default_factory=lambda: ["wandb"])
|
report_to: List[str] = field(default_factory=lambda: ["wandb"])
|
||||||
|
|
||||||
# Other training settings
|
# Other training settings
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
data_seed: Optional[int] = None
|
data_seed: Optional[int] = None
|
||||||
|
|
||||||
# Resume from checkpoint
|
# Resume from checkpoint
|
||||||
resume_from_checkpoint: Optional[str] = None
|
resume_from_checkpoint: Optional[str] = None
|
||||||
|
|
||||||
# DeepSpeed
|
# DeepSpeed
|
||||||
deepspeed: Optional[str] = None
|
deepspeed: Optional[str] = None
|
||||||
|
|
||||||
# Performance
|
# Performance
|
||||||
dataloader_drop_last: bool = True
|
dataloader_drop_last: bool = True
|
||||||
dataloader_num_workers: int = 4
|
dataloader_num_workers: int = 4
|
||||||
remove_unused_columns: bool = False # Important for custom datasets
|
remove_unused_columns: bool = False # Important for custom datasets
|
||||||
|
|
||||||
# Early stopping
|
# Early stopping
|
||||||
use_early_stopping: bool = False
|
use_early_stopping: bool = False
|
||||||
early_stopping_patience: int = 3
|
early_stopping_patience: int = 3
|
||||||
@ -181,176 +193,166 @@ class TrainingConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
"""Main configuration class that combines all sub-configs."""
|
"""Main configuration class that combines all sub-configs."""
|
||||||
|
|
||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
dataset: DatasetConfig
|
dataset: DatasetConfig
|
||||||
training: TrainingConfig
|
training: TrainingConfig
|
||||||
|
|
||||||
# Environment
|
# Environment
|
||||||
project_name: str = "olmocr-training"
|
project_name: str = "olmocr-training"
|
||||||
run_name: Optional[str] = None
|
run_name: Optional[str] = None
|
||||||
tags: List[str] = field(default_factory=list)
|
tags: List[str] = field(default_factory=list)
|
||||||
notes: Optional[str] = None
|
notes: Optional[str] = None
|
||||||
|
|
||||||
# Experiment tracking
|
# Experiment tracking
|
||||||
experiment_tracker: str = "tensorboard" # "tensorboard", "wandb", "mlflow"
|
experiment_tracker: str = "tensorboard" # "tensorboard", "wandb", "mlflow"
|
||||||
wandb_project: Optional[str] = None
|
wandb_project: Optional[str] = None
|
||||||
wandb_entity: Optional[str] = None
|
wandb_entity: Optional[str] = None
|
||||||
|
|
||||||
# Distributed training
|
# Distributed training
|
||||||
distributed: bool = False
|
distributed: bool = False
|
||||||
local_rank: int = -1
|
local_rank: int = -1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml(cls, yaml_path: Union[str, Path]) -> "Config":
|
def from_yaml(cls, yaml_path: Union[str, Path]) -> "Config":
|
||||||
"""Load configuration from YAML file."""
|
"""Load configuration from YAML file."""
|
||||||
yaml_path = Path(yaml_path)
|
yaml_path = Path(yaml_path)
|
||||||
if not yaml_path.exists():
|
if not yaml_path.exists():
|
||||||
raise FileNotFoundError(f"Config file not found: {yaml_path}")
|
raise FileNotFoundError(f"Config file not found: {yaml_path}")
|
||||||
|
|
||||||
# Load YAML with OmegaConf for better features
|
# Load YAML with OmegaConf for better features
|
||||||
with open(yaml_path, 'r') as f:
|
with open(yaml_path, "r") as f:
|
||||||
yaml_content = yaml.safe_load(f)
|
yaml_content = yaml.safe_load(f)
|
||||||
|
|
||||||
# Create OmegaConf config for interpolation and validation
|
# Create OmegaConf config for interpolation and validation
|
||||||
cfg = OmegaConf.create(yaml_content)
|
cfg = OmegaConf.create(yaml_content)
|
||||||
|
|
||||||
# Resolve any interpolations
|
# Resolve any interpolations
|
||||||
OmegaConf.resolve(cfg)
|
OmegaConf.resolve(cfg)
|
||||||
|
|
||||||
# Convert to dict and create dataclass
|
# Convert to dict and create dataclass
|
||||||
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
|
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
|
||||||
|
|
||||||
# Create sub-configs
|
# Create sub-configs
|
||||||
model_cfg = ModelConfig(**cfg_dict.get('model', {}))
|
model_cfg = ModelConfig(**cfg_dict.get("model", {}))
|
||||||
dataset_cfg = DatasetConfig(**cfg_dict.get('dataset', {}))
|
dataset_cfg = DatasetConfig(**cfg_dict.get("dataset", {}))
|
||||||
training_cfg = TrainingConfig(**cfg_dict.get('training', {}))
|
training_cfg = TrainingConfig(**cfg_dict.get("training", {}))
|
||||||
|
|
||||||
# Create main config
|
# Create main config
|
||||||
main_cfg_dict = {k: v for k, v in cfg_dict.items()
|
main_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in ["model", "dataset", "training"]}
|
||||||
if k not in ['model', 'dataset', 'training']}
|
|
||||||
|
return cls(model=model_cfg, dataset=dataset_cfg, training=training_cfg, **main_cfg_dict)
|
||||||
return cls(
|
|
||||||
model=model_cfg,
|
|
||||||
dataset=dataset_cfg,
|
|
||||||
training=training_cfg,
|
|
||||||
**main_cfg_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_yaml(self, yaml_path: Union[str, Path]) -> None:
|
def to_yaml(self, yaml_path: Union[str, Path]) -> None:
|
||||||
"""Save configuration to YAML file."""
|
"""Save configuration to YAML file."""
|
||||||
yaml_path = Path(yaml_path)
|
yaml_path = Path(yaml_path)
|
||||||
yaml_path.parent.mkdir(parents=True, exist_ok=True)
|
yaml_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Convert to OmegaConf for nice YAML output
|
# Convert to OmegaConf for nice YAML output
|
||||||
cfg = OmegaConf.structured(self)
|
cfg = OmegaConf.structured(self)
|
||||||
|
|
||||||
with open(yaml_path, 'w') as f:
|
with open(yaml_path, "w") as f:
|
||||||
OmegaConf.save(cfg, f)
|
OmegaConf.save(cfg, f)
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
"""Validate configuration values."""
|
"""Validate configuration values."""
|
||||||
# Dataset validation - check all train and eval datasets
|
# Dataset validation - check all train and eval datasets
|
||||||
for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]:
|
for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]:
|
||||||
for i, dataset_cfg in enumerate(datasets):
|
for i, dataset_cfg in enumerate(datasets):
|
||||||
root_dir = dataset_cfg.get('root_dir')
|
root_dir = dataset_cfg.get("root_dir")
|
||||||
if not root_dir:
|
if not root_dir:
|
||||||
raise ValueError(f"Missing root_dir for {split_name} dataset {i}")
|
raise ValueError(f"Missing root_dir for {split_name} dataset {i}")
|
||||||
if not os.path.exists(root_dir):
|
if not os.path.exists(root_dir):
|
||||||
raise ValueError(f"Dataset root directory does not exist: {root_dir}")
|
raise ValueError(f"Dataset root directory does not exist: {root_dir}")
|
||||||
|
|
||||||
# Training validation
|
# Training validation
|
||||||
if self.training.fp16 and self.training.bf16:
|
if self.training.fp16 and self.training.bf16:
|
||||||
raise ValueError("Cannot use both fp16 and bf16")
|
raise ValueError("Cannot use both fp16 and bf16")
|
||||||
|
|
||||||
# Model validation
|
# Model validation
|
||||||
if self.model.load_in_8bit and self.model.load_in_4bit:
|
if self.model.load_in_8bit and self.model.load_in_4bit:
|
||||||
raise ValueError("Cannot load in both 8bit and 4bit")
|
raise ValueError("Cannot load in both 8bit and 4bit")
|
||||||
|
|
||||||
# Output directory
|
# Output directory
|
||||||
Path(self.training.output_dir).mkdir(parents=True, exist_ok=True)
|
Path(self.training.output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Logging directory
|
# Logging directory
|
||||||
if self.training.logging_dir is None:
|
if self.training.logging_dir is None:
|
||||||
self.training.logging_dir = os.path.join(self.training.output_dir, "logs")
|
self.training.logging_dir = os.path.join(self.training.output_dir, "logs")
|
||||||
Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True)
|
Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None):
|
def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None):
|
||||||
"""Create actual pipeline step instances from pipeline configuration.
|
"""Create actual pipeline step instances from pipeline configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pipeline_config: List of pipeline step configurations
|
pipeline_config: List of pipeline step configurations
|
||||||
processor: The model processor (required for Tokenizer step)
|
processor: The model processor (required for Tokenizer step)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of initialized pipeline step instances
|
List of initialized pipeline step instances
|
||||||
"""
|
"""
|
||||||
|
from olmocr.prompts.prompts import PageResponse
|
||||||
from olmocr.train.dataloader import (
|
from olmocr.train.dataloader import (
|
||||||
FrontMatterParser,
|
|
||||||
PDFRenderer,
|
|
||||||
StaticLengthDocumentAnchoring,
|
|
||||||
FinetuningPrompt,
|
FinetuningPrompt,
|
||||||
FrontMatterOutputFormat,
|
FrontMatterOutputFormat,
|
||||||
|
FrontMatterParser,
|
||||||
InstructUserMessages,
|
InstructUserMessages,
|
||||||
Tokenizer
|
PDFRenderer,
|
||||||
|
StaticLengthDocumentAnchoring,
|
||||||
|
Tokenizer,
|
||||||
)
|
)
|
||||||
from olmocr.prompts.prompts import PageResponse
|
|
||||||
|
|
||||||
steps = []
|
steps = []
|
||||||
for step_config in pipeline_config:
|
for step_config in pipeline_config:
|
||||||
if not step_config.get('enabled', True):
|
if not step_config.get("enabled", True):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
step_name = step_config['name']
|
step_name = step_config["name"]
|
||||||
|
|
||||||
if step_name == 'FrontMatterParser':
|
if step_name == "FrontMatterParser":
|
||||||
# Handle both old and new config format
|
# Handle both old and new config format
|
||||||
if 'front_matter_class' in step_config:
|
if "front_matter_class" in step_config:
|
||||||
front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None
|
front_matter_class = PageResponse if step_config["front_matter_class"] == "PageResponse" else None
|
||||||
else:
|
else:
|
||||||
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None
|
front_matter_class = PageResponse if step_config.get("use_page_response_class", True) else None
|
||||||
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
|
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
|
||||||
|
|
||||||
elif step_name == 'PDFRenderer':
|
elif step_name == "PDFRenderer":
|
||||||
steps.append(PDFRenderer(
|
steps.append(
|
||||||
target_longest_image_dim=step_config.get('target_longest_image_dim', 1024),
|
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024), image_transform=None) # Can be extended later
|
||||||
image_transform=None # Can be extended later
|
)
|
||||||
))
|
|
||||||
|
elif step_name == "StaticLengthDocumentAnchoring":
|
||||||
elif step_name == 'StaticLengthDocumentAnchoring':
|
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
|
||||||
steps.append(StaticLengthDocumentAnchoring(
|
|
||||||
target_anchor_text_len=step_config.get('target_anchor_text_len', 6000)
|
elif step_name == "FinetuningPrompt":
|
||||||
))
|
|
||||||
|
|
||||||
elif step_name == 'FinetuningPrompt':
|
|
||||||
steps.append(FinetuningPrompt())
|
steps.append(FinetuningPrompt())
|
||||||
|
|
||||||
elif step_name == 'FrontMatterOutputFormat':
|
elif step_name == "FrontMatterOutputFormat":
|
||||||
steps.append(FrontMatterOutputFormat())
|
steps.append(FrontMatterOutputFormat())
|
||||||
|
|
||||||
elif step_name == 'InstructUserMessages':
|
elif step_name == "InstructUserMessages":
|
||||||
steps.append(InstructUserMessages())
|
steps.append(InstructUserMessages())
|
||||||
|
|
||||||
elif step_name == 'Tokenizer':
|
elif step_name == "Tokenizer":
|
||||||
if processor is None:
|
if processor is None:
|
||||||
raise ValueError("Processor must be provided for Tokenizer step")
|
raise ValueError("Processor must be provided for Tokenizer step")
|
||||||
steps.append(Tokenizer(
|
steps.append(
|
||||||
processor=processor,
|
Tokenizer(
|
||||||
masking_index=step_config.get('masking_index', -100),
|
processor=processor,
|
||||||
end_of_message_token=step_config.get('end_of_message_token', '<|im_end|>')
|
masking_index=step_config.get("masking_index", -100),
|
||||||
))
|
end_of_message_token=step_config.get("end_of_message_token", "<|im_end|>"),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown pipeline step: {step_name}")
|
raise ValueError(f"Unknown pipeline step: {step_name}")
|
||||||
|
|
||||||
return steps
|
return steps
|
||||||
|
|
||||||
|
|
||||||
def create_default_config() -> Config:
|
def create_default_config() -> Config:
|
||||||
"""Create a default configuration."""
|
"""Create a default configuration."""
|
||||||
return Config(
|
return Config(model=ModelConfig(), dataset=DatasetConfig(), training=TrainingConfig())
|
||||||
model=ModelConfig(),
|
|
||||||
dataset=DatasetConfig(),
|
|
||||||
training=TrainingConfig()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -358,7 +360,7 @@ if __name__ == "__main__":
|
|||||||
config = create_default_config()
|
config = create_default_config()
|
||||||
config.to_yaml("configs/default_config.yaml")
|
config.to_yaml("configs/default_config.yaml")
|
||||||
print("Default config saved to configs/default_config.yaml")
|
print("Default config saved to configs/default_config.yaml")
|
||||||
|
|
||||||
# Example: Load from YAML
|
# Example: Load from YAML
|
||||||
# loaded_config = Config.from_yaml("configs/default_config.yaml")
|
# loaded_config = Config.from_yaml("configs/default_config.yaml")
|
||||||
# print(loaded_config)
|
# print(loaded_config)
|
||||||
|
@ -71,7 +71,7 @@ class BaseMarkdownPDFDataset(Dataset):
|
|||||||
|
|
||||||
# Verify the resolved path exists
|
# Verify the resolved path exists
|
||||||
if pdf_path.exists():
|
if pdf_path.exists():
|
||||||
# Validate PDF - check it loads and has exactly one page
|
# Validate PDF - check it loads and has exactly one page and that you can get document-anchoring from it
|
||||||
try:
|
try:
|
||||||
reader = PdfReader(str(pdf_path))
|
reader = PdfReader(str(pdf_path))
|
||||||
num_pages = len(reader.pages)
|
num_pages = len(reader.pages)
|
||||||
@ -80,6 +80,9 @@ class BaseMarkdownPDFDataset(Dataset):
|
|||||||
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
|
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Test that document anchoring works
|
||||||
|
get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100)
|
||||||
|
|
||||||
self.samples.append({"markdown_path": md_path, "pdf_path": pdf_path})
|
self.samples.append({"markdown_path": md_path, "pdf_path": pdf_path})
|
||||||
valid_count += 1
|
valid_count += 1
|
||||||
|
|
||||||
|
@ -4,18 +4,18 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from transformers import (
|
import numpy as np
|
||||||
AutoProcessor,
|
|
||||||
Qwen2VLForConditionalGeneration,
|
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
|
||||||
Trainer,
|
|
||||||
TrainingArguments,
|
|
||||||
EarlyStoppingCallback
|
|
||||||
)
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import ConcatDataset
|
from torch.utils.data import ConcatDataset
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
EarlyStoppingCallback,
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
Qwen2VLForConditionalGeneration,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
from olmocr.train.config import Config
|
from olmocr.train.config import Config
|
||||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||||
@ -31,76 +31,67 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class QwenDataCollator:
|
class QwenDataCollator:
|
||||||
"""Data collator for vision-language models that handles numpy arrays."""
|
"""Data collator for vision-language models that handles numpy arrays."""
|
||||||
|
|
||||||
def __call__(self, examples):
|
def __call__(self, examples):
|
||||||
# Filter out None values and extract the fields we need
|
# Filter out None values and extract the fields we need
|
||||||
batch = {
|
batch = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": [], "image_grid_thw": []}
|
||||||
'input_ids': [],
|
|
||||||
'attention_mask': [],
|
|
||||||
'labels': [],
|
|
||||||
'pixel_values': [],
|
|
||||||
'image_grid_thw': []
|
|
||||||
}
|
|
||||||
|
|
||||||
for example in examples:
|
for example in examples:
|
||||||
if example is not None:
|
if example is not None:
|
||||||
# Convert numpy arrays to tensors
|
# Convert numpy arrays to tensors
|
||||||
batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids'])
|
batch["input_ids"].append(torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"])
|
||||||
batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask'])
|
batch["attention_mask"].append(
|
||||||
batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels'])
|
torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"]
|
||||||
|
)
|
||||||
|
batch["labels"].append(torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"])
|
||||||
|
|
||||||
# Handle pixel_values which might be numpy array or already a tensor
|
# Handle pixel_values which might be numpy array or already a tensor
|
||||||
pixel_values = example['pixel_values']
|
pixel_values = example["pixel_values"]
|
||||||
if isinstance(pixel_values, np.ndarray):
|
if isinstance(pixel_values, np.ndarray):
|
||||||
pixel_values = torch.from_numpy(pixel_values)
|
pixel_values = torch.from_numpy(pixel_values)
|
||||||
batch['pixel_values'].append(pixel_values)
|
batch["pixel_values"].append(pixel_values)
|
||||||
|
|
||||||
# Handle image_grid_thw
|
# Handle image_grid_thw
|
||||||
image_grid_thw = example['image_grid_thw']
|
image_grid_thw = example["image_grid_thw"]
|
||||||
if isinstance(image_grid_thw, np.ndarray):
|
if isinstance(image_grid_thw, np.ndarray):
|
||||||
image_grid_thw = torch.from_numpy(image_grid_thw)
|
image_grid_thw = torch.from_numpy(image_grid_thw)
|
||||||
batch['image_grid_thw'].append(image_grid_thw)
|
batch["image_grid_thw"].append(image_grid_thw)
|
||||||
|
|
||||||
# Convert lists to tensors with proper padding
|
# Convert lists to tensors with proper padding
|
||||||
# Note: For Qwen2-VL, we typically handle variable length sequences
|
# Note: For Qwen2-VL, we typically handle variable length sequences
|
||||||
# The model's processor should handle the padding internally
|
# The model's processor should handle the padding internally
|
||||||
return {
|
return {
|
||||||
'input_ids': torch.stack(batch['input_ids']),
|
"input_ids": torch.stack(batch["input_ids"]),
|
||||||
'attention_mask': torch.stack(batch['attention_mask']),
|
"attention_mask": torch.stack(batch["attention_mask"]),
|
||||||
'labels': torch.stack(batch['labels']),
|
"labels": torch.stack(batch["labels"]),
|
||||||
'pixel_values': torch.stack(batch['pixel_values']), # Stack into tensor
|
"pixel_values": torch.stack(batch["pixel_values"]), # Stack into tensor
|
||||||
'image_grid_thw': torch.stack(batch['image_grid_thw'])
|
"image_grid_thw": torch.stack(batch["image_grid_thw"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
||||||
parser.add_argument(
|
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
|
||||||
"--config",
|
|
||||||
type=str,
|
|
||||||
default="olmocr/train/configs/example_config.yaml",
|
|
||||||
help="Path to YAML configuration file"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
logger.info(f"Loading configuration from: {args.config}")
|
logger.info(f"Loading configuration from: {args.config}")
|
||||||
config = Config.from_yaml(args.config)
|
config = Config.from_yaml(args.config)
|
||||||
|
|
||||||
# Validate configuration
|
# Validate configuration
|
||||||
try:
|
try:
|
||||||
config.validate()
|
config.validate()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Configuration validation failed: {e}")
|
logger.error(f"Configuration validation failed: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load processor for tokenization
|
# Load processor for tokenization
|
||||||
logger.info(f"Loading processor: {config.model.name}")
|
logger.info(f"Loading processor: {config.model.name}")
|
||||||
processor = AutoProcessor.from_pretrained(
|
processor = AutoProcessor.from_pretrained(
|
||||||
config.model.name,
|
config.model.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
logger.info(f"Loading model: {config.model.name}")
|
logger.info(f"Loading model: {config.model.name}")
|
||||||
if "Qwen2.5-VL" in config.model.name:
|
if "Qwen2.5-VL" in config.model.name:
|
||||||
@ -121,50 +112,50 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Enable gradient checkpointing if configured
|
# Enable gradient checkpointing if configured
|
||||||
if config.training.gradient_checkpointing:
|
if config.training.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
# Create training datasets
|
# Create training datasets
|
||||||
logger.info("Creating training datasets...")
|
logger.info("Creating training datasets...")
|
||||||
train_datasets = []
|
train_datasets = []
|
||||||
for i, dataset_cfg in enumerate(config.dataset.train):
|
for i, dataset_cfg in enumerate(config.dataset.train):
|
||||||
root_dir = dataset_cfg['root_dir']
|
root_dir = dataset_cfg["root_dir"]
|
||||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
|
||||||
|
|
||||||
logger.info(f"Creating training dataset {i+1} from: {root_dir}")
|
logger.info(f"Creating training dataset {i+1} from: {root_dir}")
|
||||||
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||||
logger.info(f"Found {len(dataset)} samples")
|
logger.info(f"Found {len(dataset)} samples")
|
||||||
|
|
||||||
if len(dataset) > 0:
|
if len(dataset) > 0:
|
||||||
train_datasets.append(dataset)
|
train_datasets.append(dataset)
|
||||||
|
|
||||||
# Combine all training datasets
|
# Combine all training datasets
|
||||||
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
|
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
|
||||||
logger.info(f"Total training samples: {len(train_dataset)}")
|
logger.info(f"Total training samples: {len(train_dataset)}")
|
||||||
|
|
||||||
# Create evaluation datasets
|
# Create evaluation datasets
|
||||||
logger.info("Creating evaluation datasets...")
|
logger.info("Creating evaluation datasets...")
|
||||||
eval_datasets = {}
|
eval_datasets = {}
|
||||||
for i, dataset_cfg in enumerate(config.dataset.eval):
|
for i, dataset_cfg in enumerate(config.dataset.eval):
|
||||||
root_dir = dataset_cfg['root_dir']
|
root_dir = dataset_cfg["root_dir"]
|
||||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
|
||||||
|
|
||||||
# Use dataset name if provided, otherwise use root_dir as name
|
# Use dataset name if provided, otherwise use root_dir as name
|
||||||
dataset_name = dataset_cfg.get('name', f"eval_dataset_{i+1}")
|
dataset_name = dataset_cfg.get("name", f"eval_dataset_{i+1}")
|
||||||
|
|
||||||
logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}")
|
logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}")
|
||||||
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||||
logger.info(f"Found {len(dataset)} samples")
|
logger.info(f"Found {len(dataset)} samples")
|
||||||
|
|
||||||
if len(dataset) > 0:
|
if len(dataset) > 0:
|
||||||
eval_datasets[dataset_name] = dataset
|
eval_datasets[dataset_name] = dataset
|
||||||
|
|
||||||
# Log total evaluation samples across all datasets
|
# Log total evaluation samples across all datasets
|
||||||
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
|
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
|
||||||
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
|
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
|
||||||
|
|
||||||
# Set up training arguments
|
# Set up training arguments
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=config.training.output_dir,
|
output_dir=config.training.output_dir,
|
||||||
@ -208,17 +199,16 @@ def main():
|
|||||||
eval_on_start=True,
|
eval_on_start=True,
|
||||||
run_name=config.run_name,
|
run_name=config.run_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up callbacks
|
# Set up callbacks
|
||||||
callbacks = []
|
callbacks = []
|
||||||
if config.training.use_early_stopping:
|
if config.training.use_early_stopping:
|
||||||
callbacks.append(
|
callbacks.append(
|
||||||
EarlyStoppingCallback(
|
EarlyStoppingCallback(
|
||||||
early_stopping_patience=config.training.early_stopping_patience,
|
early_stopping_patience=config.training.early_stopping_patience, early_stopping_threshold=config.training.early_stopping_threshold
|
||||||
early_stopping_threshold=config.training.early_stopping_threshold
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize trainer
|
# Initialize trainer
|
||||||
logger.info("Initializing trainer...")
|
logger.info("Initializing trainer...")
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
@ -229,19 +219,19 @@ def main():
|
|||||||
data_collator=QwenDataCollator(),
|
data_collator=QwenDataCollator(),
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
logger.info("Starting training...")
|
logger.info("Starting training...")
|
||||||
train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
|
||||||
|
|
||||||
# Save the final model
|
# Save the final model
|
||||||
logger.info("Saving final model...")
|
logger.info("Saving final model...")
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
logger.info(f"Training completed! Metrics: {train_result.metrics}")
|
logger.info(f"Training completed! Metrics: {train_result.metrics}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user