diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 4eb6a90..5ec9630 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -4,10 +4,16 @@ Simple script to test OlmOCR dataset loading with YAML configuration. import argparse import logging -from pathlib import Path -from pprint import pprint -from transformers import AutoProcessor +from transformers import ( + AutoProcessor, + Qwen2VLForConditionalGeneration, + Trainer, + TrainingArguments, + EarlyStoppingCallback +) +import torch +from torch.utils.data import ConcatDataset from olmocr.train.config import Config from olmocr.train.dataloader import BaseMarkdownPDFDataset @@ -21,61 +27,42 @@ logging.basicConfig( logger = logging.getLogger(__name__) -def print_sample(sample, dataset_name): - """Pretty print a dataset sample.""" - print(f"\n{'='*80}") - print(f"Sample from: {dataset_name}") - print(f"{'='*80}") +def create_data_collator(): + """Create a data collator for vision-language models.""" + def collate_fn(examples): + # Filter out None values and extract the fields we need + batch = { + 'input_ids': [], + 'attention_mask': [], + 'labels': [], + 'pixel_values': [], + 'image_grid_thw': [] + } + + for example in examples: + if example is not None: + batch['input_ids'].append(example['input_ids']) + batch['attention_mask'].append(example['attention_mask']) + batch['labels'].append(example['labels']) + batch['pixel_values'].append(example['pixel_values']) + batch['image_grid_thw'].append(example['image_grid_thw']) + + # Convert lists to tensors with proper padding + # Note: For Qwen2-VL, we typically handle variable length sequences + # The model's processor should handle the padding internally + return { + 'input_ids': torch.stack(batch['input_ids']), + 'attention_mask': torch.stack(batch['attention_mask']), + 'labels': torch.stack(batch['labels']), + 'pixel_values': batch['pixel_values'], # Keep as list for now + 'image_grid_thw': torch.stack(batch['image_grid_thw']) + } - # Print keys - print(f"\nAvailable keys: {list(sample.keys())}") - - # Print path information - if 'markdown_path' in sample: - print(f"\nMarkdown path: {sample['markdown_path']}") - if 'pdf_path' in sample: - print(f"PDF path: {sample['pdf_path']}") - - # Print page data - if 'page_data' in sample: - print(f"\nPage data:") - print(f" Primary language: {sample['page_data'].primary_language}") - print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}") - print(f" Rotation correction: {sample['page_data'].rotation_correction}") - print(f" Is table: {sample['page_data'].is_table}") - print(f" Is diagram: {sample['page_data'].is_diagram}") - print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None") - - # Print image info - if 'image' in sample: - print(f"\nImage shape: {sample['image'].size}") - - # Print anchor text preview - if 'anchor_text' in sample: - print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...") - - # Print instruction prompt preview - if 'instruction_prompt' in sample: - print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...") - - # Print response preview - if 'response' in sample: - print(f"\nResponse preview: {sample['response'][:200]}...") - - # Print tokenization info - if 'input_ids' in sample: - print(f"\nTokenization info:") - print(f" Input IDs shape: {sample['input_ids'].shape}") - print(f" Attention mask shape: {sample['attention_mask'].shape}") - print(f" Labels shape: {sample['labels'].shape}") - if 'pixel_values' in sample: - print(f" Pixel values shape: {sample['pixel_values'].shape}") - if 'image_grid_thw' in sample: - print(f" Image grid THW: {sample['image_grid_thw']}") + return collate_fn def main(): - parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading") + parser = argparse.ArgumentParser(description="Train OlmOCR model") parser.add_argument( "--config", type=str, @@ -103,45 +90,134 @@ def main(): trust_remote_code=config.model.processor_trust_remote_code ) - # Process training datasets - print(f"\n{'='*80}") - print("TRAINING DATASETS") - print(f"{'='*80}") + # Load model + logger.info(f"Loading model: {config.model.name}") + model = Qwen2VLForConditionalGeneration.from_pretrained( + config.model.name, + torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto", + device_map=config.model.device_map, + trust_remote_code=config.model.trust_remote_code, + attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None, + ) + # Enable gradient checkpointing if configured + if config.training.gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs) + + # Create training datasets + logger.info("Creating training datasets...") + train_datasets = [] for i, dataset_cfg in enumerate(config.dataset.train): root_dir = dataset_cfg['root_dir'] pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) - logger.info(f"\nCreating training dataset {i+1} from: {root_dir}") + logger.info(f"Creating training dataset {i+1} from: {root_dir}") dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) logger.info(f"Found {len(dataset)} samples") if len(dataset) > 0: - # Get first sample - sample = dataset[0] - print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}") + train_datasets.append(dataset) - # Process evaluation datasets - print(f"\n\n{'='*80}") - print("EVALUATION DATASETS") - print(f"{'='*80}") + # Combine all training datasets + train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0] + logger.info(f"Total training samples: {len(train_dataset)}") + # Create evaluation datasets + logger.info("Creating evaluation datasets...") + eval_datasets = [] for i, dataset_cfg in enumerate(config.dataset.eval): root_dir = dataset_cfg['root_dir'] pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) - logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}") + logger.info(f"Creating evaluation dataset {i+1} from: {root_dir}") dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) logger.info(f"Found {len(dataset)} samples") if len(dataset) > 0: - # Get first sample - sample = dataset[0] - print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}") + eval_datasets.append(dataset) - print(f"\n{'='*80}") - print("Dataset loading test completed!") - print(f"{'='*80}") + # Combine all evaluation datasets + eval_dataset = ConcatDataset(eval_datasets) if len(eval_datasets) > 1 else eval_datasets[0] + logger.info(f"Total evaluation samples: {len(eval_dataset)}") + + # Set up training arguments + training_args = TrainingArguments( + output_dir=config.training.output_dir, + num_train_epochs=config.training.num_train_epochs, + per_device_train_batch_size=config.training.per_device_train_batch_size, + per_device_eval_batch_size=config.training.per_device_eval_batch_size, + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + learning_rate=config.training.learning_rate, + lr_scheduler_type=config.training.lr_scheduler_type, + warmup_ratio=config.training.warmup_ratio, + warmup_steps=config.training.warmup_steps, + optim=config.training.optim, + adam_beta1=config.training.adam_beta1, + adam_beta2=config.training.adam_beta2, + adam_epsilon=config.training.adam_epsilon, + weight_decay=config.training.weight_decay, + max_grad_norm=config.training.max_grad_norm, + fp16=config.training.fp16, + bf16=config.training.bf16, + tf32=config.training.tf32, + eval_strategy=config.training.evaluation_strategy, + eval_steps=config.training.eval_steps, + save_strategy=config.training.save_strategy, + save_steps=config.training.save_steps, + save_total_limit=config.training.save_total_limit, + load_best_model_at_end=config.training.load_best_model_at_end, + metric_for_best_model=config.training.metric_for_best_model, + greater_is_better=config.training.greater_is_better, + logging_dir=config.training.logging_dir, + logging_strategy=config.training.logging_strategy, + logging_steps=config.training.logging_steps, + logging_first_step=config.training.logging_first_step, + report_to=config.training.report_to, + seed=config.training.seed, + data_seed=config.training.data_seed, + push_to_hub=config.training.push_to_hub, + hub_model_id=config.training.hub_model_id, + hub_strategy=config.training.hub_strategy, + resume_from_checkpoint=config.training.resume_from_checkpoint, + deepspeed=config.training.deepspeed, + dataloader_drop_last=config.training.dataloader_drop_last, + dataloader_num_workers=config.training.dataloader_num_workers, + remove_unused_columns=config.training.remove_unused_columns, + run_name=config.run_name, + ) + + # Set up callbacks + callbacks = [] + if config.training.use_early_stopping: + callbacks.append( + EarlyStoppingCallback( + early_stopping_patience=config.training.early_stopping_patience, + early_stopping_threshold=config.training.early_stopping_threshold + ) + ) + + # Initialize trainer + logger.info("Initializing trainer...") + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=create_data_collator(), + callbacks=callbacks, + ) + + # Start training + logger.info("Starting training...") + train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint) + + # Save the final model + logger.info("Saving final model...") + trainer.save_model() + trainer.save_state() + + # Log metrics + logger.info(f"Training completed! Metrics: {train_result.metrics}") if __name__ == "__main__":