Basic train config loader for datasets

This commit is contained in:
Jake Poznanski 2025-06-24 22:48:36 +00:00
parent b93c262dca
commit 0ebc35cf1f
4 changed files with 193 additions and 48 deletions

View File

@ -63,10 +63,21 @@ class TokenizerStepConfig(PipelineStepConfig):
end_of_message_token: str = "<|im_end|>" end_of_message_token: str = "<|im_end|>"
@dataclass
class DatasetItemConfig:
"""Configuration for a single dataset item."""
root_dir: str
pipeline: List[Dict[str, Any]] = field(default_factory=list)
# Optional sampling
max_samples: Optional[int] = None
@dataclass @dataclass
class DatasetConfig: class DatasetConfig:
"""Configuration for dataset and data loading.""" """Configuration for dataset and data loading."""
root_dir: str train: List[Dict[str, Any]] = field(default_factory=list)
eval: List[Dict[str, Any]] = field(default_factory=list)
# DataLoader configuration # DataLoader configuration
batch_size: int = 1 batch_size: int = 1
@ -76,33 +87,13 @@ class DatasetConfig:
pin_memory: bool = True pin_memory: bool = True
prefetch_factor: int = 2 prefetch_factor: int = 2
# Pipeline steps configuration # Global seed
pipeline_steps: List[Dict[str, Any]] = field(default_factory=lambda: [
{"name": "FrontMatterParser", "use_page_response_class": True},
{"name": "PDFRenderer", "target_longest_image_dim": 1024},
{"name": "StaticLengthDocumentAnchoring", "target_anchor_text_len": 6000},
{"name": "FinetuningPrompt"},
{"name": "FrontMatterOutputFormat"},
{"name": "InstructUserMessages"},
{"name": "Tokenizer", "masking_index": -100, "end_of_message_token": "<|im_end|>"}
])
# Optional dataset sampling
max_samples: Optional[int] = None
validation_split: float = 0.1
seed: int = 42 seed: int = 42
# Train/validation split
train_indices: Optional[List[int]] = None
val_indices: Optional[List[int]] = None
# Caching # Caching
cache_dir: Optional[str] = None cache_dir: Optional[str] = None
use_cache: bool = False use_cache: bool = False
# Data augmentation (future extension)
augmentation: Dict[str, Any] = field(default_factory=dict)
@dataclass @dataclass
class ModelConfig: class ModelConfig:
@ -280,9 +271,14 @@ class Config:
def validate(self) -> None: def validate(self) -> None:
"""Validate configuration values.""" """Validate configuration values."""
# Dataset validation # Dataset validation - check all train and eval datasets
if not os.path.exists(self.dataset.root_dir): for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]:
raise ValueError(f"Dataset root directory does not exist: {self.dataset.root_dir}") for i, dataset_cfg in enumerate(datasets):
root_dir = dataset_cfg.get('root_dir')
if not root_dir:
raise ValueError(f"Missing root_dir for {split_name} dataset {i}")
if not os.path.exists(root_dir):
raise ValueError(f"Dataset root directory does not exist: {root_dir}")
# Training validation # Training validation
if self.training.warmup_steps is not None and self.training.warmup_ratio > 0: if self.training.warmup_steps is not None and self.training.warmup_ratio > 0:
@ -303,8 +299,16 @@ class Config:
self.training.logging_dir = os.path.join(self.training.output_dir, "logs") self.training.logging_dir = os.path.join(self.training.output_dir, "logs")
Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True) Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True)
def get_pipeline_steps(self, processor=None): def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None):
"""Create actual pipeline step instances from configuration.""" """Create actual pipeline step instances from pipeline configuration.
Args:
pipeline_config: List of pipeline step configurations
processor: The model processor (required for Tokenizer step)
Returns:
List of initialized pipeline step instances
"""
from olmocr.train.dataloader import ( from olmocr.train.dataloader import (
FrontMatterParser, FrontMatterParser,
PDFRenderer, PDFRenderer,
@ -317,13 +321,17 @@ class Config:
from olmocr.prompts.prompts import PageResponse from olmocr.prompts.prompts import PageResponse
steps = [] steps = []
for step_config in self.dataset.pipeline_steps: for step_config in pipeline_config:
if not step_config.get('enabled', True): if not step_config.get('enabled', True):
continue continue
step_name = step_config['name'] step_name = step_config['name']
if step_name == 'FrontMatterParser': if step_name == 'FrontMatterParser':
# Handle both old and new config format
if 'front_matter_class' in step_config:
front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None
else:
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None
steps.append(FrontMatterParser(front_matter_class=front_matter_class)) steps.append(FrontMatterParser(front_matter_class=front_matter_class))
@ -365,7 +373,7 @@ def create_default_config() -> Config:
"""Create a default configuration.""" """Create a default configuration."""
return Config( return Config(
model=ModelConfig(), model=ModelConfig(),
dataset=DatasetConfig(root_dir="/path/to/dataset"), dataset=DatasetConfig(),
training=TrainingConfig() training=TrainingConfig()
) )

View File

@ -27,8 +27,8 @@ model:
dataset: dataset:
train: train:
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_train_s2pdf/ - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
pipeline *basic_pipeline: pipeline: &basic_pipeline
- name: FrontMatterParser - name: FrontMatterParser
front_matter_class: PageResponse front_matter_class: PageResponse
- name: PDFRenderer - name: PDFRenderer
@ -41,15 +41,14 @@ dataset:
- name: Tokenizer - name: Tokenizer
masking_index: -100 masking_index: -100
end_of_message_token: "<|im_end|>" end_of_message_token: "<|im_end|>"
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_train_s2pdf/ - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/
pipeline: *reuse basic_pipeline above* pipeline: *basic_pipeline
eval: eval:
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/ - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
pipeline: *reuse basic_pipeline above* pipeline: *basic_pipeline
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_eval_s2pdf/ - root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/
pipeline: *reuse basic_pipeline above* pipeline: *basic_pipeline
# Training configuration # Training configuration

View File

@ -1,9 +1,148 @@
# TODO Overall, this code will read in a config yaml file with omega conf """
# From that config, we are going to use HuggingFace Trainer to train a model Simple script to test OlmOCR dataset loading with YAML configuration.
# TODOS: """
# DONE Build a script to convert olmocr-mix to a new dataloader format
# DONE Write a new dataloader and collator, with tests that brings in everything, only needs to support batch size 1 for this first version import argparse
# Get a basic config yaml file system working import logging
# Get a basic hugging face trainer running, supporting Qwen2.5VL for now from pathlib import Path
# Saving and restoring training checkpoints from pprint import pprint
# Converting training checkpoints to vllm compatible checkpoinst
from transformers import AutoProcessor
from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset
# Configure logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def print_sample(sample, dataset_name):
"""Pretty print a dataset sample."""
print(f"\n{'='*80}")
print(f"Sample from: {dataset_name}")
print(f"{'='*80}")
# Print keys
print(f"\nAvailable keys: {list(sample.keys())}")
# Print path information
if 'markdown_path' in sample:
print(f"\nMarkdown path: {sample['markdown_path']}")
if 'pdf_path' in sample:
print(f"PDF path: {sample['pdf_path']}")
# Print page data
if 'page_data' in sample:
print(f"\nPage data:")
print(f" Primary language: {sample['page_data'].primary_language}")
print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}")
print(f" Rotation correction: {sample['page_data'].rotation_correction}")
print(f" Is table: {sample['page_data'].is_table}")
print(f" Is diagram: {sample['page_data'].is_diagram}")
print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None")
# Print image info
if 'image' in sample:
print(f"\nImage shape: {sample['image'].size}")
# Print anchor text preview
if 'anchor_text' in sample:
print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...")
# Print instruction prompt preview
if 'instruction_prompt' in sample:
print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...")
# Print response preview
if 'response' in sample:
print(f"\nResponse preview: {sample['response'][:200]}...")
# Print tokenization info
if 'input_ids' in sample:
print(f"\nTokenization info:")
print(f" Input IDs shape: {sample['input_ids'].shape}")
print(f" Attention mask shape: {sample['attention_mask'].shape}")
print(f" Labels shape: {sample['labels'].shape}")
if 'pixel_values' in sample:
print(f" Pixel values shape: {sample['pixel_values'].shape}")
if 'image_grid_thw' in sample:
print(f" Image grid THW: {sample['image_grid_thw']}")
def main():
parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading")
parser.add_argument(
"--config",
type=str,
default="olmocr/train/configs/example_config.yaml",
help="Path to YAML configuration file"
)
args = parser.parse_args()
# Load configuration
logger.info(f"Loading configuration from: {args.config}")
config = Config.from_yaml(args.config)
# Validate configuration
try:
config.validate()
except ValueError as e:
logger.error(f"Configuration validation failed: {e}")
return
# Load processor for tokenization
logger.info(f"Loading processor: {config.model.name}")
processor = AutoProcessor.from_pretrained(
config.model.name,
trust_remote_code=config.model.processor_trust_remote_code
)
# Process training datasets
print(f"\n{'='*80}")
print("TRAINING DATASETS")
print(f"{'='*80}")
for i, dataset_cfg in enumerate(config.dataset.train):
root_dir = dataset_cfg['root_dir']
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
logger.info(f"\nCreating training dataset {i+1} from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0:
# Get first sample
sample = dataset[0]
print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}")
# Process evaluation datasets
print(f"\n\n{'='*80}")
print("EVALUATION DATASETS")
print(f"{'='*80}")
for i, dataset_cfg in enumerate(config.dataset.eval):
root_dir = dataset_cfg['root_dir']
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0:
# Get first sample
sample = dataset[0]
print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}")
print(f"\n{'='*80}")
print("Dataset loading test completed!")
print(f"{'='*80}")
if __name__ == "__main__":
main()

View File

@ -105,7 +105,6 @@ train = [
"s3fs", "s3fs",
"necessary", "necessary",
"einops", "einops",
"transformers>=4.45.1"
] ]
elo = [ elo = [