From 8e5e18f54cfda0066e2dd14aae3798c1dfa41a2c Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Mon, 30 Jun 2025 16:29:33 +0000 Subject: [PATCH] Checking that anchor text works for each pdf page when initializing dataloader --- olmocr/pipeline.py | 2 +- olmocr/prompts/anchor.py | 3 +- olmocr/train/compressqwen2checkpoint.py | 2 +- olmocr/train/config.py | 204 ++++++++++++------------ olmocr/train/dataloader.py | 5 +- olmocr/train/train.py | 126 +++++++-------- 6 files changed, 169 insertions(+), 173 deletions(-) diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 12bfd78..f056cd6 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -1225,4 +1225,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/olmocr/prompts/anchor.py b/olmocr/prompts/anchor.py index 2bac5db..68e8304 100644 --- a/olmocr/prompts/anchor.py +++ b/olmocr/prompts/anchor.py @@ -5,6 +5,7 @@ import random import re import subprocess from dataclasses import dataclass +from os import PathLike from typing import List, Literal import ftfy @@ -16,7 +17,7 @@ from olmocr.filter.coherency import get_document_coherency def get_anchor_text( - local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000 + local_pdf_path: str | PathLike, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000 ) -> str: assert page > 0, "Pages are 1-indexed in pdf-land" diff --git a/olmocr/train/compressqwen2checkpoint.py b/olmocr/train/compressqwen2checkpoint.py index a12d31a..89483e0 100644 --- a/olmocr/train/compressqwen2checkpoint.py +++ b/olmocr/train/compressqwen2checkpoint.py @@ -25,4 +25,4 @@ oneshot(model=model, recipe=recipe) # Save the model. SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe" model.save_pretrained(SAVE_DIR) -tokenizer.save_pretrained(SAVE_DIR) \ No newline at end of file +tokenizer.save_pretrained(SAVE_DIR) diff --git a/olmocr/train/config.py b/olmocr/train/config.py index 89bcbb6..42f1c41 100644 --- a/olmocr/train/config.py +++ b/olmocr/train/config.py @@ -12,6 +12,7 @@ from omegaconf import DictConfig, OmegaConf @dataclass class PipelineStepConfig: """Base configuration for pipeline steps.""" + name: str enabled: bool = True @@ -19,6 +20,7 @@ class PipelineStepConfig: @dataclass class FrontMatterParserConfig(PipelineStepConfig): """Configuration for FrontMatterParser step.""" + name: str = "FrontMatterParser" use_page_response_class: bool = True # Whether to use PageResponse dataclass @@ -26,6 +28,7 @@ class FrontMatterParserConfig(PipelineStepConfig): @dataclass class PDFRendererConfig(PipelineStepConfig): """Configuration for PDFRenderer step.""" + name: str = "PDFRenderer" target_longest_image_dim: int = 1024 @@ -33,6 +36,7 @@ class PDFRendererConfig(PipelineStepConfig): @dataclass class StaticLengthDocumentAnchoringConfig(PipelineStepConfig): """Configuration for StaticLengthDocumentAnchoring step.""" + name: str = "StaticLengthDocumentAnchoring" target_anchor_text_len: int = 6000 @@ -40,24 +44,28 @@ class StaticLengthDocumentAnchoringConfig(PipelineStepConfig): @dataclass class FinetuningPromptConfig(PipelineStepConfig): """Configuration for FinetuningPrompt step.""" + name: str = "FinetuningPrompt" @dataclass class FrontMatterOutputFormatConfig(PipelineStepConfig): """Configuration for FrontMatterOutputFormat step.""" + name: str = "FrontMatterOutputFormat" @dataclass class InstructUserMessagesConfig(PipelineStepConfig): """Configuration for InstructUserMessages step.""" + name: str = "InstructUserMessages" @dataclass class TokenizerStepConfig(PipelineStepConfig): """Configuration for Tokenizer step.""" + name: str = "Tokenizer" masking_index: int = -100 end_of_message_token: str = "<|im_end|>" @@ -66,16 +74,18 @@ class TokenizerStepConfig(PipelineStepConfig): @dataclass class DatasetItemConfig: """Configuration for a single dataset item.""" + root_dir: str pipeline: List[Dict[str, Any]] = field(default_factory=list) - + # Optional sampling max_samples: Optional[int] = None - -@dataclass + +@dataclass class DatasetConfig: """Configuration for dataset and data loading.""" + train: List[Dict[str, Any]] = field(default_factory=list) eval: List[Dict[str, Any]] = field(default_factory=list) @@ -83,23 +93,24 @@ class DatasetConfig: @dataclass class ModelConfig: """Configuration for model.""" + name: str = "Qwen/Qwen2.5-VL-7B-Instruct" trust_remote_code: bool = False - + # Model initialization load_in_8bit: bool = False load_in_4bit: bool = False device_map: Optional[Union[str, Dict[str, Any]]] = "auto" torch_dtype: str = "auto" # "auto", "float16", "bfloat16", "float32" - + # Flash attention use_flash_attention: bool = True attn_implementation: Optional[str] = None # "flash_attention_2", "sdpa", "eager" - + # Model modifications freeze_vision_tower: bool = False freeze_language_model: bool = False - + # LoRA configuration (optional) use_lora: bool = False lora_rank: int = 8 @@ -107,22 +118,23 @@ class ModelConfig: lora_dropout: float = 0.1 lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj"]) lora_modules_to_save: Optional[List[str]] = None - + @dataclass class TrainingConfig: """Configuration for training parameters.""" + output_dir: str = "./outputs" num_train_epochs: int = 3 per_device_train_batch_size: int = 1 per_device_eval_batch_size: int = 1 gradient_accumulation_steps: int = 8 - + # Learning rate and scheduler learning_rate: float = 2e-5 lr_scheduler_type: str = "cosine" warmup_ratio: float = 0.1 - + # Optimization optim: str = "adamw_torch" adam_beta1: float = 0.9 @@ -130,16 +142,16 @@ class TrainingConfig: adam_epsilon: float = 1e-8 weight_decay: float = 0.01 max_grad_norm: float = 1.0 - + # Gradient checkpointing gradient_checkpointing: bool = True gradient_checkpointing_kwargs: Dict[str, Any] = field(default_factory=lambda: {"use_reentrant": False}) - + # Mixed precision fp16: bool = False bf16: bool = True tf32: bool = True # Enable TF32 on Ampere GPUs - + # Evaluation and checkpointing evaluation_strategy: str = "steps" eval_steps: int = 500 @@ -149,29 +161,29 @@ class TrainingConfig: load_best_model_at_end: bool = True metric_for_best_model: str = "eval_loss" greater_is_better: bool = False - + # Logging logging_dir: Optional[str] = None logging_strategy: str = "steps" logging_steps: int = 10 logging_first_step: bool = True report_to: List[str] = field(default_factory=lambda: ["wandb"]) - + # Other training settings seed: int = 42 data_seed: Optional[int] = None - + # Resume from checkpoint resume_from_checkpoint: Optional[str] = None - + # DeepSpeed deepspeed: Optional[str] = None - + # Performance dataloader_drop_last: bool = True dataloader_num_workers: int = 4 remove_unused_columns: bool = False # Important for custom datasets - + # Early stopping use_early_stopping: bool = False early_stopping_patience: int = 3 @@ -181,176 +193,166 @@ class TrainingConfig: @dataclass class Config: """Main configuration class that combines all sub-configs.""" + model: ModelConfig dataset: DatasetConfig training: TrainingConfig - + # Environment project_name: str = "olmocr-training" run_name: Optional[str] = None tags: List[str] = field(default_factory=list) notes: Optional[str] = None - + # Experiment tracking experiment_tracker: str = "tensorboard" # "tensorboard", "wandb", "mlflow" wandb_project: Optional[str] = None wandb_entity: Optional[str] = None - + # Distributed training distributed: bool = False local_rank: int = -1 - + @classmethod def from_yaml(cls, yaml_path: Union[str, Path]) -> "Config": """Load configuration from YAML file.""" yaml_path = Path(yaml_path) if not yaml_path.exists(): raise FileNotFoundError(f"Config file not found: {yaml_path}") - + # Load YAML with OmegaConf for better features - with open(yaml_path, 'r') as f: + with open(yaml_path, "r") as f: yaml_content = yaml.safe_load(f) - + # Create OmegaConf config for interpolation and validation cfg = OmegaConf.create(yaml_content) - + # Resolve any interpolations OmegaConf.resolve(cfg) - + # Convert to dict and create dataclass cfg_dict = OmegaConf.to_container(cfg, resolve=True) - + # Create sub-configs - model_cfg = ModelConfig(**cfg_dict.get('model', {})) - dataset_cfg = DatasetConfig(**cfg_dict.get('dataset', {})) - training_cfg = TrainingConfig(**cfg_dict.get('training', {})) - + model_cfg = ModelConfig(**cfg_dict.get("model", {})) + dataset_cfg = DatasetConfig(**cfg_dict.get("dataset", {})) + training_cfg = TrainingConfig(**cfg_dict.get("training", {})) + # Create main config - main_cfg_dict = {k: v for k, v in cfg_dict.items() - if k not in ['model', 'dataset', 'training']} - - return cls( - model=model_cfg, - dataset=dataset_cfg, - training=training_cfg, - **main_cfg_dict - ) - + main_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in ["model", "dataset", "training"]} + + return cls(model=model_cfg, dataset=dataset_cfg, training=training_cfg, **main_cfg_dict) + def to_yaml(self, yaml_path: Union[str, Path]) -> None: """Save configuration to YAML file.""" yaml_path = Path(yaml_path) yaml_path.parent.mkdir(parents=True, exist_ok=True) - + # Convert to OmegaConf for nice YAML output cfg = OmegaConf.structured(self) - - with open(yaml_path, 'w') as f: + + with open(yaml_path, "w") as f: OmegaConf.save(cfg, f) - + def validate(self) -> None: """Validate configuration values.""" # Dataset validation - check all train and eval datasets for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]: for i, dataset_cfg in enumerate(datasets): - root_dir = dataset_cfg.get('root_dir') + root_dir = dataset_cfg.get("root_dir") if not root_dir: raise ValueError(f"Missing root_dir for {split_name} dataset {i}") if not os.path.exists(root_dir): raise ValueError(f"Dataset root directory does not exist: {root_dir}") - + # Training validation if self.training.fp16 and self.training.bf16: raise ValueError("Cannot use both fp16 and bf16") - + # Model validation if self.model.load_in_8bit and self.model.load_in_4bit: raise ValueError("Cannot load in both 8bit and 4bit") - + # Output directory Path(self.training.output_dir).mkdir(parents=True, exist_ok=True) - + # Logging directory if self.training.logging_dir is None: self.training.logging_dir = os.path.join(self.training.output_dir, "logs") Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True) - + def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None): """Create actual pipeline step instances from pipeline configuration. - + Args: pipeline_config: List of pipeline step configurations processor: The model processor (required for Tokenizer step) - + Returns: List of initialized pipeline step instances """ + from olmocr.prompts.prompts import PageResponse from olmocr.train.dataloader import ( - FrontMatterParser, - PDFRenderer, - StaticLengthDocumentAnchoring, FinetuningPrompt, FrontMatterOutputFormat, + FrontMatterParser, InstructUserMessages, - Tokenizer + PDFRenderer, + StaticLengthDocumentAnchoring, + Tokenizer, ) - from olmocr.prompts.prompts import PageResponse - + steps = [] for step_config in pipeline_config: - if not step_config.get('enabled', True): + if not step_config.get("enabled", True): continue - - step_name = step_config['name'] - - if step_name == 'FrontMatterParser': + + step_name = step_config["name"] + + if step_name == "FrontMatterParser": # Handle both old and new config format - if 'front_matter_class' in step_config: - front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None + if "front_matter_class" in step_config: + front_matter_class = PageResponse if step_config["front_matter_class"] == "PageResponse" else None else: - front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None + front_matter_class = PageResponse if step_config.get("use_page_response_class", True) else None steps.append(FrontMatterParser(front_matter_class=front_matter_class)) - - elif step_name == 'PDFRenderer': - steps.append(PDFRenderer( - target_longest_image_dim=step_config.get('target_longest_image_dim', 1024), - image_transform=None # Can be extended later - )) - - elif step_name == 'StaticLengthDocumentAnchoring': - steps.append(StaticLengthDocumentAnchoring( - target_anchor_text_len=step_config.get('target_anchor_text_len', 6000) - )) - - elif step_name == 'FinetuningPrompt': + + elif step_name == "PDFRenderer": + steps.append( + PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024), image_transform=None) # Can be extended later + ) + + elif step_name == "StaticLengthDocumentAnchoring": + steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000))) + + elif step_name == "FinetuningPrompt": steps.append(FinetuningPrompt()) - - elif step_name == 'FrontMatterOutputFormat': + + elif step_name == "FrontMatterOutputFormat": steps.append(FrontMatterOutputFormat()) - - elif step_name == 'InstructUserMessages': + + elif step_name == "InstructUserMessages": steps.append(InstructUserMessages()) - - elif step_name == 'Tokenizer': + + elif step_name == "Tokenizer": if processor is None: raise ValueError("Processor must be provided for Tokenizer step") - steps.append(Tokenizer( - processor=processor, - masking_index=step_config.get('masking_index', -100), - end_of_message_token=step_config.get('end_of_message_token', '<|im_end|>') - )) + steps.append( + Tokenizer( + processor=processor, + masking_index=step_config.get("masking_index", -100), + end_of_message_token=step_config.get("end_of_message_token", "<|im_end|>"), + ) + ) else: raise ValueError(f"Unknown pipeline step: {step_name}") - + return steps def create_default_config() -> Config: """Create a default configuration.""" - return Config( - model=ModelConfig(), - dataset=DatasetConfig(), - training=TrainingConfig() - ) + return Config(model=ModelConfig(), dataset=DatasetConfig(), training=TrainingConfig()) if __name__ == "__main__": @@ -358,7 +360,7 @@ if __name__ == "__main__": config = create_default_config() config.to_yaml("configs/default_config.yaml") print("Default config saved to configs/default_config.yaml") - + # Example: Load from YAML # loaded_config = Config.from_yaml("configs/default_config.yaml") - # print(loaded_config) \ No newline at end of file + # print(loaded_config) diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index 4caab91..fab2229 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -71,7 +71,7 @@ class BaseMarkdownPDFDataset(Dataset): # Verify the resolved path exists if pdf_path.exists(): - # Validate PDF - check it loads and has exactly one page + # Validate PDF - check it loads and has exactly one page and that you can get document-anchoring from it try: reader = PdfReader(str(pdf_path)) num_pages = len(reader.pages) @@ -80,6 +80,9 @@ class BaseMarkdownPDFDataset(Dataset): invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}")) continue + # Test that document anchoring works + get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100) + self.samples.append({"markdown_path": md_path, "pdf_path": pdf_path}) valid_count += 1 diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 8bde005..d0c5e0b 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -4,18 +4,18 @@ Simple script to test OlmOCR dataset loading with YAML configuration. import argparse import logging -import numpy as np -from transformers import ( - AutoProcessor, - Qwen2VLForConditionalGeneration, - Qwen2_5_VLForConditionalGeneration, - Trainer, - TrainingArguments, - EarlyStoppingCallback -) +import numpy as np import torch from torch.utils.data import ConcatDataset +from transformers import ( + AutoProcessor, + EarlyStoppingCallback, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLForConditionalGeneration, + Trainer, + TrainingArguments, +) from olmocr.train.config import Config from olmocr.train.dataloader import BaseMarkdownPDFDataset @@ -31,76 +31,67 @@ logger = logging.getLogger(__name__) class QwenDataCollator: """Data collator for vision-language models that handles numpy arrays.""" - + def __call__(self, examples): # Filter out None values and extract the fields we need - batch = { - 'input_ids': [], - 'attention_mask': [], - 'labels': [], - 'pixel_values': [], - 'image_grid_thw': [] - } - + batch = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": [], "image_grid_thw": []} + for example in examples: if example is not None: # Convert numpy arrays to tensors - batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids']) - batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask']) - batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels']) - + batch["input_ids"].append(torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"]) + batch["attention_mask"].append( + torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"] + ) + batch["labels"].append(torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"]) + # Handle pixel_values which might be numpy array or already a tensor - pixel_values = example['pixel_values'] + pixel_values = example["pixel_values"] if isinstance(pixel_values, np.ndarray): pixel_values = torch.from_numpy(pixel_values) - batch['pixel_values'].append(pixel_values) - + batch["pixel_values"].append(pixel_values) + # Handle image_grid_thw - image_grid_thw = example['image_grid_thw'] + image_grid_thw = example["image_grid_thw"] if isinstance(image_grid_thw, np.ndarray): image_grid_thw = torch.from_numpy(image_grid_thw) - batch['image_grid_thw'].append(image_grid_thw) - + batch["image_grid_thw"].append(image_grid_thw) + # Convert lists to tensors with proper padding # Note: For Qwen2-VL, we typically handle variable length sequences # The model's processor should handle the padding internally return { - 'input_ids': torch.stack(batch['input_ids']), - 'attention_mask': torch.stack(batch['attention_mask']), - 'labels': torch.stack(batch['labels']), - 'pixel_values': torch.stack(batch['pixel_values']), # Stack into tensor - 'image_grid_thw': torch.stack(batch['image_grid_thw']) + "input_ids": torch.stack(batch["input_ids"]), + "attention_mask": torch.stack(batch["attention_mask"]), + "labels": torch.stack(batch["labels"]), + "pixel_values": torch.stack(batch["pixel_values"]), # Stack into tensor + "image_grid_thw": torch.stack(batch["image_grid_thw"]), } def main(): parser = argparse.ArgumentParser(description="Train OlmOCR model") - parser.add_argument( - "--config", - type=str, - default="olmocr/train/configs/example_config.yaml", - help="Path to YAML configuration file" - ) - + parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file") + args = parser.parse_args() - + # Load configuration logger.info(f"Loading configuration from: {args.config}") config = Config.from_yaml(args.config) - + # Validate configuration try: config.validate() except ValueError as e: logger.error(f"Configuration validation failed: {e}") return - + # Load processor for tokenization logger.info(f"Loading processor: {config.model.name}") processor = AutoProcessor.from_pretrained( config.model.name, ) - + # Load model logger.info(f"Loading model: {config.model.name}") if "Qwen2.5-VL" in config.model.name: @@ -121,50 +112,50 @@ def main(): ) else: raise NotImplementedError() - + # Enable gradient checkpointing if configured if config.training.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs) - + # Create training datasets logger.info("Creating training datasets...") train_datasets = [] for i, dataset_cfg in enumerate(config.dataset.train): - root_dir = dataset_cfg['root_dir'] - pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) - + root_dir = dataset_cfg["root_dir"] + pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor) + logger.info(f"Creating training dataset {i+1} from: {root_dir}") dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) logger.info(f"Found {len(dataset)} samples") - + if len(dataset) > 0: train_datasets.append(dataset) - + # Combine all training datasets train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0] logger.info(f"Total training samples: {len(train_dataset)}") - + # Create evaluation datasets logger.info("Creating evaluation datasets...") eval_datasets = {} for i, dataset_cfg in enumerate(config.dataset.eval): - root_dir = dataset_cfg['root_dir'] - pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) - + root_dir = dataset_cfg["root_dir"] + pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor) + # Use dataset name if provided, otherwise use root_dir as name - dataset_name = dataset_cfg.get('name', f"eval_dataset_{i+1}") - + dataset_name = dataset_cfg.get("name", f"eval_dataset_{i+1}") + logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}") dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) logger.info(f"Found {len(dataset)} samples") - + if len(dataset) > 0: eval_datasets[dataset_name] = dataset - + # Log total evaluation samples across all datasets total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values()) logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}") - + # Set up training arguments training_args = TrainingArguments( output_dir=config.training.output_dir, @@ -208,17 +199,16 @@ def main(): eval_on_start=True, run_name=config.run_name, ) - + # Set up callbacks callbacks = [] if config.training.use_early_stopping: callbacks.append( EarlyStoppingCallback( - early_stopping_patience=config.training.early_stopping_patience, - early_stopping_threshold=config.training.early_stopping_threshold + early_stopping_patience=config.training.early_stopping_patience, early_stopping_threshold=config.training.early_stopping_threshold ) ) - + # Initialize trainer logger.info("Initializing trainer...") trainer = Trainer( @@ -229,19 +219,19 @@ def main(): data_collator=QwenDataCollator(), callbacks=callbacks, ) - + # Start training logger.info("Starting training...") train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint) - + # Save the final model logger.info("Saving final model...") trainer.save_model() trainer.save_state() - + # Log metrics logger.info(f"Training completed! Metrics: {train_result.metrics}") if __name__ == "__main__": - main() \ No newline at end of file + main()