Adding a standard JSON output option

This commit is contained in:
Jake Poznanski 2025-07-01 22:13:06 +00:00
parent 6f2a426986
commit 210d170b15
3 changed files with 124 additions and 0 deletions

View File

@ -69,6 +69,13 @@ class FrontMatterOutputFormatConfig(PipelineStepConfig):
name: str = "FrontMatterOutputFormat"
@dataclass
class JSONOutputFormatConfig(PipelineStepConfig):
"""Configuration for JSONOutputFormat step."""
name: str = "JSONOutputFormat"
@dataclass
class InstructUserMessagesConfig(PipelineStepConfig):
"""Configuration for InstructUserMessages step."""
@ -301,6 +308,7 @@ class Config:
InstructUserMessages,
NewYamlFinetuningPromptWithAnchoring,
NewYamlFinetuningPromptWithNoAnchoring,
JSONOutputFormat,
PDFRenderer,
StaticLengthDocumentAnchoring,
Tokenizer,
@ -338,6 +346,9 @@ class Config:
elif step_name == "NewYamlFinetuningPromptWithNoAnchoring":
steps.append(NewYamlFinetuningPromptWithNoAnchoring())
elif step_name == "JSONOutputFormat":
steps.append(JSONOutputFormat())
elif step_name == "FrontMatterOutputFormat":
steps.append(FrontMatterOutputFormat())

View File

@ -0,0 +1,93 @@
# Example OlmOCR Training Configuration
# Project metadata
project_name: olmocr-qwen-vl-training
run_name: qwen2.5-vl-7b-finetune-default
# Model configuration
model:
name: Qwen/Qwen2.5-VL-7B-Instruct
trust_remote_code: true
torch_dtype: bfloat16
use_flash_attention: true
attn_implementation: flash_attention_2
# LoRA settings (disabled by default)
use_lora: false
# lora_rank: 8
# lora_alpha: 32
# lora_dropout: 0.1
# lora_target_modules:
# - q_proj
# - v_proj
# - k_proj
# - o_proj
# Dataset configuration
dataset:
train:
- name: processed_01_books_train_iabooks
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_train_iabooks/
pipeline: &basic_pipeline
- name: FrontMatterParser
front_matter_class: PageResponse
- name: PDFRenderer
target_longest_image_dim: 1024
- name: StaticLengthDocumentAnchoring
target_anchor_text_len: 6000
- name: FinetuningPrompt
- name: JSONOutputFormat
- name: InstructUserMessages
- name: Tokenizer
masking_index: -100
end_of_message_token: "<|im_end|>"
- name: processed_00_documents_train_s2pdf
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
pipeline: *basic_pipeline
eval:
- name: processed_00_documents_eval_s2pdf
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
pipeline: *basic_pipeline
- name: processed_01_books_eval_iabooks
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_eval_iabooks/
pipeline: *basic_pipeline
# Training configuration
training:
output_dir: /weka/oe-data-default/jakep/olmocr-trainer/
num_train_epochs: 1
# Batch size and accumulation
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
gradient_accumulation_steps: 32
gradient_checkpointing: False
# Learning rate
learning_rate: 1e-6
lr_scheduler_type: cosine
warmup_ratio: 0.1
# Optimization
optim: adamw_torch
weight_decay: 0.01
max_grad_norm: 1.0
# Evaluation and checkpointing
evaluation_strategy: steps
eval_steps: 500
save_strategy: steps
save_steps: 500
save_total_limit: 5
load_best_model_at_end: true
metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss
greater_is_better: false
report_to:
- wandb

View File

@ -7,6 +7,7 @@ from functools import reduce
from io import BytesIO
from os import PathLike
from pathlib import Path
import json
from typing import Any, Callable, Dict, List, Optional, Type, TypeAlias, Tuple
import numpy as np
@ -338,6 +339,25 @@ is_diagram: {page_data.is_diagram}
return sample
@dataclass(frozen=True, slots=True)
class JSONOutputFormat(PipelineStep):
"""Takes the output and applies the standard yaml formatting to it"""
def __call__(self, sample: Sample) -> Sample:
page_data = sample["page_data"]
assert type(page_data) == PageResponse
sample["response"] = json.dumps({
"primary_language": page_data.primary_language,
"is_rotation_valid": page_data.is_rotation_valid,
"rotation_correction": page_data.rotation_correction,
"is_table": page_data.is_table,
"is_diagram": page_data.is_diagram,
"natural_text": page_data.natural_text
}, ensure_ascii=True)
return sample
@dataclass(frozen=True, slots=True)
class InstructUserMessages(PipelineStep):
"""Creates instruction-following messages format for training."""