Claude generated train script

This commit is contained in:
Jake Poznanski 2025-06-24 22:56:35 +00:00
parent 0ebc35cf1f
commit 91e7b5ce3f

View File

@ -4,10 +4,16 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse import argparse
import logging import logging
from pathlib import Path
from pprint import pprint
from transformers import AutoProcessor from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
Trainer,
TrainingArguments,
EarlyStoppingCallback
)
import torch
from torch.utils.data import ConcatDataset
from olmocr.train.config import Config from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset from olmocr.train.dataloader import BaseMarkdownPDFDataset
@ -21,61 +27,42 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def print_sample(sample, dataset_name): def create_data_collator():
"""Pretty print a dataset sample.""" """Create a data collator for vision-language models."""
print(f"\n{'='*80}") def collate_fn(examples):
print(f"Sample from: {dataset_name}") # Filter out None values and extract the fields we need
print(f"{'='*80}") batch = {
'input_ids': [],
'attention_mask': [],
'labels': [],
'pixel_values': [],
'image_grid_thw': []
}
for example in examples:
if example is not None:
batch['input_ids'].append(example['input_ids'])
batch['attention_mask'].append(example['attention_mask'])
batch['labels'].append(example['labels'])
batch['pixel_values'].append(example['pixel_values'])
batch['image_grid_thw'].append(example['image_grid_thw'])
# Convert lists to tensors with proper padding
# Note: For Qwen2-VL, we typically handle variable length sequences
# The model's processor should handle the padding internally
return {
'input_ids': torch.stack(batch['input_ids']),
'attention_mask': torch.stack(batch['attention_mask']),
'labels': torch.stack(batch['labels']),
'pixel_values': batch['pixel_values'], # Keep as list for now
'image_grid_thw': torch.stack(batch['image_grid_thw'])
}
# Print keys return collate_fn
print(f"\nAvailable keys: {list(sample.keys())}")
# Print path information
if 'markdown_path' in sample:
print(f"\nMarkdown path: {sample['markdown_path']}")
if 'pdf_path' in sample:
print(f"PDF path: {sample['pdf_path']}")
# Print page data
if 'page_data' in sample:
print(f"\nPage data:")
print(f" Primary language: {sample['page_data'].primary_language}")
print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}")
print(f" Rotation correction: {sample['page_data'].rotation_correction}")
print(f" Is table: {sample['page_data'].is_table}")
print(f" Is diagram: {sample['page_data'].is_diagram}")
print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None")
# Print image info
if 'image' in sample:
print(f"\nImage shape: {sample['image'].size}")
# Print anchor text preview
if 'anchor_text' in sample:
print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...")
# Print instruction prompt preview
if 'instruction_prompt' in sample:
print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...")
# Print response preview
if 'response' in sample:
print(f"\nResponse preview: {sample['response'][:200]}...")
# Print tokenization info
if 'input_ids' in sample:
print(f"\nTokenization info:")
print(f" Input IDs shape: {sample['input_ids'].shape}")
print(f" Attention mask shape: {sample['attention_mask'].shape}")
print(f" Labels shape: {sample['labels'].shape}")
if 'pixel_values' in sample:
print(f" Pixel values shape: {sample['pixel_values'].shape}")
if 'image_grid_thw' in sample:
print(f" Image grid THW: {sample['image_grid_thw']}")
def main(): def main():
parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading") parser = argparse.ArgumentParser(description="Train OlmOCR model")
parser.add_argument( parser.add_argument(
"--config", "--config",
type=str, type=str,
@ -103,45 +90,134 @@ def main():
trust_remote_code=config.model.processor_trust_remote_code trust_remote_code=config.model.processor_trust_remote_code
) )
# Process training datasets # Load model
print(f"\n{'='*80}") logger.info(f"Loading model: {config.model.name}")
print("TRAINING DATASETS") model = Qwen2VLForConditionalGeneration.from_pretrained(
print(f"{'='*80}") config.model.name,
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
device_map=config.model.device_map,
trust_remote_code=config.model.trust_remote_code,
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
)
# Enable gradient checkpointing if configured
if config.training.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
# Create training datasets
logger.info("Creating training datasets...")
train_datasets = []
for i, dataset_cfg in enumerate(config.dataset.train): for i, dataset_cfg in enumerate(config.dataset.train):
root_dir = dataset_cfg['root_dir'] root_dir = dataset_cfg['root_dir']
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
logger.info(f"\nCreating training dataset {i+1} from: {root_dir}") logger.info(f"Creating training dataset {i+1} from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples") logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0: if len(dataset) > 0:
# Get first sample train_datasets.append(dataset)
sample = dataset[0]
print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}")
# Process evaluation datasets # Combine all training datasets
print(f"\n\n{'='*80}") train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
print("EVALUATION DATASETS") logger.info(f"Total training samples: {len(train_dataset)}")
print(f"{'='*80}")
# Create evaluation datasets
logger.info("Creating evaluation datasets...")
eval_datasets = []
for i, dataset_cfg in enumerate(config.dataset.eval): for i, dataset_cfg in enumerate(config.dataset.eval):
root_dir = dataset_cfg['root_dir'] root_dir = dataset_cfg['root_dir']
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor) pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}") logger.info(f"Creating evaluation dataset {i+1} from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps) dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples") logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0: if len(dataset) > 0:
# Get first sample eval_datasets.append(dataset)
sample = dataset[0]
print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}")
print(f"\n{'='*80}") # Combine all evaluation datasets
print("Dataset loading test completed!") eval_dataset = ConcatDataset(eval_datasets) if len(eval_datasets) > 1 else eval_datasets[0]
print(f"{'='*80}") logger.info(f"Total evaluation samples: {len(eval_dataset)}")
# Set up training arguments
training_args = TrainingArguments(
output_dir=config.training.output_dir,
num_train_epochs=config.training.num_train_epochs,
per_device_train_batch_size=config.training.per_device_train_batch_size,
per_device_eval_batch_size=config.training.per_device_eval_batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
learning_rate=config.training.learning_rate,
lr_scheduler_type=config.training.lr_scheduler_type,
warmup_ratio=config.training.warmup_ratio,
warmup_steps=config.training.warmup_steps,
optim=config.training.optim,
adam_beta1=config.training.adam_beta1,
adam_beta2=config.training.adam_beta2,
adam_epsilon=config.training.adam_epsilon,
weight_decay=config.training.weight_decay,
max_grad_norm=config.training.max_grad_norm,
fp16=config.training.fp16,
bf16=config.training.bf16,
tf32=config.training.tf32,
eval_strategy=config.training.evaluation_strategy,
eval_steps=config.training.eval_steps,
save_strategy=config.training.save_strategy,
save_steps=config.training.save_steps,
save_total_limit=config.training.save_total_limit,
load_best_model_at_end=config.training.load_best_model_at_end,
metric_for_best_model=config.training.metric_for_best_model,
greater_is_better=config.training.greater_is_better,
logging_dir=config.training.logging_dir,
logging_strategy=config.training.logging_strategy,
logging_steps=config.training.logging_steps,
logging_first_step=config.training.logging_first_step,
report_to=config.training.report_to,
seed=config.training.seed,
data_seed=config.training.data_seed,
push_to_hub=config.training.push_to_hub,
hub_model_id=config.training.hub_model_id,
hub_strategy=config.training.hub_strategy,
resume_from_checkpoint=config.training.resume_from_checkpoint,
deepspeed=config.training.deepspeed,
dataloader_drop_last=config.training.dataloader_drop_last,
dataloader_num_workers=config.training.dataloader_num_workers,
remove_unused_columns=config.training.remove_unused_columns,
run_name=config.run_name,
)
# Set up callbacks
callbacks = []
if config.training.use_early_stopping:
callbacks.append(
EarlyStoppingCallback(
early_stopping_patience=config.training.early_stopping_patience,
early_stopping_threshold=config.training.early_stopping_threshold
)
)
# Initialize trainer
logger.info("Initializing trainer...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=create_data_collator(),
callbacks=callbacks,
)
# Start training
logger.info("Starting training...")
train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
# Save the final model
logger.info("Saving final model...")
trainer.save_model()
trainer.save_state()
# Log metrics
logger.info(f"Training completed! Metrics: {train_result.metrics}")
if __name__ == "__main__": if __name__ == "__main__":