mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 08:12:22 +00:00
Adding a standard JSON output option
This commit is contained in:
parent
6f2a426986
commit
210d170b15
@ -69,6 +69,13 @@ class FrontMatterOutputFormatConfig(PipelineStepConfig):
|
||||
name: str = "FrontMatterOutputFormat"
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONOutputFormatConfig(PipelineStepConfig):
|
||||
"""Configuration for JSONOutputFormat step."""
|
||||
|
||||
name: str = "JSONOutputFormat"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstructUserMessagesConfig(PipelineStepConfig):
|
||||
"""Configuration for InstructUserMessages step."""
|
||||
@ -301,6 +308,7 @@ class Config:
|
||||
InstructUserMessages,
|
||||
NewYamlFinetuningPromptWithAnchoring,
|
||||
NewYamlFinetuningPromptWithNoAnchoring,
|
||||
JSONOutputFormat,
|
||||
PDFRenderer,
|
||||
StaticLengthDocumentAnchoring,
|
||||
Tokenizer,
|
||||
@ -338,6 +346,9 @@ class Config:
|
||||
elif step_name == "NewYamlFinetuningPromptWithNoAnchoring":
|
||||
steps.append(NewYamlFinetuningPromptWithNoAnchoring())
|
||||
|
||||
elif step_name == "JSONOutputFormat":
|
||||
steps.append(JSONOutputFormat())
|
||||
|
||||
elif step_name == "FrontMatterOutputFormat":
|
||||
steps.append(FrontMatterOutputFormat())
|
||||
|
||||
|
93
olmocr/train/configs/qwen25_vl_b100_x1_default_json.yaml
Normal file
93
olmocr/train/configs/qwen25_vl_b100_x1_default_json.yaml
Normal file
@ -0,0 +1,93 @@
|
||||
# Example OlmOCR Training Configuration
|
||||
|
||||
# Project metadata
|
||||
project_name: olmocr-qwen-vl-training
|
||||
run_name: qwen2.5-vl-7b-finetune-default
|
||||
|
||||
# Model configuration
|
||||
model:
|
||||
name: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
trust_remote_code: true
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# LoRA settings (disabled by default)
|
||||
use_lora: false
|
||||
# lora_rank: 8
|
||||
# lora_alpha: 32
|
||||
# lora_dropout: 0.1
|
||||
# lora_target_modules:
|
||||
# - q_proj
|
||||
# - v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
|
||||
# Dataset configuration
|
||||
dataset:
|
||||
|
||||
train:
|
||||
- name: processed_01_books_train_iabooks
|
||||
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_train_iabooks/
|
||||
pipeline: &basic_pipeline
|
||||
- name: FrontMatterParser
|
||||
front_matter_class: PageResponse
|
||||
- name: PDFRenderer
|
||||
target_longest_image_dim: 1024
|
||||
- name: StaticLengthDocumentAnchoring
|
||||
target_anchor_text_len: 6000
|
||||
- name: FinetuningPrompt
|
||||
- name: JSONOutputFormat
|
||||
- name: InstructUserMessages
|
||||
- name: Tokenizer
|
||||
masking_index: -100
|
||||
end_of_message_token: "<|im_end|>"
|
||||
- name: processed_00_documents_train_s2pdf
|
||||
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
|
||||
pipeline: *basic_pipeline
|
||||
|
||||
eval:
|
||||
- name: processed_00_documents_eval_s2pdf
|
||||
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
|
||||
pipeline: *basic_pipeline
|
||||
- name: processed_01_books_eval_iabooks
|
||||
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_eval_iabooks/
|
||||
pipeline: *basic_pipeline
|
||||
|
||||
|
||||
# Training configuration
|
||||
training:
|
||||
output_dir: /weka/oe-data-default/jakep/olmocr-trainer/
|
||||
num_train_epochs: 1
|
||||
|
||||
# Batch size and accumulation
|
||||
per_device_train_batch_size: 1
|
||||
per_device_eval_batch_size: 1
|
||||
gradient_accumulation_steps: 32
|
||||
|
||||
gradient_checkpointing: False
|
||||
|
||||
# Learning rate
|
||||
learning_rate: 1e-6
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
|
||||
# Optimization
|
||||
optim: adamw_torch
|
||||
weight_decay: 0.01
|
||||
max_grad_norm: 1.0
|
||||
|
||||
|
||||
# Evaluation and checkpointing
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 500
|
||||
save_strategy: steps
|
||||
save_steps: 500
|
||||
save_total_limit: 5
|
||||
load_best_model_at_end: true
|
||||
metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss
|
||||
greater_is_better: false
|
||||
|
||||
report_to:
|
||||
- wandb
|
||||
|
@ -7,6 +7,7 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeAlias, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -338,6 +339,25 @@ is_diagram: {page_data.is_diagram}
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JSONOutputFormat(PipelineStep):
|
||||
"""Takes the output and applies the standard yaml formatting to it"""
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
page_data = sample["page_data"]
|
||||
assert type(page_data) == PageResponse
|
||||
|
||||
sample["response"] = json.dumps({
|
||||
"primary_language": page_data.primary_language,
|
||||
"is_rotation_valid": page_data.is_rotation_valid,
|
||||
"rotation_correction": page_data.rotation_correction,
|
||||
"is_table": page_data.is_table,
|
||||
"is_diagram": page_data.is_diagram,
|
||||
"natural_text": page_data.natural_text
|
||||
}, ensure_ascii=True)
|
||||
|
||||
return sample
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class InstructUserMessages(PipelineStep):
|
||||
"""Creates instruction-following messages format for training."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user