mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 16:52:20 +00:00
Getting things ready for a bit more augmentation
This commit is contained in:
parent
55f8ba0ac0
commit
7dca33db60
@ -386,7 +386,7 @@ class Config:
|
|||||||
|
|
||||||
elif step_name == "PDFRenderer":
|
elif step_name == "PDFRenderer":
|
||||||
steps.append(
|
steps.append(
|
||||||
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024), image_transform=None) # Can be extended later
|
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))
|
||||||
)
|
)
|
||||||
|
|
||||||
elif step_name == "StaticLengthDocumentAnchoring":
|
elif step_name == "StaticLengthDocumentAnchoring":
|
||||||
|
@ -0,0 +1,107 @@
|
|||||||
|
# Example OlmOCR Training Configuration with Torch Compile
|
||||||
|
|
||||||
|
# Project metadata
|
||||||
|
project_name: olmocr-qwen-vl-training
|
||||||
|
run_name: qwen2.5-vl-7b-olmocrv3_1epoch_prompt_first_rotation_tokflip
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
model:
|
||||||
|
name: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
torch_dtype: bfloat16
|
||||||
|
use_flash_attention: true
|
||||||
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
|
# LoRA settings (disabled by default)
|
||||||
|
use_lora: false
|
||||||
|
# lora_rank: 8
|
||||||
|
# lora_alpha: 32
|
||||||
|
# lora_dropout: 0.1
|
||||||
|
# lora_target_modules:
|
||||||
|
# - q_proj
|
||||||
|
# - v_proj
|
||||||
|
# - k_proj
|
||||||
|
# - o_proj
|
||||||
|
|
||||||
|
# Dataset configuration
|
||||||
|
dataset:
|
||||||
|
|
||||||
|
train:
|
||||||
|
- name: processed_01_books_train_iabooks
|
||||||
|
root_dir: /data/olmOCR-mix-0225/processed_01_books_train_iabooks/
|
||||||
|
pipeline: &basic_pipeline
|
||||||
|
- name: FrontMatterParser
|
||||||
|
front_matter_class: PageResponse
|
||||||
|
- name: FilterOutRotatedDocuments
|
||||||
|
- name: PDFRenderer
|
||||||
|
target_longest_image_dim: 1288
|
||||||
|
- name: RotationAugmentation
|
||||||
|
probability: 0.002
|
||||||
|
- name: NewYamlFinetuningPromptWithNoAnchoring
|
||||||
|
- name: FrontMatterOutputFormat
|
||||||
|
- name: InstructUserMessages
|
||||||
|
prompt_first: true
|
||||||
|
- name: Tokenizer
|
||||||
|
masking_index: -100
|
||||||
|
end_of_message_token: "<|im_end|>"
|
||||||
|
- name: RandomTokenFlipper
|
||||||
|
token_flip_rate: 0.0001
|
||||||
|
- name: processed_00_documents_train_s2pdf
|
||||||
|
root_dir: /data/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
|
eval:
|
||||||
|
- name: processed_00_documents_eval_s2pdf
|
||||||
|
root_dir: /data/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
- name: processed_01_books_eval_iabooks
|
||||||
|
root_dir: /data/olmOCR-mix-0225/processed_01_books_eval_iabooks/
|
||||||
|
pipeline: *basic_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
training:
|
||||||
|
output_dir: /weka/oe-data-default/jakep/olmocr-trainer/
|
||||||
|
num_train_epochs: 1
|
||||||
|
|
||||||
|
# Batch size and accumulation
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 32
|
||||||
|
|
||||||
|
gradient_checkpointing: False
|
||||||
|
|
||||||
|
collator_max_token_len: 8192
|
||||||
|
|
||||||
|
# Learning rate
|
||||||
|
learning_rate: 2e-5
|
||||||
|
lr_scheduler_type: linear
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
# Optimization
|
||||||
|
optim: adamw_torch
|
||||||
|
weight_decay: 0.01
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
|
||||||
|
# Torch compile settings
|
||||||
|
torch_compile: true
|
||||||
|
torch_compile_backend: inductor
|
||||||
|
torch_compile_mode: default
|
||||||
|
torch_compile_fullgraph: false
|
||||||
|
torch_compile_dynamic: false
|
||||||
|
|
||||||
|
seed: 300
|
||||||
|
data_seed: 301
|
||||||
|
|
||||||
|
# Evaluation and checkpointing
|
||||||
|
evaluation_strategy: steps
|
||||||
|
eval_steps: 500
|
||||||
|
save_strategy: steps
|
||||||
|
save_steps: 500
|
||||||
|
save_total_limit: 5
|
||||||
|
load_best_model_at_end: false # Needs to be false because it has a problem restoring checkpoints for some reason
|
||||||
|
metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss
|
||||||
|
greater_is_better: false
|
||||||
|
|
||||||
|
report_to:
|
||||||
|
- wandb
|
@ -281,7 +281,6 @@ class PDFRenderer(PipelineStep):
|
|||||||
"""Pipeline step that renders PDF to image."""
|
"""Pipeline step that renders PDF to image."""
|
||||||
|
|
||||||
target_longest_image_dim: int
|
target_longest_image_dim: int
|
||||||
image_transform: Optional[Callable] = None
|
|
||||||
|
|
||||||
def __call__(self, sample: Sample) -> Sample:
|
def __call__(self, sample: Sample) -> Sample:
|
||||||
"""Render PDF to image."""
|
"""Render PDF to image."""
|
||||||
@ -290,10 +289,6 @@ class PDFRenderer(PipelineStep):
|
|||||||
png_bytes = base64.b64decode(base64_png)
|
png_bytes = base64.b64decode(base64_png)
|
||||||
image = Image.open(BytesIO(png_bytes))
|
image = Image.open(BytesIO(png_bytes))
|
||||||
|
|
||||||
# Apply transform if provided
|
|
||||||
if self.image_transform:
|
|
||||||
image = self.image_transform(image)
|
|
||||||
|
|
||||||
# Update sample
|
# Update sample
|
||||||
sample["image"] = image
|
sample["image"] = image
|
||||||
|
|
||||||
@ -524,6 +519,7 @@ class FilterOutRotatedDocuments(PipelineStep):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class InstructUserMessages(PipelineStep):
|
class InstructUserMessages(PipelineStep):
|
||||||
"""Creates instruction-following messages format for training."""
|
"""Creates instruction-following messages format for training."""
|
||||||
@ -670,20 +666,19 @@ class RandomTokenFlipper(PipelineStep):
|
|||||||
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
||||||
"""Dataset that includes front matter parsing and PDF rendering by default."""
|
"""Dataset that includes front matter parsing and PDF rendering by default."""
|
||||||
|
|
||||||
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
|
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, front_matter_class=None):
|
||||||
"""
|
"""
|
||||||
Initialize the dataset with default pipeline steps.
|
Initialize the dataset with default pipeline steps.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
root_dir: Path to the root folder containing processed markdown and PDF files
|
root_dir: Path to the root folder containing processed markdown and PDF files
|
||||||
target_longest_image_dim: Target dimension for the longest side of the image
|
target_longest_image_dim: Target dimension for the longest side of the image
|
||||||
image_transform: Optional transform to apply to the PDF images
|
|
||||||
front_matter_class: Optional dataclass type to validate front matter against
|
front_matter_class: Optional dataclass type to validate front matter against
|
||||||
"""
|
"""
|
||||||
# Create default pipeline steps
|
# Create default pipeline steps
|
||||||
pipeline_steps = [
|
pipeline_steps = [
|
||||||
FrontMatterParser(front_matter_class),
|
FrontMatterParser(front_matter_class),
|
||||||
PDFRenderer(target_longest_image_dim, image_transform),
|
PDFRenderer(target_longest_image_dim),
|
||||||
StaticLengthDocumentAnchoring(target_anchor_text_len=6000),
|
StaticLengthDocumentAnchoring(target_anchor_text_len=6000),
|
||||||
FinetuningPrompt(),
|
FinetuningPrompt(),
|
||||||
FrontMatterOutputFormat(),
|
FrontMatterOutputFormat(),
|
||||||
|
@ -106,6 +106,7 @@ train = [
|
|||||||
"s3fs",
|
"s3fs",
|
||||||
"necessary",
|
"necessary",
|
||||||
"einops",
|
"einops",
|
||||||
|
"augraphy",
|
||||||
]
|
]
|
||||||
|
|
||||||
elo = [
|
elo = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user