Basic train config loader for datasets

This commit is contained in:
Jake Poznanski 2025-06-24 22:48:36 +00:00
parent b93c262dca
commit 0ebc35cf1f
4 changed files with 193 additions and 48 deletions

View File

@ -63,10 +63,21 @@ class TokenizerStepConfig(PipelineStepConfig):
end_of_message_token: str = "<|im_end|>"
@dataclass
class DatasetItemConfig:
"""Configuration for a single dataset item."""
root_dir: str
pipeline: List[Dict[str, Any]] = field(default_factory=list)
# Optional sampling
max_samples: Optional[int] = None
@dataclass
class DatasetConfig:
"""Configuration for dataset and data loading."""
root_dir: str
train: List[Dict[str, Any]] = field(default_factory=list)
eval: List[Dict[str, Any]] = field(default_factory=list)
# DataLoader configuration
batch_size: int = 1
@ -76,33 +87,13 @@ class DatasetConfig:
pin_memory: bool = True
prefetch_factor: int = 2
# Pipeline steps configuration
pipeline_steps: List[Dict[str, Any]] = field(default_factory=lambda: [
{"name": "FrontMatterParser", "use_page_response_class": True},
{"name": "PDFRenderer", "target_longest_image_dim": 1024},
{"name": "StaticLengthDocumentAnchoring", "target_anchor_text_len": 6000},
{"name": "FinetuningPrompt"},
{"name": "FrontMatterOutputFormat"},
{"name": "InstructUserMessages"},
{"name": "Tokenizer", "masking_index": -100, "end_of_message_token": "<|im_end|>"}
])
# Optional dataset sampling
max_samples: Optional[int] = None
validation_split: float = 0.1
# Global seed
seed: int = 42
# Train/validation split
train_indices: Optional[List[int]] = None
val_indices: Optional[List[int]] = None
# Caching
cache_dir: Optional[str] = None
use_cache: bool = False
# Data augmentation (future extension)
augmentation: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ModelConfig:
@ -280,9 +271,14 @@ class Config:
def validate(self) -> None:
"""Validate configuration values."""
# Dataset validation
if not os.path.exists(self.dataset.root_dir):
raise ValueError(f"Dataset root directory does not exist: {self.dataset.root_dir}")
# Dataset validation - check all train and eval datasets
for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]:
for i, dataset_cfg in enumerate(datasets):
root_dir = dataset_cfg.get('root_dir')
if not root_dir:
raise ValueError(f"Missing root_dir for {split_name} dataset {i}")
if not os.path.exists(root_dir):
raise ValueError(f"Dataset root directory does not exist: {root_dir}")
# Training validation
if self.training.warmup_steps is not None and self.training.warmup_ratio > 0:
@ -303,8 +299,16 @@ class Config:
self.training.logging_dir = os.path.join(self.training.output_dir, "logs")
Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True)
def get_pipeline_steps(self, processor=None):
"""Create actual pipeline step instances from configuration."""
def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None):
"""Create actual pipeline step instances from pipeline configuration.
Args:
pipeline_config: List of pipeline step configurations
processor: The model processor (required for Tokenizer step)
Returns:
List of initialized pipeline step instances
"""
from olmocr.train.dataloader import (
FrontMatterParser,
PDFRenderer,
@ -317,14 +321,18 @@ class Config:
from olmocr.prompts.prompts import PageResponse
steps = []
for step_config in self.dataset.pipeline_steps:
for step_config in pipeline_config:
if not step_config.get('enabled', True):
continue
step_name = step_config['name']
if step_name == 'FrontMatterParser':
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None
# Handle both old and new config format
if 'front_matter_class' in step_config:
front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None
else:
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
elif step_name == 'PDFRenderer':
@ -365,7 +373,7 @@ def create_default_config() -> Config:
"""Create a default configuration."""
return Config(
model=ModelConfig(),
dataset=DatasetConfig(root_dir="/path/to/dataset"),
dataset=DatasetConfig(),
training=TrainingConfig()
)

View File

@ -27,8 +27,8 @@ model:
dataset:
train:
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
pipeline *basic_pipeline:
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
pipeline: &basic_pipeline
- name: FrontMatterParser
front_matter_class: PageResponse
- name: PDFRenderer
@ -41,15 +41,14 @@ dataset:
- name: Tokenizer
masking_index: -100
end_of_message_token: "<|im_end|>"
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_train_s2pdf/
pipeline: *reuse basic_pipeline above*
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/
pipeline: *basic_pipeline
eval:
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
pipeline: *reuse basic_pipeline above*
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_eval_s2pdf/
pipeline: *reuse basic_pipeline above*
pipeline: *basic_pipeline
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/
pipeline: *basic_pipeline
# Training configuration

View File

@ -1,9 +1,148 @@
# TODO Overall, this code will read in a config yaml file with omega conf
# From that config, we are going to use HuggingFace Trainer to train a model
# TODOS:
# DONE Build a script to convert olmocr-mix to a new dataloader format
# DONE Write a new dataloader and collator, with tests that brings in everything, only needs to support batch size 1 for this first version
# Get a basic config yaml file system working
# Get a basic hugging face trainer running, supporting Qwen2.5VL for now
# Saving and restoring training checkpoints
# Converting training checkpoints to vllm compatible checkpoinst
"""
Simple script to test OlmOCR dataset loading with YAML configuration.
"""
import argparse
import logging
from pathlib import Path
from pprint import pprint
from transformers import AutoProcessor
from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset
# Configure logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def print_sample(sample, dataset_name):
"""Pretty print a dataset sample."""
print(f"\n{'='*80}")
print(f"Sample from: {dataset_name}")
print(f"{'='*80}")
# Print keys
print(f"\nAvailable keys: {list(sample.keys())}")
# Print path information
if 'markdown_path' in sample:
print(f"\nMarkdown path: {sample['markdown_path']}")
if 'pdf_path' in sample:
print(f"PDF path: {sample['pdf_path']}")
# Print page data
if 'page_data' in sample:
print(f"\nPage data:")
print(f" Primary language: {sample['page_data'].primary_language}")
print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}")
print(f" Rotation correction: {sample['page_data'].rotation_correction}")
print(f" Is table: {sample['page_data'].is_table}")
print(f" Is diagram: {sample['page_data'].is_diagram}")
print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None")
# Print image info
if 'image' in sample:
print(f"\nImage shape: {sample['image'].size}")
# Print anchor text preview
if 'anchor_text' in sample:
print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...")
# Print instruction prompt preview
if 'instruction_prompt' in sample:
print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...")
# Print response preview
if 'response' in sample:
print(f"\nResponse preview: {sample['response'][:200]}...")
# Print tokenization info
if 'input_ids' in sample:
print(f"\nTokenization info:")
print(f" Input IDs shape: {sample['input_ids'].shape}")
print(f" Attention mask shape: {sample['attention_mask'].shape}")
print(f" Labels shape: {sample['labels'].shape}")
if 'pixel_values' in sample:
print(f" Pixel values shape: {sample['pixel_values'].shape}")
if 'image_grid_thw' in sample:
print(f" Image grid THW: {sample['image_grid_thw']}")
def main():
parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading")
parser.add_argument(
"--config",
type=str,
default="olmocr/train/configs/example_config.yaml",
help="Path to YAML configuration file"
)
args = parser.parse_args()
# Load configuration
logger.info(f"Loading configuration from: {args.config}")
config = Config.from_yaml(args.config)
# Validate configuration
try:
config.validate()
except ValueError as e:
logger.error(f"Configuration validation failed: {e}")
return
# Load processor for tokenization
logger.info(f"Loading processor: {config.model.name}")
processor = AutoProcessor.from_pretrained(
config.model.name,
trust_remote_code=config.model.processor_trust_remote_code
)
# Process training datasets
print(f"\n{'='*80}")
print("TRAINING DATASETS")
print(f"{'='*80}")
for i, dataset_cfg in enumerate(config.dataset.train):
root_dir = dataset_cfg['root_dir']
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
logger.info(f"\nCreating training dataset {i+1} from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0:
# Get first sample
sample = dataset[0]
print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}")
# Process evaluation datasets
print(f"\n\n{'='*80}")
print("EVALUATION DATASETS")
print(f"{'='*80}")
for i, dataset_cfg in enumerate(config.dataset.eval):
root_dir = dataset_cfg['root_dir']
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0:
# Get first sample
sample = dataset[0]
print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}")
print(f"\n{'='*80}")
print("Dataset loading test completed!")
print(f"{'='*80}")
if __name__ == "__main__":
main()

View File

@ -105,7 +105,6 @@ train = [
"s3fs",
"necessary",
"einops",
"transformers>=4.45.1"
]
elo = [