diff --git a/olmocr/train/config.py b/olmocr/train/config.py index 45aacc9..5c4760f 100644 --- a/olmocr/train/config.py +++ b/olmocr/train/config.py @@ -69,6 +69,13 @@ class FrontMatterOutputFormatConfig(PipelineStepConfig): name: str = "FrontMatterOutputFormat" +@dataclass +class JSONOutputFormatConfig(PipelineStepConfig): + """Configuration for JSONOutputFormat step.""" + + name: str = "JSONOutputFormat" + + @dataclass class InstructUserMessagesConfig(PipelineStepConfig): """Configuration for InstructUserMessages step.""" @@ -301,6 +308,7 @@ class Config: InstructUserMessages, NewYamlFinetuningPromptWithAnchoring, NewYamlFinetuningPromptWithNoAnchoring, + JSONOutputFormat, PDFRenderer, StaticLengthDocumentAnchoring, Tokenizer, @@ -338,6 +346,9 @@ class Config: elif step_name == "NewYamlFinetuningPromptWithNoAnchoring": steps.append(NewYamlFinetuningPromptWithNoAnchoring()) + elif step_name == "JSONOutputFormat": + steps.append(JSONOutputFormat()) + elif step_name == "FrontMatterOutputFormat": steps.append(FrontMatterOutputFormat()) diff --git a/olmocr/train/configs/qwen25_vl_b100_x1_default_json.yaml b/olmocr/train/configs/qwen25_vl_b100_x1_default_json.yaml new file mode 100644 index 0000000..81fa259 --- /dev/null +++ b/olmocr/train/configs/qwen25_vl_b100_x1_default_json.yaml @@ -0,0 +1,93 @@ +# Example OlmOCR Training Configuration + +# Project metadata +project_name: olmocr-qwen-vl-training +run_name: qwen2.5-vl-7b-finetune-default + +# Model configuration +model: + name: Qwen/Qwen2.5-VL-7B-Instruct + trust_remote_code: true + torch_dtype: bfloat16 + use_flash_attention: true + attn_implementation: flash_attention_2 + + # LoRA settings (disabled by default) + use_lora: false + # lora_rank: 8 + # lora_alpha: 32 + # lora_dropout: 0.1 + # lora_target_modules: + # - q_proj + # - v_proj + # - k_proj + # - o_proj + +# Dataset configuration +dataset: + + train: + - name: processed_01_books_train_iabooks + root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_train_iabooks/ + pipeline: &basic_pipeline + - name: FrontMatterParser + front_matter_class: PageResponse + - name: PDFRenderer + target_longest_image_dim: 1024 + - name: StaticLengthDocumentAnchoring + target_anchor_text_len: 6000 + - name: FinetuningPrompt + - name: JSONOutputFormat + - name: InstructUserMessages + - name: Tokenizer + masking_index: -100 + end_of_message_token: "<|im_end|>" + - name: processed_00_documents_train_s2pdf + root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_train_s2pdf/ + pipeline: *basic_pipeline + + eval: + - name: processed_00_documents_eval_s2pdf + root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/ + pipeline: *basic_pipeline + - name: processed_01_books_eval_iabooks + root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_eval_iabooks/ + pipeline: *basic_pipeline + + +# Training configuration +training: + output_dir: /weka/oe-data-default/jakep/olmocr-trainer/ + num_train_epochs: 1 + + # Batch size and accumulation + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + gradient_accumulation_steps: 32 + + gradient_checkpointing: False + + # Learning rate + learning_rate: 1e-6 + lr_scheduler_type: cosine + warmup_ratio: 0.1 + + # Optimization + optim: adamw_torch + weight_decay: 0.01 + max_grad_norm: 1.0 + + + # Evaluation and checkpointing + evaluation_strategy: steps + eval_steps: 500 + save_strategy: steps + save_steps: 500 + save_total_limit: 5 + load_best_model_at_end: true + metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss + greater_is_better: false + + report_to: + - wandb + \ No newline at end of file diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index 61ce6b0..a3363db 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -7,6 +7,7 @@ from functools import reduce from io import BytesIO from os import PathLike from pathlib import Path +import json from typing import Any, Callable, Dict, List, Optional, Type, TypeAlias, Tuple import numpy as np @@ -338,6 +339,25 @@ is_diagram: {page_data.is_diagram} return sample +@dataclass(frozen=True, slots=True) +class JSONOutputFormat(PipelineStep): + """Takes the output and applies the standard yaml formatting to it""" + + def __call__(self, sample: Sample) -> Sample: + page_data = sample["page_data"] + assert type(page_data) == PageResponse + + sample["response"] = json.dumps({ + "primary_language": page_data.primary_language, + "is_rotation_valid": page_data.is_rotation_valid, + "rotation_correction": page_data.rotation_correction, + "is_table": page_data.is_table, + "is_diagram": page_data.is_diagram, + "natural_text": page_data.natural_text + }, ensure_ascii=True) + + return sample + @dataclass(frozen=True, slots=True) class InstructUserMessages(PipelineStep): """Creates instruction-following messages format for training."""