mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-02 11:04:25 +00:00
Claude generated train script
This commit is contained in:
parent
0ebc35cf1f
commit
91e7b5ce3f
@ -4,10 +4,16 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
EarlyStoppingCallback
|
||||
)
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from olmocr.train.config import Config
|
||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||
@ -21,61 +27,42 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def print_sample(sample, dataset_name):
|
||||
"""Pretty print a dataset sample."""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Sample from: {dataset_name}")
|
||||
print(f"{'='*80}")
|
||||
def create_data_collator():
|
||||
"""Create a data collator for vision-language models."""
|
||||
def collate_fn(examples):
|
||||
# Filter out None values and extract the fields we need
|
||||
batch = {
|
||||
'input_ids': [],
|
||||
'attention_mask': [],
|
||||
'labels': [],
|
||||
'pixel_values': [],
|
||||
'image_grid_thw': []
|
||||
}
|
||||
|
||||
for example in examples:
|
||||
if example is not None:
|
||||
batch['input_ids'].append(example['input_ids'])
|
||||
batch['attention_mask'].append(example['attention_mask'])
|
||||
batch['labels'].append(example['labels'])
|
||||
batch['pixel_values'].append(example['pixel_values'])
|
||||
batch['image_grid_thw'].append(example['image_grid_thw'])
|
||||
|
||||
# Convert lists to tensors with proper padding
|
||||
# Note: For Qwen2-VL, we typically handle variable length sequences
|
||||
# The model's processor should handle the padding internally
|
||||
return {
|
||||
'input_ids': torch.stack(batch['input_ids']),
|
||||
'attention_mask': torch.stack(batch['attention_mask']),
|
||||
'labels': torch.stack(batch['labels']),
|
||||
'pixel_values': batch['pixel_values'], # Keep as list for now
|
||||
'image_grid_thw': torch.stack(batch['image_grid_thw'])
|
||||
}
|
||||
|
||||
# Print keys
|
||||
print(f"\nAvailable keys: {list(sample.keys())}")
|
||||
|
||||
# Print path information
|
||||
if 'markdown_path' in sample:
|
||||
print(f"\nMarkdown path: {sample['markdown_path']}")
|
||||
if 'pdf_path' in sample:
|
||||
print(f"PDF path: {sample['pdf_path']}")
|
||||
|
||||
# Print page data
|
||||
if 'page_data' in sample:
|
||||
print(f"\nPage data:")
|
||||
print(f" Primary language: {sample['page_data'].primary_language}")
|
||||
print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}")
|
||||
print(f" Rotation correction: {sample['page_data'].rotation_correction}")
|
||||
print(f" Is table: {sample['page_data'].is_table}")
|
||||
print(f" Is diagram: {sample['page_data'].is_diagram}")
|
||||
print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None")
|
||||
|
||||
# Print image info
|
||||
if 'image' in sample:
|
||||
print(f"\nImage shape: {sample['image'].size}")
|
||||
|
||||
# Print anchor text preview
|
||||
if 'anchor_text' in sample:
|
||||
print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...")
|
||||
|
||||
# Print instruction prompt preview
|
||||
if 'instruction_prompt' in sample:
|
||||
print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...")
|
||||
|
||||
# Print response preview
|
||||
if 'response' in sample:
|
||||
print(f"\nResponse preview: {sample['response'][:200]}...")
|
||||
|
||||
# Print tokenization info
|
||||
if 'input_ids' in sample:
|
||||
print(f"\nTokenization info:")
|
||||
print(f" Input IDs shape: {sample['input_ids'].shape}")
|
||||
print(f" Attention mask shape: {sample['attention_mask'].shape}")
|
||||
print(f" Labels shape: {sample['labels'].shape}")
|
||||
if 'pixel_values' in sample:
|
||||
print(f" Pixel values shape: {sample['pixel_values'].shape}")
|
||||
if 'image_grid_thw' in sample:
|
||||
print(f" Image grid THW: {sample['image_grid_thw']}")
|
||||
return collate_fn
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading")
|
||||
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
@ -103,45 +90,134 @@ def main():
|
||||
trust_remote_code=config.model.processor_trust_remote_code
|
||||
)
|
||||
|
||||
# Process training datasets
|
||||
print(f"\n{'='*80}")
|
||||
print("TRAINING DATASETS")
|
||||
print(f"{'='*80}")
|
||||
# Load model
|
||||
logger.info(f"Loading model: {config.model.name}")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
config.model.name,
|
||||
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
|
||||
device_map=config.model.device_map,
|
||||
trust_remote_code=config.model.trust_remote_code,
|
||||
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
|
||||
)
|
||||
|
||||
# Enable gradient checkpointing if configured
|
||||
if config.training.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
|
||||
|
||||
# Create training datasets
|
||||
logger.info("Creating training datasets...")
|
||||
train_datasets = []
|
||||
for i, dataset_cfg in enumerate(config.dataset.train):
|
||||
root_dir = dataset_cfg['root_dir']
|
||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
||||
|
||||
logger.info(f"\nCreating training dataset {i+1} from: {root_dir}")
|
||||
logger.info(f"Creating training dataset {i+1} from: {root_dir}")
|
||||
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||
logger.info(f"Found {len(dataset)} samples")
|
||||
|
||||
if len(dataset) > 0:
|
||||
# Get first sample
|
||||
sample = dataset[0]
|
||||
print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}")
|
||||
train_datasets.append(dataset)
|
||||
|
||||
# Process evaluation datasets
|
||||
print(f"\n\n{'='*80}")
|
||||
print("EVALUATION DATASETS")
|
||||
print(f"{'='*80}")
|
||||
# Combine all training datasets
|
||||
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
|
||||
logger.info(f"Total training samples: {len(train_dataset)}")
|
||||
|
||||
# Create evaluation datasets
|
||||
logger.info("Creating evaluation datasets...")
|
||||
eval_datasets = []
|
||||
for i, dataset_cfg in enumerate(config.dataset.eval):
|
||||
root_dir = dataset_cfg['root_dir']
|
||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
||||
|
||||
logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}")
|
||||
logger.info(f"Creating evaluation dataset {i+1} from: {root_dir}")
|
||||
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||
logger.info(f"Found {len(dataset)} samples")
|
||||
|
||||
if len(dataset) > 0:
|
||||
# Get first sample
|
||||
sample = dataset[0]
|
||||
print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}")
|
||||
eval_datasets.append(dataset)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("Dataset loading test completed!")
|
||||
print(f"{'='*80}")
|
||||
# Combine all evaluation datasets
|
||||
eval_dataset = ConcatDataset(eval_datasets) if len(eval_datasets) > 1 else eval_datasets[0]
|
||||
logger.info(f"Total evaluation samples: {len(eval_dataset)}")
|
||||
|
||||
# Set up training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=config.training.output_dir,
|
||||
num_train_epochs=config.training.num_train_epochs,
|
||||
per_device_train_batch_size=config.training.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=config.training.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
|
||||
learning_rate=config.training.learning_rate,
|
||||
lr_scheduler_type=config.training.lr_scheduler_type,
|
||||
warmup_ratio=config.training.warmup_ratio,
|
||||
warmup_steps=config.training.warmup_steps,
|
||||
optim=config.training.optim,
|
||||
adam_beta1=config.training.adam_beta1,
|
||||
adam_beta2=config.training.adam_beta2,
|
||||
adam_epsilon=config.training.adam_epsilon,
|
||||
weight_decay=config.training.weight_decay,
|
||||
max_grad_norm=config.training.max_grad_norm,
|
||||
fp16=config.training.fp16,
|
||||
bf16=config.training.bf16,
|
||||
tf32=config.training.tf32,
|
||||
eval_strategy=config.training.evaluation_strategy,
|
||||
eval_steps=config.training.eval_steps,
|
||||
save_strategy=config.training.save_strategy,
|
||||
save_steps=config.training.save_steps,
|
||||
save_total_limit=config.training.save_total_limit,
|
||||
load_best_model_at_end=config.training.load_best_model_at_end,
|
||||
metric_for_best_model=config.training.metric_for_best_model,
|
||||
greater_is_better=config.training.greater_is_better,
|
||||
logging_dir=config.training.logging_dir,
|
||||
logging_strategy=config.training.logging_strategy,
|
||||
logging_steps=config.training.logging_steps,
|
||||
logging_first_step=config.training.logging_first_step,
|
||||
report_to=config.training.report_to,
|
||||
seed=config.training.seed,
|
||||
data_seed=config.training.data_seed,
|
||||
push_to_hub=config.training.push_to_hub,
|
||||
hub_model_id=config.training.hub_model_id,
|
||||
hub_strategy=config.training.hub_strategy,
|
||||
resume_from_checkpoint=config.training.resume_from_checkpoint,
|
||||
deepspeed=config.training.deepspeed,
|
||||
dataloader_drop_last=config.training.dataloader_drop_last,
|
||||
dataloader_num_workers=config.training.dataloader_num_workers,
|
||||
remove_unused_columns=config.training.remove_unused_columns,
|
||||
run_name=config.run_name,
|
||||
)
|
||||
|
||||
# Set up callbacks
|
||||
callbacks = []
|
||||
if config.training.use_early_stopping:
|
||||
callbacks.append(
|
||||
EarlyStoppingCallback(
|
||||
early_stopping_patience=config.training.early_stopping_patience,
|
||||
early_stopping_threshold=config.training.early_stopping_threshold
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
logger.info("Initializing trainer...")
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=create_data_collator(),
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# Start training
|
||||
logger.info("Starting training...")
|
||||
train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
|
||||
|
||||
# Save the final model
|
||||
logger.info("Saving final model...")
|
||||
trainer.save_model()
|
||||
trainer.save_state()
|
||||
|
||||
# Log metrics
|
||||
logger.info(f"Training completed! Metrics: {train_result.metrics}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user