From 0ebc35cf1fe41a3bb6974da75823d78fa28f5251 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 24 Jun 2025 22:48:36 +0000 Subject: [PATCH] Basic train config loader for datasets --- olmocr/train/config.py | 68 +++++----- olmocr/train/configs/example_config.yaml | 15 +-- olmocr/train/train.py | 157 +++++++++++++++++++++-- pyproject.toml | 1 - 4 files changed, 193 insertions(+), 48 deletions(-) diff --git a/olmocr/train/config.py b/olmocr/train/config.py index 5b8fa76..bd06681 100644 --- a/olmocr/train/config.py +++ b/olmocr/train/config.py @@ -64,9 +64,20 @@ class TokenizerStepConfig(PipelineStepConfig): @dataclass +class DatasetItemConfig: + """Configuration for a single dataset item.""" + root_dir: str + pipeline: List[Dict[str, Any]] = field(default_factory=list) + + # Optional sampling + max_samples: Optional[int] = None + + +@dataclass class DatasetConfig: """Configuration for dataset and data loading.""" - root_dir: str + train: List[Dict[str, Any]] = field(default_factory=list) + eval: List[Dict[str, Any]] = field(default_factory=list) # DataLoader configuration batch_size: int = 1 @@ -76,32 +87,12 @@ class DatasetConfig: pin_memory: bool = True prefetch_factor: int = 2 - # Pipeline steps configuration - pipeline_steps: List[Dict[str, Any]] = field(default_factory=lambda: [ - {"name": "FrontMatterParser", "use_page_response_class": True}, - {"name": "PDFRenderer", "target_longest_image_dim": 1024}, - {"name": "StaticLengthDocumentAnchoring", "target_anchor_text_len": 6000}, - {"name": "FinetuningPrompt"}, - {"name": "FrontMatterOutputFormat"}, - {"name": "InstructUserMessages"}, - {"name": "Tokenizer", "masking_index": -100, "end_of_message_token": "<|im_end|>"} - ]) - - # Optional dataset sampling - max_samples: Optional[int] = None - validation_split: float = 0.1 + # Global seed seed: int = 42 - # Train/validation split - train_indices: Optional[List[int]] = None - val_indices: Optional[List[int]] = None - # Caching cache_dir: Optional[str] = None use_cache: bool = False - - # Data augmentation (future extension) - augmentation: Dict[str, Any] = field(default_factory=dict) @dataclass @@ -280,9 +271,14 @@ class Config: def validate(self) -> None: """Validate configuration values.""" - # Dataset validation - if not os.path.exists(self.dataset.root_dir): - raise ValueError(f"Dataset root directory does not exist: {self.dataset.root_dir}") + # Dataset validation - check all train and eval datasets + for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]: + for i, dataset_cfg in enumerate(datasets): + root_dir = dataset_cfg.get('root_dir') + if not root_dir: + raise ValueError(f"Missing root_dir for {split_name} dataset {i}") + if not os.path.exists(root_dir): + raise ValueError(f"Dataset root directory does not exist: {root_dir}") # Training validation if self.training.warmup_steps is not None and self.training.warmup_ratio > 0: @@ -303,8 +299,16 @@ class Config: self.training.logging_dir = os.path.join(self.training.output_dir, "logs") Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True) - def get_pipeline_steps(self, processor=None): - """Create actual pipeline step instances from configuration.""" + def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None): + """Create actual pipeline step instances from pipeline configuration. + + Args: + pipeline_config: List of pipeline step configurations + processor: The model processor (required for Tokenizer step) + + Returns: + List of initialized pipeline step instances + """ from olmocr.train.dataloader import ( FrontMatterParser, PDFRenderer, @@ -317,14 +321,18 @@ class Config: from olmocr.prompts.prompts import PageResponse steps = [] - for step_config in self.dataset.pipeline_steps: + for step_config in pipeline_config: if not step_config.get('enabled', True): continue step_name = step_config['name'] if step_name == 'FrontMatterParser': - front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None + # Handle both old and new config format + if 'front_matter_class' in step_config: + front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None + else: + front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None steps.append(FrontMatterParser(front_matter_class=front_matter_class)) elif step_name == 'PDFRenderer': @@ -365,7 +373,7 @@ def create_default_config() -> Config: """Create a default configuration.""" return Config( model=ModelConfig(), - dataset=DatasetConfig(root_dir="/path/to/dataset"), + dataset=DatasetConfig(), training=TrainingConfig() ) diff --git a/olmocr/train/configs/example_config.yaml b/olmocr/train/configs/example_config.yaml index 5326f28..71c6100 100644 --- a/olmocr/train/configs/example_config.yaml +++ b/olmocr/train/configs/example_config.yaml @@ -27,8 +27,8 @@ model: dataset: train: - - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_train_s2pdf/ - pipeline *basic_pipeline: + - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/ + pipeline: &basic_pipeline - name: FrontMatterParser front_matter_class: PageResponse - name: PDFRenderer @@ -41,15 +41,14 @@ dataset: - name: Tokenizer masking_index: -100 end_of_message_token: "<|im_end|>" - - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_train_s2pdf/ - pipeline: *reuse basic_pipeline above* + - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/ + pipeline: *basic_pipeline eval: - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/ - pipeline: *reuse basic_pipeline above* - - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_eval_s2pdf/ - pipeline: *reuse basic_pipeline above* - + pipeline: *basic_pipeline + - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/ + pipeline: *basic_pipeline # Training configuration diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 0ca3741..4eb6a90 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -1,9 +1,148 @@ -# TODO Overall, this code will read in a config yaml file with omega conf -# From that config, we are going to use HuggingFace Trainer to train a model -# TODOS: -# DONE Build a script to convert olmocr-mix to a new dataloader format -# DONE Write a new dataloader and collator, with tests that brings in everything, only needs to support batch size 1 for this first version -# Get a basic config yaml file system working -# Get a basic hugging face trainer running, supporting Qwen2.5VL for now -# Saving and restoring training checkpoints -# Converting training checkpoints to vllm compatible checkpoinst +""" +Simple script to test OlmOCR dataset loading with YAML configuration. +""" + +import argparse +import logging +from pathlib import Path +from pprint import pprint + +from transformers import AutoProcessor + +from olmocr.train.config import Config +from olmocr.train.dataloader import BaseMarkdownPDFDataset + +# Configure logging +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +def print_sample(sample, dataset_name): + """Pretty print a dataset sample.""" + print(f"\n{'='*80}") + print(f"Sample from: {dataset_name}") + print(f"{'='*80}") + + # Print keys + print(f"\nAvailable keys: {list(sample.keys())}") + + # Print path information + if 'markdown_path' in sample: + print(f"\nMarkdown path: {sample['markdown_path']}") + if 'pdf_path' in sample: + print(f"PDF path: {sample['pdf_path']}") + + # Print page data + if 'page_data' in sample: + print(f"\nPage data:") + print(f" Primary language: {sample['page_data'].primary_language}") + print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}") + print(f" Rotation correction: {sample['page_data'].rotation_correction}") + print(f" Is table: {sample['page_data'].is_table}") + print(f" Is diagram: {sample['page_data'].is_diagram}") + print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None") + + # Print image info + if 'image' in sample: + print(f"\nImage shape: {sample['image'].size}") + + # Print anchor text preview + if 'anchor_text' in sample: + print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...") + + # Print instruction prompt preview + if 'instruction_prompt' in sample: + print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...") + + # Print response preview + if 'response' in sample: + print(f"\nResponse preview: {sample['response'][:200]}...") + + # Print tokenization info + if 'input_ids' in sample: + print(f"\nTokenization info:") + print(f" Input IDs shape: {sample['input_ids'].shape}") + print(f" Attention mask shape: {sample['attention_mask'].shape}") + print(f" Labels shape: {sample['labels'].shape}") + if 'pixel_values' in sample: + print(f" Pixel values shape: {sample['pixel_values'].shape}") + if 'image_grid_thw' in sample: + print(f" Image grid THW: {sample['image_grid_thw']}") + + +def main(): + parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading") + parser.add_argument( + "--config", + type=str, + default="olmocr/train/configs/example_config.yaml", + help="Path to YAML configuration file" + ) + + args = parser.parse_args() + + # Load configuration + logger.info(f"Loading configuration from: {args.config}") + config = Config.from_yaml(args.config) + + # Validate configuration + try: + config.validate() + except ValueError as e: + logger.error(f"Configuration validation failed: {e}") + return + + # Load processor for tokenization + logger.info(f"Loading processor: {config.model.name}") + processor = AutoProcessor.from_pretrained( + config.model.name, + trust_remote_code=config.model.processor_trust_remote_code + ) + + # Process training datasets + print(f"\n{'='*80}") + print("TRAINING DATASETS") + print(f"{'='*80}") + + for i, dataset_cfg in enumerate(config.dataset.train): + root_dir = dataset_cfg['root_dir'] + pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) + + logger.info(f"\nCreating training dataset {i+1} from: {root_dir}") + dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) + logger.info(f"Found {len(dataset)} samples") + + if len(dataset) > 0: + # Get first sample + sample = dataset[0] + print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}") + + # Process evaluation datasets + print(f"\n\n{'='*80}") + print("EVALUATION DATASETS") + print(f"{'='*80}") + + for i, dataset_cfg in enumerate(config.dataset.eval): + root_dir = dataset_cfg['root_dir'] + pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) + + logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}") + dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) + logger.info(f"Found {len(dataset)} samples") + + if len(dataset) > 0: + # Get first sample + sample = dataset[0] + print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}") + + print(f"\n{'='*80}") + print("Dataset loading test completed!") + print(f"{'='*80}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1dbf992..6336791 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,6 @@ train = [ "s3fs", "necessary", "einops", - "transformers>=4.45.1" ] elo = [