mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-27 16:12:13 +00:00
Basic train config loader for datasets
This commit is contained in:
parent
b93c262dca
commit
0ebc35cf1f
@ -63,10 +63,21 @@ class TokenizerStepConfig(PipelineStepConfig):
|
|||||||
end_of_message_token: str = "<|im_end|>"
|
end_of_message_token: str = "<|im_end|>"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetItemConfig:
|
||||||
|
"""Configuration for a single dataset item."""
|
||||||
|
root_dir: str
|
||||||
|
pipeline: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Optional sampling
|
||||||
|
max_samples: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetConfig:
|
class DatasetConfig:
|
||||||
"""Configuration for dataset and data loading."""
|
"""Configuration for dataset and data loading."""
|
||||||
root_dir: str
|
train: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
eval: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
# DataLoader configuration
|
# DataLoader configuration
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
@ -76,33 +87,13 @@ class DatasetConfig:
|
|||||||
pin_memory: bool = True
|
pin_memory: bool = True
|
||||||
prefetch_factor: int = 2
|
prefetch_factor: int = 2
|
||||||
|
|
||||||
# Pipeline steps configuration
|
# Global seed
|
||||||
pipeline_steps: List[Dict[str, Any]] = field(default_factory=lambda: [
|
|
||||||
{"name": "FrontMatterParser", "use_page_response_class": True},
|
|
||||||
{"name": "PDFRenderer", "target_longest_image_dim": 1024},
|
|
||||||
{"name": "StaticLengthDocumentAnchoring", "target_anchor_text_len": 6000},
|
|
||||||
{"name": "FinetuningPrompt"},
|
|
||||||
{"name": "FrontMatterOutputFormat"},
|
|
||||||
{"name": "InstructUserMessages"},
|
|
||||||
{"name": "Tokenizer", "masking_index": -100, "end_of_message_token": "<|im_end|>"}
|
|
||||||
])
|
|
||||||
|
|
||||||
# Optional dataset sampling
|
|
||||||
max_samples: Optional[int] = None
|
|
||||||
validation_split: float = 0.1
|
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
|
|
||||||
# Train/validation split
|
|
||||||
train_indices: Optional[List[int]] = None
|
|
||||||
val_indices: Optional[List[int]] = None
|
|
||||||
|
|
||||||
# Caching
|
# Caching
|
||||||
cache_dir: Optional[str] = None
|
cache_dir: Optional[str] = None
|
||||||
use_cache: bool = False
|
use_cache: bool = False
|
||||||
|
|
||||||
# Data augmentation (future extension)
|
|
||||||
augmentation: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@ -280,9 +271,14 @@ class Config:
|
|||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
"""Validate configuration values."""
|
"""Validate configuration values."""
|
||||||
# Dataset validation
|
# Dataset validation - check all train and eval datasets
|
||||||
if not os.path.exists(self.dataset.root_dir):
|
for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]:
|
||||||
raise ValueError(f"Dataset root directory does not exist: {self.dataset.root_dir}")
|
for i, dataset_cfg in enumerate(datasets):
|
||||||
|
root_dir = dataset_cfg.get('root_dir')
|
||||||
|
if not root_dir:
|
||||||
|
raise ValueError(f"Missing root_dir for {split_name} dataset {i}")
|
||||||
|
if not os.path.exists(root_dir):
|
||||||
|
raise ValueError(f"Dataset root directory does not exist: {root_dir}")
|
||||||
|
|
||||||
# Training validation
|
# Training validation
|
||||||
if self.training.warmup_steps is not None and self.training.warmup_ratio > 0:
|
if self.training.warmup_steps is not None and self.training.warmup_ratio > 0:
|
||||||
@ -303,8 +299,16 @@ class Config:
|
|||||||
self.training.logging_dir = os.path.join(self.training.output_dir, "logs")
|
self.training.logging_dir = os.path.join(self.training.output_dir, "logs")
|
||||||
Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True)
|
Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def get_pipeline_steps(self, processor=None):
|
def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None):
|
||||||
"""Create actual pipeline step instances from configuration."""
|
"""Create actual pipeline step instances from pipeline configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline_config: List of pipeline step configurations
|
||||||
|
processor: The model processor (required for Tokenizer step)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of initialized pipeline step instances
|
||||||
|
"""
|
||||||
from olmocr.train.dataloader import (
|
from olmocr.train.dataloader import (
|
||||||
FrontMatterParser,
|
FrontMatterParser,
|
||||||
PDFRenderer,
|
PDFRenderer,
|
||||||
@ -317,13 +321,17 @@ class Config:
|
|||||||
from olmocr.prompts.prompts import PageResponse
|
from olmocr.prompts.prompts import PageResponse
|
||||||
|
|
||||||
steps = []
|
steps = []
|
||||||
for step_config in self.dataset.pipeline_steps:
|
for step_config in pipeline_config:
|
||||||
if not step_config.get('enabled', True):
|
if not step_config.get('enabled', True):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
step_name = step_config['name']
|
step_name = step_config['name']
|
||||||
|
|
||||||
if step_name == 'FrontMatterParser':
|
if step_name == 'FrontMatterParser':
|
||||||
|
# Handle both old and new config format
|
||||||
|
if 'front_matter_class' in step_config:
|
||||||
|
front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None
|
||||||
|
else:
|
||||||
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None
|
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None
|
||||||
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
|
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
|
||||||
|
|
||||||
@ -365,7 +373,7 @@ def create_default_config() -> Config:
|
|||||||
"""Create a default configuration."""
|
"""Create a default configuration."""
|
||||||
return Config(
|
return Config(
|
||||||
model=ModelConfig(),
|
model=ModelConfig(),
|
||||||
dataset=DatasetConfig(root_dir="/path/to/dataset"),
|
dataset=DatasetConfig(),
|
||||||
training=TrainingConfig()
|
training=TrainingConfig()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -27,8 +27,8 @@ model:
|
|||||||
dataset:
|
dataset:
|
||||||
|
|
||||||
train:
|
train:
|
||||||
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
|
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
|
||||||
pipeline *basic_pipeline:
|
pipeline: &basic_pipeline
|
||||||
- name: FrontMatterParser
|
- name: FrontMatterParser
|
||||||
front_matter_class: PageResponse
|
front_matter_class: PageResponse
|
||||||
- name: PDFRenderer
|
- name: PDFRenderer
|
||||||
@ -41,15 +41,14 @@ dataset:
|
|||||||
- name: Tokenizer
|
- name: Tokenizer
|
||||||
masking_index: -100
|
masking_index: -100
|
||||||
end_of_message_token: "<|im_end|>"
|
end_of_message_token: "<|im_end|>"
|
||||||
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_train_s2pdf/
|
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/
|
||||||
pipeline: *reuse basic_pipeline above*
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
|
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
|
||||||
pipeline: *reuse basic_pipeline above*
|
pipeline: *basic_pipeline
|
||||||
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_eval_s2pdf/
|
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/
|
||||||
pipeline: *reuse basic_pipeline above*
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Training configuration
|
# Training configuration
|
||||||
|
|||||||
@ -1,9 +1,148 @@
|
|||||||
# TODO Overall, this code will read in a config yaml file with omega conf
|
"""
|
||||||
# From that config, we are going to use HuggingFace Trainer to train a model
|
Simple script to test OlmOCR dataset loading with YAML configuration.
|
||||||
# TODOS:
|
"""
|
||||||
# DONE Build a script to convert olmocr-mix to a new dataloader format
|
|
||||||
# DONE Write a new dataloader and collator, with tests that brings in everything, only needs to support batch size 1 for this first version
|
import argparse
|
||||||
# Get a basic config yaml file system working
|
import logging
|
||||||
# Get a basic hugging face trainer running, supporting Qwen2.5VL for now
|
from pathlib import Path
|
||||||
# Saving and restoring training checkpoints
|
from pprint import pprint
|
||||||
# Converting training checkpoints to vllm compatible checkpoinst
|
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
from olmocr.train.config import Config
|
||||||
|
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def print_sample(sample, dataset_name):
|
||||||
|
"""Pretty print a dataset sample."""
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print(f"Sample from: {dataset_name}")
|
||||||
|
print(f"{'='*80}")
|
||||||
|
|
||||||
|
# Print keys
|
||||||
|
print(f"\nAvailable keys: {list(sample.keys())}")
|
||||||
|
|
||||||
|
# Print path information
|
||||||
|
if 'markdown_path' in sample:
|
||||||
|
print(f"\nMarkdown path: {sample['markdown_path']}")
|
||||||
|
if 'pdf_path' in sample:
|
||||||
|
print(f"PDF path: {sample['pdf_path']}")
|
||||||
|
|
||||||
|
# Print page data
|
||||||
|
if 'page_data' in sample:
|
||||||
|
print(f"\nPage data:")
|
||||||
|
print(f" Primary language: {sample['page_data'].primary_language}")
|
||||||
|
print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}")
|
||||||
|
print(f" Rotation correction: {sample['page_data'].rotation_correction}")
|
||||||
|
print(f" Is table: {sample['page_data'].is_table}")
|
||||||
|
print(f" Is diagram: {sample['page_data'].is_diagram}")
|
||||||
|
print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None")
|
||||||
|
|
||||||
|
# Print image info
|
||||||
|
if 'image' in sample:
|
||||||
|
print(f"\nImage shape: {sample['image'].size}")
|
||||||
|
|
||||||
|
# Print anchor text preview
|
||||||
|
if 'anchor_text' in sample:
|
||||||
|
print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...")
|
||||||
|
|
||||||
|
# Print instruction prompt preview
|
||||||
|
if 'instruction_prompt' in sample:
|
||||||
|
print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...")
|
||||||
|
|
||||||
|
# Print response preview
|
||||||
|
if 'response' in sample:
|
||||||
|
print(f"\nResponse preview: {sample['response'][:200]}...")
|
||||||
|
|
||||||
|
# Print tokenization info
|
||||||
|
if 'input_ids' in sample:
|
||||||
|
print(f"\nTokenization info:")
|
||||||
|
print(f" Input IDs shape: {sample['input_ids'].shape}")
|
||||||
|
print(f" Attention mask shape: {sample['attention_mask'].shape}")
|
||||||
|
print(f" Labels shape: {sample['labels'].shape}")
|
||||||
|
if 'pixel_values' in sample:
|
||||||
|
print(f" Pixel values shape: {sample['pixel_values'].shape}")
|
||||||
|
if 'image_grid_thw' in sample:
|
||||||
|
print(f" Image grid THW: {sample['image_grid_thw']}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="olmocr/train/configs/example_config.yaml",
|
||||||
|
help="Path to YAML configuration file"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
logger.info(f"Loading configuration from: {args.config}")
|
||||||
|
config = Config.from_yaml(args.config)
|
||||||
|
|
||||||
|
# Validate configuration
|
||||||
|
try:
|
||||||
|
config.validate()
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Configuration validation failed: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load processor for tokenization
|
||||||
|
logger.info(f"Loading processor: {config.model.name}")
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
config.model.name,
|
||||||
|
trust_remote_code=config.model.processor_trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process training datasets
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print("TRAINING DATASETS")
|
||||||
|
print(f"{'='*80}")
|
||||||
|
|
||||||
|
for i, dataset_cfg in enumerate(config.dataset.train):
|
||||||
|
root_dir = dataset_cfg['root_dir']
|
||||||
|
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
||||||
|
|
||||||
|
logger.info(f"\nCreating training dataset {i+1} from: {root_dir}")
|
||||||
|
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||||
|
logger.info(f"Found {len(dataset)} samples")
|
||||||
|
|
||||||
|
if len(dataset) > 0:
|
||||||
|
# Get first sample
|
||||||
|
sample = dataset[0]
|
||||||
|
print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}")
|
||||||
|
|
||||||
|
# Process evaluation datasets
|
||||||
|
print(f"\n\n{'='*80}")
|
||||||
|
print("EVALUATION DATASETS")
|
||||||
|
print(f"{'='*80}")
|
||||||
|
|
||||||
|
for i, dataset_cfg in enumerate(config.dataset.eval):
|
||||||
|
root_dir = dataset_cfg['root_dir']
|
||||||
|
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
||||||
|
|
||||||
|
logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}")
|
||||||
|
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||||
|
logger.info(f"Found {len(dataset)} samples")
|
||||||
|
|
||||||
|
if len(dataset) > 0:
|
||||||
|
# Get first sample
|
||||||
|
sample = dataset[0]
|
||||||
|
print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}")
|
||||||
|
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print("Dataset loading test completed!")
|
||||||
|
print(f"{'='*80}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -105,7 +105,6 @@ train = [
|
|||||||
"s3fs",
|
"s3fs",
|
||||||
"necessary",
|
"necessary",
|
||||||
"einops",
|
"einops",
|
||||||
"transformers>=4.45.1"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
elo = [
|
elo = [
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user