mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-25 14:52:56 +00:00
Trying a few more configs
This commit is contained in:
parent
384a1b19c7
commit
a5a0cd7478
@ -83,6 +83,13 @@ class InstructUserMessagesConfig(PipelineStepConfig):
|
|||||||
name: str = "InstructUserMessages"
|
name: str = "InstructUserMessages"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LatexBracketNormalizerConfig(PipelineStepConfig):
|
||||||
|
"""Configuration for LatexBracketNormalizer step."""
|
||||||
|
|
||||||
|
name: str = "LatexBracketNormalizer"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenizerStepConfig(PipelineStepConfig):
|
class TokenizerStepConfig(PipelineStepConfig):
|
||||||
"""Configuration for Tokenizer step."""
|
"""Configuration for Tokenizer step."""
|
||||||
@ -307,6 +314,7 @@ class Config:
|
|||||||
FrontMatterOutputFormat,
|
FrontMatterOutputFormat,
|
||||||
FrontMatterParser,
|
FrontMatterParser,
|
||||||
InstructUserMessages,
|
InstructUserMessages,
|
||||||
|
LatexBracketNormalizer,
|
||||||
NewYamlFinetuningPromptWithAnchoring,
|
NewYamlFinetuningPromptWithAnchoring,
|
||||||
NewYamlFinetuningPromptWithNoAnchoring,
|
NewYamlFinetuningPromptWithNoAnchoring,
|
||||||
JSONOutputFormat,
|
JSONOutputFormat,
|
||||||
@ -356,6 +364,9 @@ class Config:
|
|||||||
elif step_name == "InstructUserMessages":
|
elif step_name == "InstructUserMessages":
|
||||||
steps.append(InstructUserMessages())
|
steps.append(InstructUserMessages())
|
||||||
|
|
||||||
|
elif step_name == "LatexBracketNormalizer":
|
||||||
|
steps.append(LatexBracketNormalizer())
|
||||||
|
|
||||||
elif step_name == "Tokenizer":
|
elif step_name == "Tokenizer":
|
||||||
if processor is None:
|
if processor is None:
|
||||||
raise ValueError("Processor must be provided for Tokenizer step")
|
raise ValueError("Processor must be provided for Tokenizer step")
|
||||||
|
|||||||
@ -0,0 +1,97 @@
|
|||||||
|
# Example OlmOCR Training Configuration
|
||||||
|
|
||||||
|
# Project metadata
|
||||||
|
project_name: olmocr-qwen-vl-training
|
||||||
|
run_name: qwen2.5-vl-7b-finetune-day3-yaml-1280-noanchor-latexnormalize
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
model:
|
||||||
|
name: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
torch_dtype: bfloat16
|
||||||
|
use_flash_attention: true
|
||||||
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
|
# LoRA settings (disabled by default)
|
||||||
|
use_lora: false
|
||||||
|
# lora_rank: 8
|
||||||
|
# lora_alpha: 32
|
||||||
|
# lora_dropout: 0.1
|
||||||
|
# lora_target_modules:
|
||||||
|
# - q_proj
|
||||||
|
# - v_proj
|
||||||
|
# - k_proj
|
||||||
|
# - o_proj
|
||||||
|
|
||||||
|
# Dataset configuration
|
||||||
|
dataset:
|
||||||
|
|
||||||
|
train:
|
||||||
|
- name: processed_01_books_train_iabooks
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_train_iabooks/
|
||||||
|
pipeline: &basic_pipeline
|
||||||
|
- name: FrontMatterParser
|
||||||
|
front_matter_class: PageResponse
|
||||||
|
- name: PDFRenderer
|
||||||
|
target_longest_image_dim: 1280
|
||||||
|
- name: LatexBracketNormalizer
|
||||||
|
- name: StaticLengthDocumentAnchoring
|
||||||
|
target_anchor_text_len: -1
|
||||||
|
- name: FinetuningPrompt
|
||||||
|
- name: FrontMatterOutputFormat
|
||||||
|
- name: InstructUserMessages
|
||||||
|
- name: Tokenizer
|
||||||
|
masking_index: -100
|
||||||
|
end_of_message_token: "<|im_end|>"
|
||||||
|
- name: processed_00_documents_train_s2pdf
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
|
eval:
|
||||||
|
- name: processed_00_documents_eval_s2pdf
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
- name: processed_01_books_eval_iabooks
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_eval_iabooks/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
training:
|
||||||
|
output_dir: /weka/oe-data-default/jakep/olmocr-trainer/
|
||||||
|
num_train_epochs: 1
|
||||||
|
|
||||||
|
# Batch size and accumulation
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 32
|
||||||
|
|
||||||
|
gradient_checkpointing: False
|
||||||
|
|
||||||
|
collator_max_token_len: 8192
|
||||||
|
|
||||||
|
# Learning rate
|
||||||
|
learning_rate: 2e-5
|
||||||
|
lr_scheduler_type: linear
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
# Optimization
|
||||||
|
optim: adamw_torch
|
||||||
|
weight_decay: 0.01
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# Evaluation and checkpointing
|
||||||
|
evaluation_strategy: steps
|
||||||
|
eval_steps: 500
|
||||||
|
save_strategy: steps
|
||||||
|
save_steps: 500
|
||||||
|
save_total_limit: 5
|
||||||
|
load_best_model_at_end: false # Needs to be false because it has a problem restoring checkpoints for some reason
|
||||||
|
metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss
|
||||||
|
greater_is_better: false
|
||||||
|
|
||||||
|
report_to:
|
||||||
|
- wandb
|
||||||
|
|
||||||
@ -0,0 +1,94 @@
|
|||||||
|
# Example OlmOCR Training Configuration
|
||||||
|
|
||||||
|
# Project metadata
|
||||||
|
project_name: olmocr-qwen-vl-training
|
||||||
|
run_name: qwen2.5-vl-7b-finetune-day3-yaml-1280-noanchor-newprompt
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
model:
|
||||||
|
name: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
torch_dtype: bfloat16
|
||||||
|
use_flash_attention: true
|
||||||
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
|
# LoRA settings (disabled by default)
|
||||||
|
use_lora: false
|
||||||
|
# lora_rank: 8
|
||||||
|
# lora_alpha: 32
|
||||||
|
# lora_dropout: 0.1
|
||||||
|
# lora_target_modules:
|
||||||
|
# - q_proj
|
||||||
|
# - v_proj
|
||||||
|
# - k_proj
|
||||||
|
# - o_proj
|
||||||
|
|
||||||
|
# Dataset configuration
|
||||||
|
dataset:
|
||||||
|
|
||||||
|
train:
|
||||||
|
- name: processed_01_books_train_iabooks
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_train_iabooks/
|
||||||
|
pipeline: &basic_pipeline
|
||||||
|
- name: FrontMatterParser
|
||||||
|
front_matter_class: PageResponse
|
||||||
|
- name: PDFRenderer
|
||||||
|
target_longest_image_dim: 1280
|
||||||
|
- name: NewYamlFinetuningPromptWithNoAnchoring
|
||||||
|
- name: FrontMatterOutputFormat
|
||||||
|
- name: InstructUserMessages
|
||||||
|
- name: Tokenizer
|
||||||
|
masking_index: -100
|
||||||
|
end_of_message_token: "<|im_end|>"
|
||||||
|
- name: processed_00_documents_train_s2pdf
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
|
eval:
|
||||||
|
- name: processed_00_documents_eval_s2pdf
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
- name: processed_01_books_eval_iabooks
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_eval_iabooks/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
training:
|
||||||
|
output_dir: /weka/oe-data-default/jakep/olmocr-trainer/
|
||||||
|
num_train_epochs: 1
|
||||||
|
|
||||||
|
# Batch size and accumulation
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 32
|
||||||
|
|
||||||
|
gradient_checkpointing: False
|
||||||
|
|
||||||
|
collator_max_token_len: 8192
|
||||||
|
|
||||||
|
# Learning rate
|
||||||
|
learning_rate: 2e-5
|
||||||
|
lr_scheduler_type: linear
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
# Optimization
|
||||||
|
optim: adamw_torch
|
||||||
|
weight_decay: 0.01
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# Evaluation and checkpointing
|
||||||
|
evaluation_strategy: steps
|
||||||
|
eval_steps: 500
|
||||||
|
save_strategy: steps
|
||||||
|
save_steps: 500
|
||||||
|
save_total_limit: 5
|
||||||
|
load_best_model_at_end: false # Needs to be false because it has a problem restoring checkpoints for some reason
|
||||||
|
metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss
|
||||||
|
greater_is_better: false
|
||||||
|
|
||||||
|
report_to:
|
||||||
|
- wandb
|
||||||
|
|
||||||
96
olmocr/train/configs/qwen2_vl_b100_x1_day3_yaml.yaml
Normal file
96
olmocr/train/configs/qwen2_vl_b100_x1_day3_yaml.yaml
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# Example OlmOCR Training Configuration
|
||||||
|
|
||||||
|
# Project metadata
|
||||||
|
project_name: olmocr-qwen-vl-training
|
||||||
|
run_name: qwen2-vl-7b-finetune-day3-yaml
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
model:
|
||||||
|
name: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
torch_dtype: bfloat16
|
||||||
|
use_flash_attention: true
|
||||||
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
|
# LoRA settings (disabled by default)
|
||||||
|
use_lora: false
|
||||||
|
# lora_rank: 8
|
||||||
|
# lora_alpha: 32
|
||||||
|
# lora_dropout: 0.1
|
||||||
|
# lora_target_modules:
|
||||||
|
# - q_proj
|
||||||
|
# - v_proj
|
||||||
|
# - k_proj
|
||||||
|
# - o_proj
|
||||||
|
|
||||||
|
# Dataset configuration
|
||||||
|
dataset:
|
||||||
|
|
||||||
|
train:
|
||||||
|
- name: processed_01_books_train_iabooks
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_train_iabooks/
|
||||||
|
pipeline: &basic_pipeline
|
||||||
|
- name: FrontMatterParser
|
||||||
|
front_matter_class: PageResponse
|
||||||
|
- name: PDFRenderer
|
||||||
|
target_longest_image_dim: 1280
|
||||||
|
- name: StaticLengthDocumentAnchoring
|
||||||
|
target_anchor_text_len: -1
|
||||||
|
- name: FinetuningPrompt
|
||||||
|
- name: FrontMatterOutputFormat
|
||||||
|
- name: InstructUserMessages
|
||||||
|
- name: Tokenizer
|
||||||
|
masking_index: -100
|
||||||
|
end_of_message_token: "<|im_end|>"
|
||||||
|
- name: processed_00_documents_train_s2pdf
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
|
eval:
|
||||||
|
- name: processed_00_documents_eval_s2pdf
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
- name: processed_01_books_eval_iabooks
|
||||||
|
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_eval_iabooks/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
training:
|
||||||
|
output_dir: /weka/oe-data-default/jakep/olmocr-trainer/
|
||||||
|
num_train_epochs: 1
|
||||||
|
|
||||||
|
# Batch size and accumulation
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 32
|
||||||
|
|
||||||
|
gradient_checkpointing: False
|
||||||
|
|
||||||
|
collator_max_token_len: 8192
|
||||||
|
|
||||||
|
# Learning rate
|
||||||
|
learning_rate: 2e-5
|
||||||
|
lr_scheduler_type: linear
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
# Optimization
|
||||||
|
optim: adamw_torch
|
||||||
|
weight_decay: 0.01
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# Evaluation and checkpointing
|
||||||
|
evaluation_strategy: steps
|
||||||
|
eval_steps: 500
|
||||||
|
save_strategy: steps
|
||||||
|
save_steps: 500
|
||||||
|
save_total_limit: 5
|
||||||
|
load_best_model_at_end: false # Needs to be false because it has a problem restoring checkpoints for some reason
|
||||||
|
metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss
|
||||||
|
greater_is_better: false
|
||||||
|
|
||||||
|
report_to:
|
||||||
|
- wandb
|
||||||
|
|
||||||
@ -1,5 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
@ -358,6 +359,49 @@ class JSONOutputFormat(PipelineStep):
|
|||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class LatexBracketNormalizer(PipelineStep):
|
||||||
|
"""Normalizes LaTeX brackets in natural text field."""
|
||||||
|
|
||||||
|
def __call__(self, sample: Sample) -> Sample:
|
||||||
|
"""Normalize LaTeX brackets in the natural text field."""
|
||||||
|
# Get the page_data object
|
||||||
|
if "page_data" not in sample:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
page_data = sample["page_data"]
|
||||||
|
if not hasattr(page_data, "natural_text") or not page_data.natural_text:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
text = page_data.natural_text
|
||||||
|
|
||||||
|
# Define patterns for LaTeX normalization
|
||||||
|
# Order matters: process display math first, then inline
|
||||||
|
patterns = [
|
||||||
|
(r"\$\$(.+?)\$\$", r"\[\1\]"), # $$...$$ to \[...\]
|
||||||
|
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply replacements
|
||||||
|
for pattern, replacement in patterns:
|
||||||
|
text = re.sub(pattern, replacement, text, flags=re.DOTALL)
|
||||||
|
|
||||||
|
# Update the page_data with normalized text
|
||||||
|
# Since PageResponse is frozen, we need to create a new instance
|
||||||
|
from olmocr.prompts.prompts import PageResponse
|
||||||
|
new_page_data = PageResponse(
|
||||||
|
primary_language=page_data.primary_language,
|
||||||
|
is_rotation_valid=page_data.is_rotation_valid,
|
||||||
|
rotation_correction=page_data.rotation_correction,
|
||||||
|
is_table=page_data.is_table,
|
||||||
|
is_diagram=page_data.is_diagram,
|
||||||
|
natural_text=text
|
||||||
|
)
|
||||||
|
|
||||||
|
sample["page_data"] = new_page_data
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class InstructUserMessages(PipelineStep):
|
class InstructUserMessages(PipelineStep):
|
||||||
"""Creates instruction-following messages format for training."""
|
"""Creates instruction-following messages format for training."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user