mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-02 02:54:53 +00:00
First attempt at new trainer code
This commit is contained in:
parent
3eda2c04c1
commit
f0d8ff7bd3
1
.gitignore
vendored
1
.gitignore
vendored
@ -14,6 +14,7 @@ localworkspace/*
|
||||
math_data/*
|
||||
math_data_big/*
|
||||
gpt4otestset/*
|
||||
old_train/
|
||||
gpt4otestset_output/*
|
||||
pdfs/*
|
||||
olmOCR-bench/*
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
# pip install llmcompressor
|
||||
|
||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
|
||||
|
||||
MODEL_ID = "/home/ubuntu/olmocr/olmOCR-7B-0225-preview"
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_ID, device_map="auto", torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||||
|
||||
# Configure the simple PTQ quantization
|
||||
# recipe = QuantizationModifier(
|
||||
# targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"])
|
||||
|
||||
# Configure pre-defined qwen2vl recipe
|
||||
recipe = QuantizationModifier(
|
||||
targets="Linear",
|
||||
scheme="FP8_DYNAMIC",
|
||||
ignore=["re:.*lm_head", "re:visual.*"],
|
||||
)
|
||||
|
||||
# Apply the quantization algorithm.
|
||||
oneshot(model=model, recipe=recipe)
|
||||
|
||||
# Save the model.
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe"
|
||||
model.save_pretrained(SAVE_DIR)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
@ -1,87 +0,0 @@
|
||||
model:
|
||||
name_or_path: allenai/Molmo-7B-O-0924
|
||||
arch: causal
|
||||
use_flash_attn: true
|
||||
|
||||
wandb:
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
generate:
|
||||
max_length: 8192
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
sources:
|
||||
- name: openai_batch_data_v5_1_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_iabooks_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
|
||||
valid_data:
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
metric_for_best_model: openai_batch_data_v5_1_eval_loss
|
||||
sources:
|
||||
# These tend to be small, so you can load from s3 it's no big deal
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_iabooks_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
find_unused_parameters: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 3e-4
|
||||
max_steps: 10000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 10
|
||||
eval_every_steps: 100
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
||||
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
|
||||
lora:
|
||||
rank: 32
|
||||
alpha: 32
|
||||
dropout: 0.05
|
||||
task_type: CAUSAL_LM
|
||||
target_modules:
|
||||
# attention layers in main transformer
|
||||
- att_proj
|
||||
- ff_proj
|
||||
- attn_out
|
||||
- ff_out
|
||||
# vision transformer attention and FF
|
||||
- attention.wq
|
||||
- attention.wk
|
||||
- attention.wv
|
||||
- attention.wo
|
||||
- feed_forward.w1
|
||||
- feed_forward.w2
|
||||
# vision image projector
|
||||
- vision_backbone.image_projector.w1
|
||||
- vision_backbone.image_projector.w2
|
||||
- vision_backbone.image_projector.w3
|
||||
|
||||
save:
|
||||
path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/
|
||||
save_every_steps: 1000
|
||||
|
||||
max_workers: 10
|
||||
@ -1,89 +0,0 @@
|
||||
model:
|
||||
name_or_path: allenai/Molmo-7B-O-0924
|
||||
arch: causal
|
||||
use_flash_attn: true
|
||||
|
||||
wandb:
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
generate:
|
||||
max_length: 4096
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
sources:
|
||||
- name: openai_batch_data_v5_1_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_iabooks_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
|
||||
valid_data:
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
metric_for_best_model: openai_batch_data_v5_1_eval_loss
|
||||
sources:
|
||||
# These tend to be small, so you can load from s3 it's no big deal
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
|
||||
|
||||
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
find_unused_parameters: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 1e-4
|
||||
max_steps: 10000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 10
|
||||
eval_every_steps: 100
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
||||
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
|
||||
lora:
|
||||
rank: 32
|
||||
alpha: 32
|
||||
dropout: 0.05
|
||||
task_type: CAUSAL_LM
|
||||
target_modules:
|
||||
# attention layers in main transformer
|
||||
- att_proj
|
||||
- ff_proj
|
||||
- attn_out
|
||||
- ff_out
|
||||
# vision transformer attention and FF
|
||||
- attention.wq
|
||||
- attention.wk
|
||||
- attention.wv
|
||||
- attention.wo
|
||||
- feed_forward.w1
|
||||
- feed_forward.w2
|
||||
# vision image projector
|
||||
- vision_backbone.image_projector.w1
|
||||
- vision_backbone.image_projector.w2
|
||||
- vision_backbone.image_projector.w3
|
||||
|
||||
save:
|
||||
path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/
|
||||
save_every_steps: 1000
|
||||
|
||||
max_workers: 10
|
||||
@ -1,73 +0,0 @@
|
||||
model:
|
||||
name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
arch: causal
|
||||
use_flash_attn: true
|
||||
|
||||
wandb:
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
generate:
|
||||
max_length: 8192
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
sources:
|
||||
# These tend to be small, so you can load from s3 it's no big deal
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_iabooks_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
# - name: openai_batch_data_v5_1_train
|
||||
# response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
|
||||
# target_longest_image_dim: [1024]
|
||||
# target_anchor_text_len: [6000]
|
||||
# - name: openai_batch_data_v5_1_iabooks_train
|
||||
# response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
|
||||
# target_longest_image_dim: [1024]
|
||||
# target_anchor_text_len: [6000]
|
||||
|
||||
valid_data:
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
metric_for_best_model: openai_batch_data_v5_1_eval_loss
|
||||
sources:
|
||||
# These tend to be small, so you can load from s3 it's no big deal
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_iabooks_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
|
||||
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 1e-6
|
||||
max_steps: 10000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 10
|
||||
eval_every_steps: 100
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
||||
|
||||
save:
|
||||
path: s3://ai2-oe-data/jakep/experiments/qwen25vl-pdf/v1/models/
|
||||
save_every_steps: 9500
|
||||
|
||||
max_workers: 10
|
||||
@ -1,89 +0,0 @@
|
||||
model:
|
||||
name_or_path: Qwen/Qwen2-VL-2B-Instruct
|
||||
arch: causal
|
||||
use_flash_attn: true
|
||||
|
||||
wandb:
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
# TODO This is not used
|
||||
format:
|
||||
instruction_template: "Original:"
|
||||
response_template: "Rewritten:"
|
||||
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
|
||||
chat_template: |
|
||||
{% for message in messages %}
|
||||
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
|
||||
{% if loop.last %}
|
||||
{{ '<|im_end|>'}}
|
||||
{% else %}
|
||||
{{ '<|im_end|>\n' }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
generate:
|
||||
max_length: 4096
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
sources:
|
||||
- name: openai_batch_data_v2
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
|
||||
valid_data:
|
||||
sources:
|
||||
- name: openai_batch_data_eval_mini
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: false
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 3e-4
|
||||
max_steps: 2000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 50
|
||||
eval_every_steps: 1000
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
||||
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
|
||||
lora:
|
||||
rank: 32
|
||||
alpha: 32
|
||||
dropout: 0.05
|
||||
task_type: causal_lm
|
||||
target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
- visual.blocks.[0-9]+.attn.qkv
|
||||
- visual.blocks.[0-9]+.attn.proj
|
||||
- visual.blocks.[0-9]+.mlp.fc1
|
||||
- visual.blocks.[0-9]+.mlp.fc2
|
||||
- visual.merger.mlp.0
|
||||
- visual.merger.mlp.2
|
||||
|
||||
save:
|
||||
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
|
||||
save_every_steps: 1000
|
||||
|
||||
max_workers: 10
|
||||
@ -1,84 +0,0 @@
|
||||
model:
|
||||
name_or_path: Qwen/Qwen2-VL-2B-Instruct
|
||||
arch: causal
|
||||
use_flash_attn: true
|
||||
|
||||
wandb:
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
# TODO This is not used
|
||||
format:
|
||||
instruction_template: "Original:"
|
||||
response_template: "Rewritten:"
|
||||
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
|
||||
chat_template: |
|
||||
{% for message in messages %}
|
||||
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
|
||||
{% if loop.last %}
|
||||
{{ '<|im_end|>'}}
|
||||
{% else %}
|
||||
{{ '<|im_end|>\n' }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
generate:
|
||||
max_length: 4096
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
sources:
|
||||
- name: openai_batch_data_v2
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
|
||||
valid_data:
|
||||
sources:
|
||||
- name: openai_batch_data_eval_mini
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: false
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 3e-4
|
||||
max_steps: 2000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 50
|
||||
eval_every_steps: 1000
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
||||
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
|
||||
# Disable LORA for now, because we want the visual network to get trained too
|
||||
# lora:
|
||||
# rank: 32
|
||||
# alpha: 32
|
||||
# dropout: 0.05
|
||||
# task_type: causal_lm
|
||||
# target_modules:
|
||||
# - q_proj
|
||||
# - k_proj
|
||||
# - v_proj
|
||||
# - o_proj
|
||||
# - gate_proj
|
||||
# - up_proj
|
||||
# - down_proj
|
||||
|
||||
save:
|
||||
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
|
||||
save_every_steps: 1000
|
||||
|
||||
max_workers: 10
|
||||
@ -1,84 +0,0 @@
|
||||
model:
|
||||
name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||
arch: causal
|
||||
use_flash_attn: true
|
||||
|
||||
wandb:
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
generate:
|
||||
max_length: 8192
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
sources:
|
||||
- name: openai_batch_data_v5_1_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
|
||||
target_longest_image_dim: 1024
|
||||
target_anchor_text_len: 6000
|
||||
- name: openai_batch_data_v5_1_iabooks_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
|
||||
target_longest_image_dim: 1024
|
||||
target_anchor_text_len: 6000
|
||||
|
||||
valid_data:
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
metric_for_best_model: openai_batch_data_v5_1_eval_loss
|
||||
sources:
|
||||
# These tend to be small, so you can load from s3 it's no big deal
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
|
||||
target_longest_image_dim: 1024
|
||||
target_anchor_text_len: 6000
|
||||
- name: openai_batch_data_v5_1_iabooks_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
|
||||
target_longest_image_dim: 1024
|
||||
target_anchor_text_len: 6000
|
||||
|
||||
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 1e-4
|
||||
max_steps: 10000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 10
|
||||
eval_every_steps: 100
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
||||
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
|
||||
lora:
|
||||
rank: 32
|
||||
alpha: 32
|
||||
dropout: 0.05
|
||||
task_type: causal_lm
|
||||
target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
- visual.blocks.[0-9]+.attn.qkv
|
||||
- visual.blocks.[0-9]+.attn.proj
|
||||
- visual.blocks.[0-9]+.mlp.fc1
|
||||
- visual.blocks.[0-9]+.mlp.fc2
|
||||
- visual.merger.mlp.0
|
||||
- visual.merger.mlp.2
|
||||
|
||||
save:
|
||||
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
|
||||
save_every_steps: 1000
|
||||
|
||||
max_workers: 10
|
||||
@ -1,64 +0,0 @@
|
||||
model:
|
||||
name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||
arch: causal
|
||||
use_flash_attn: true
|
||||
|
||||
wandb:
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
generate:
|
||||
max_length: 8192
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
sources:
|
||||
- name: openai_batch_data_v5_1_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_iabooks_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
|
||||
valid_data:
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
metric_for_best_model: openai_batch_data_v5_1_eval_loss
|
||||
sources:
|
||||
# These tend to be small, so you can load from s3 it's no big deal
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_iabooks_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
|
||||
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 1e-6
|
||||
max_steps: 10000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 10
|
||||
eval_every_steps: 100
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
||||
|
||||
save:
|
||||
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
|
||||
save_every_steps: 9500
|
||||
|
||||
max_workers: 10
|
||||
@ -1,96 +0,0 @@
|
||||
import json
|
||||
from logging import Logger
|
||||
from typing import Optional, Type
|
||||
|
||||
import smart_open
|
||||
import torch
|
||||
from peft.peft_model import PeftModel
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
from .config import ModelConfig
|
||||
from .loggers import get_logger
|
||||
from .paths import cached_path, exists, get_cache_dir, join_path, resource_to_filename
|
||||
|
||||
__all__ = ["load_model", "cache_merged_model"]
|
||||
|
||||
|
||||
def get_model_cls(config: ModelConfig) -> Type[AutoModelWithLMHead]:
|
||||
if config.arch == "seq2seq":
|
||||
return AutoModelForSeq2SeqLM # pyright: ignore
|
||||
elif config.arch == "causal" or config.arch == "vllm":
|
||||
return AutoModelForCausalLM # pyright: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported model architecture: {config.arch}")
|
||||
|
||||
|
||||
def get_adapter_config(config: ModelConfig) -> dict:
|
||||
local_path = cached_path(config.name_or_path)
|
||||
if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")):
|
||||
with smart_open.open(adapter_config_path, "rt", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
def load_model(config: ModelConfig, logger: Optional[Logger] = None) -> AutoModelWithLMHead:
|
||||
logger = logger or get_logger(__file__, level="INFO")
|
||||
|
||||
logger.info(f"Loading model from {config.name_or_path}")
|
||||
local_path = cached_path(config.name_or_path)
|
||||
if local_path != config.name_or_path:
|
||||
logger.info(f"Model cached at {local_path}")
|
||||
|
||||
if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")):
|
||||
logger.info(f"Loading LoRA adapter from {adapter_config_path}")
|
||||
with smart_open.open(adapter_config_path) as f:
|
||||
adapter_config = json.load(f)
|
||||
base_model_name_or_path = adapter_config["base_model_name_or_path"]
|
||||
enable_lora = True
|
||||
else:
|
||||
base_model_name_or_path = local_path
|
||||
enable_lora = False
|
||||
|
||||
model = get_model_cls(config).from_pretrained(
|
||||
base_model_name_or_path,
|
||||
device_map="auto",
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
# low_cpu_mem_usage=model_config.low_cpu_mem_usage,
|
||||
use_flash_attention_2=True if config.use_flash_attn else False,
|
||||
revision=config.model_revision,
|
||||
torch_dtype=torch.bfloat16 if config.use_flash_attn else getattr(torch, config.dtype),
|
||||
)
|
||||
logger.info(f"Successfully loaded base model from {base_model_name_or_path}")
|
||||
|
||||
if enable_lora:
|
||||
peft_model = PeftModel.from_pretrained(model, local_path)
|
||||
model = peft_model.merge_and_unload()
|
||||
logger.info(f"Successfully loaded LoRA adapter from base model: {base_model_name_or_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def cache_merged_model(config: ModelConfig, logger: Optional[Logger] = None) -> str:
|
||||
logger = logger or get_logger(__file__, level="INFO")
|
||||
|
||||
base_local_path = cached_path(config.name_or_path)
|
||||
adapter_config = get_adapter_config(config)
|
||||
if not adapter_config:
|
||||
logger.info("No adapter config found; using base model")
|
||||
return base_local_path
|
||||
|
||||
local_fn = resource_to_filename(json.dumps({"adapter": adapter_config, "model": config.name_or_path}))
|
||||
merged_local_path = f"{get_cache_dir()}/{local_fn}"
|
||||
|
||||
if not exists(merged_local_path):
|
||||
model = load_model(config=config, logger=logger)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_local_path)
|
||||
|
||||
logger.info(f"Saving merged model to {merged_local_path}")
|
||||
model.save_pretrained(merged_local_path)
|
||||
tokenizer.save_pretrained(merged_local_path)
|
||||
|
||||
return merged_local_path
|
||||
@ -1,327 +0,0 @@
|
||||
"""
|
||||
Utilities to work with a OmegaConf structured config object
|
||||
|
||||
From Dolma Toolkit: https://github.com/allenai/dolma/blob/64886d9db15bd99acea9e28740ae20a510875dfb/python/dolma/cli/__init__.py
|
||||
|
||||
Author: Luca Soldaini (@soldni)
|
||||
""" # noqa: E501
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from collections.abc import Iterable
|
||||
from copy import deepcopy
|
||||
from dataclasses import Field
|
||||
from dataclasses import field as dataclass_field
|
||||
from dataclasses import is_dataclass
|
||||
from logging import warning
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import smart_open
|
||||
from necessary import necessary
|
||||
from omegaconf import MISSING, DictConfig, ListConfig
|
||||
from omegaconf import OmegaConf as om
|
||||
from omegaconf.errors import OmegaConfBaseException
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
from yaml import safe_load # type: ignore
|
||||
|
||||
from .errors import DolmaRefineError
|
||||
|
||||
__all__ = ["field", "namespace_to_nested_omegaconf", "print_config", "make_cli", "read_config", "to_native_types"]
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Any)
|
||||
D = TypeVar("D", bound="DataClass")
|
||||
A = TypeVar("A", bound="ArgumentParser")
|
||||
|
||||
|
||||
def _field_nargs(default: Any) -> Union[Literal["?"], Literal["*"]]:
|
||||
# return '+' if _default is iterable but not string/bytes, else 1
|
||||
if isinstance(default, (str, bytes)):
|
||||
return "?"
|
||||
|
||||
if isinstance(default, Iterable):
|
||||
return "*"
|
||||
|
||||
return "?"
|
||||
|
||||
|
||||
def field(default: T = MISSING, help: Optional[str] = None, **extra: Any) -> T:
|
||||
metadata = {"help": help, "type": type(default), "default": default, "nargs": _field_nargs(default), **extra}
|
||||
return dataclass_field(default_factory=lambda: deepcopy(default), metadata=metadata)
|
||||
|
||||
|
||||
class DataClass(Protocol):
|
||||
__dataclass_fields__: Dict[str, Field]
|
||||
|
||||
|
||||
def read_config(path: Union[None, str]) -> Dict[str, Any]:
|
||||
"""Read a configuration file if it exists"""
|
||||
if path is None:
|
||||
return {}
|
||||
|
||||
try:
|
||||
with smart_open.open(path, mode="rt") as f:
|
||||
return dict(safe_load(f))
|
||||
except FileNotFoundError as ex:
|
||||
raise DolmaRefineError(f"Config file not found: {path}") from ex
|
||||
except Exception as ex:
|
||||
raise DolmaRefineError(f"Error while reading config file: {path}") from ex
|
||||
|
||||
|
||||
def save_config(config: Union[dict, DictConfig, list, ListConfig, DataClass], path: str) -> None:
|
||||
"""Save a configuration to a file"""
|
||||
if isinstance(config, (list, dict)):
|
||||
config = om.create(config)
|
||||
elif is_dataclass(config):
|
||||
config = om.structured(config)
|
||||
|
||||
with smart_open.open(path, mode="wt") as f:
|
||||
f.write(om.to_yaml(config))
|
||||
|
||||
|
||||
def _make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None) -> A:
|
||||
for field_name, dt_field in config.__dataclass_fields__.items():
|
||||
# get type from annotations or metadata
|
||||
typ_ = config.__annotations__.get(field_name, dt_field.metadata.get("type", MISSING))
|
||||
|
||||
if typ_ is MISSING:
|
||||
warning(f"No type annotation for field {field_name} in {config.__name__}")
|
||||
continue
|
||||
|
||||
# join prefix and field name
|
||||
field_name = f"{prefix}.{field_name}" if prefix else field_name
|
||||
|
||||
# This section here is to handle Optional[T] types; we only care for cases where T is a dataclass
|
||||
# So we first check if type is Union since Optional[T] is just a shorthand for Union[T, None]
|
||||
# and that the union contains only one non-None type
|
||||
if get_origin(typ_) == Union:
|
||||
# get all non-None types
|
||||
args = [a for a in get_args(typ_) if a is not type(None)] # noqa: E721
|
||||
|
||||
if len(args) == 1:
|
||||
# simple Optional[T] type
|
||||
typ_ = args[0]
|
||||
|
||||
# here's where we check if T is a dataclass
|
||||
if is_dataclass(typ_):
|
||||
# recursively add subparsers
|
||||
_make_parser(parser, typ_, prefix=field_name) # type: ignore
|
||||
continue
|
||||
|
||||
if typ_ is bool:
|
||||
# for boolean values, we add two arguments: --field_name and --no-field_name
|
||||
parser.add_argument(
|
||||
f"--{field_name}",
|
||||
help=dt_field.metadata.get("help"),
|
||||
dest=field_name,
|
||||
action="store_true",
|
||||
default=MISSING,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--no-{field_name}",
|
||||
help=f"Disable {field_name}",
|
||||
dest=field_name,
|
||||
action="store_false",
|
||||
default=MISSING,
|
||||
)
|
||||
else:
|
||||
# else it's just a normal argument
|
||||
parser.add_argument(
|
||||
f"--{field_name}",
|
||||
help=dt_field.metadata.get("help"),
|
||||
nargs=dt_field.metadata.get("nargs", "?"),
|
||||
default=MISSING,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def make_nested_dict(key: str, value: Any, d: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
d = d or {}
|
||||
|
||||
if "." in key:
|
||||
key, rest = key.split(".", 1)
|
||||
value = make_nested_dict(rest, value, d.get(key))
|
||||
|
||||
# the value was provided (is not MISSING constant) and is not an empty dict or list
|
||||
if value != MISSING and (not isinstance(value, (dict, list)) or len(value) > 0):
|
||||
d[key] = value
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def to_native_types(obj: Any, resolve: bool = True, throw_on_missing: bool = True, enum_to_str: bool = True) -> Any:
|
||||
"""Converts an OmegaConf object to native types (dicts, lists, etc.)"""
|
||||
|
||||
# convert dataclass to structured config
|
||||
if hasattr(obj, "to_dict"):
|
||||
# huggingface objects have a to_dict method, we prefer that
|
||||
obj = obj.to_dict()
|
||||
elif is_dataclass(obj):
|
||||
# we go through structured config instead and hope for the best
|
||||
obj = om.to_container(obj)
|
||||
|
||||
if isinstance(obj, DictConfig) or isinstance(obj, ListConfig):
|
||||
obj = om.to_container(obj, resolve=resolve, throw_on_missing=throw_on_missing, enum_to_str=enum_to_str)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
return {k: to_native_types(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [to_native_types(v) for v in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config: Optional[dict] = None) -> T:
|
||||
nested_config_dict: Dict[str, Any] = {}
|
||||
for key, value in vars(args).items():
|
||||
nested_config_dict = make_nested_dict(key, value, nested_config_dict)
|
||||
|
||||
untyped_config: DictConfig = om.merge(
|
||||
om.create(config or {}), om.create(nested_config_dict)
|
||||
) # pyright: ignore (pylance is confused because om.create might return a DictConfig or a ListConfig)
|
||||
|
||||
# resolve any interpolations in the config
|
||||
om.resolve(untyped_config)
|
||||
|
||||
# create structured config from cli dataclass
|
||||
base_structured_config: DictConfig = om.structured(structured)
|
||||
|
||||
# merge with options parsed from config file and
|
||||
merged_config = om.merge(base_structured_config, untyped_config)
|
||||
|
||||
# check for type
|
||||
if not isinstance(merged_config, DictConfig):
|
||||
raise DolmaRefineError(f"Expected a DictConfig, got {type(merged_config).__name__}")
|
||||
|
||||
# try resolving all cross references in the config, raise a DolmaConfigError if it fails
|
||||
try:
|
||||
om.resolve(merged_config)
|
||||
except OmegaConfBaseException as ex:
|
||||
raise DolmaRefineError(f"Invalid error while parsing key `{ex.full_key}`: {type(ex).__name__}") from ex
|
||||
|
||||
return merged_config # pyright: ignore
|
||||
|
||||
|
||||
def print_config(config: Any, console: Optional[Console] = None) -> None:
|
||||
if not isinstance(config, (DictConfig, ListConfig)):
|
||||
config = om.create(config)
|
||||
|
||||
# print the config as yaml using a rich syntax highlighter
|
||||
console = console or Console()
|
||||
yaml_config = om.to_yaml(config, sort_keys=True).strip()
|
||||
highlighted = Syntax(code=yaml_config, lexer="yaml", theme="ansi_dark")
|
||||
console.print(highlighted)
|
||||
|
||||
|
||||
def _patch_old_omegaconf():
|
||||
"""Monkey patch omegaconf below version 2.3.0 to support custom resolver returning
|
||||
lists or dicts. Applies patch https://github.com/omry/omegaconf/pull/1093"""
|
||||
|
||||
if necessary(("omegaconf", "2.4.0"), soft=True):
|
||||
# no need to patch
|
||||
return
|
||||
|
||||
if getattr(_patch_old_omegaconf, "__patched__", False):
|
||||
# already patched
|
||||
return
|
||||
|
||||
from omegaconf import _impl # pylint: disable=import-outside-toplevel
|
||||
from omegaconf import ( # pylint: disable=import-outside-toplevel
|
||||
Container,
|
||||
Node,
|
||||
ValueNode,
|
||||
)
|
||||
from omegaconf._utils import ( # noqa: F401 # pylint: disable=import-outside-toplevel
|
||||
_ensure_container,
|
||||
_get_value,
|
||||
is_primitive_container,
|
||||
is_structured_config,
|
||||
)
|
||||
from omegaconf.errors import ( # pylint: disable=import-outside-toplevel
|
||||
InterpolationToMissingValueError,
|
||||
)
|
||||
from omegaconf.nodes import ( # pylint: disable=import-outside-toplevel
|
||||
InterpolationResultNode,
|
||||
)
|
||||
|
||||
def _resolve_container_value(cfg: Container, key: Any) -> None:
|
||||
node = cfg._get_child(key) # pylint: disable=protected-access
|
||||
assert isinstance(node, Node)
|
||||
if node._is_interpolation(): # pylint: disable=protected-access
|
||||
try:
|
||||
resolved = node._dereference_node() # pylint: disable=protected-access
|
||||
except InterpolationToMissingValueError:
|
||||
node._set_value(MISSING) # pylint: disable=protected-access
|
||||
else:
|
||||
if isinstance(resolved, Container):
|
||||
_impl._resolve(resolved) # pylint: disable=protected-access
|
||||
if isinstance(resolved, InterpolationResultNode):
|
||||
resolved_value = _get_value(resolved)
|
||||
if is_primitive_container(resolved_value) or is_structured_config(resolved_value):
|
||||
resolved = _ensure_container(resolved_value)
|
||||
if isinstance(resolved, Container) and isinstance(node, ValueNode):
|
||||
cfg[key] = resolved
|
||||
else:
|
||||
node._set_value(_get_value(resolved)) # pylint: disable=protected-access
|
||||
else:
|
||||
_impl._resolve(node) # pylint: disable=protected-access
|
||||
|
||||
# set new function and mark as patched
|
||||
setattr(_impl, "_resolve_container_value", _resolve_container_value)
|
||||
setattr(_patch_old_omegaconf, "__patched__", True)
|
||||
|
||||
|
||||
# actually executes the patch
|
||||
_patch_old_omegaconf()
|
||||
|
||||
|
||||
def make_cli(config_cls: Type[D], _config_flag: str = "config", _dryrun_flag: str = "dryrun") -> D:
|
||||
"""Create a CLI parser for a dataclass and parse the arguments into a structured config object."""
|
||||
|
||||
if hasattr(config_cls, _config_flag):
|
||||
raise DolmaRefineError(f"`{_config_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`")
|
||||
|
||||
if hasattr(config_cls, _dryrun_flag):
|
||||
raise DolmaRefineError(f"`{_dryrun_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`")
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(f"-{_config_flag[0]}", f"--{_config_flag}", help="Path to config file", default=None, type=str)
|
||||
parser.add_argument(
|
||||
f"-{_dryrun_flag[0]}",
|
||||
f"--{_dryrun_flag}",
|
||||
help="Dry run mode: print config and exit",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser = _make_parser(parser, config_cls)
|
||||
args = parser.parse_args()
|
||||
|
||||
parsed_config: Dict[str, Any] = {}
|
||||
if (config_path := getattr(args, _config_flag)) is not None:
|
||||
parsed_config = read_config(config_path)
|
||||
delattr(args, _config_flag)
|
||||
|
||||
only_dryrun = getattr(args, _dryrun_flag, False)
|
||||
delattr(args, _dryrun_flag)
|
||||
|
||||
full_config = namespace_to_nested_omegaconf(args, config_cls, parsed_config)
|
||||
|
||||
print_config(full_config)
|
||||
|
||||
if only_dryrun:
|
||||
exit(0)
|
||||
|
||||
return full_config
|
||||
@ -1,16 +0,0 @@
|
||||
from smart_open import register_compressor
|
||||
|
||||
__all__ = ["mk_compression"]
|
||||
|
||||
|
||||
def mk_compression():
|
||||
def _handle_zst(file_obj, mode):
|
||||
try:
|
||||
import zstandard as zstd
|
||||
except ImportError:
|
||||
raise ImportError("zstandard is required for zstd support")
|
||||
|
||||
return zstd.open(file_obj, mode)
|
||||
|
||||
register_compressor(".zstd", _handle_zst)
|
||||
register_compressor(".zst", _handle_zst)
|
||||
@ -1,131 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from peft import TaskType # pyright: ignore
|
||||
|
||||
from .cli import field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for loading a model; includes model name and type."""
|
||||
|
||||
name_or_path: str = field(help="The model name or path to load; must be compatible with huggingface transformers.")
|
||||
arch: str = field(help="The model type to load; can be 'vllm', 'causal', or 'vllm'")
|
||||
dtype: str = field(help="The precision to use for the model", default="bfloat16")
|
||||
use_flash_attn: bool = field(help="Whether to use the flash attention for the model.", default=False)
|
||||
trust_remote_code: bool = field(help="Whether to trust remote code for the model.", default=False)
|
||||
low_cpu_mem_usage: bool = field(help="Whether to use low cpu memory usage for the model.", default=False)
|
||||
fast_tokenizer: bool = field(help="Whether to use the fast tokenizer for the model.", default=True)
|
||||
model_revision: Optional[str] = field(help="The model revision to use for the model.", default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateConfig:
|
||||
max_length: int = field(help="The maximum length of the generated text", default=4096)
|
||||
temperature: float = field(default=0.2, help="The temperature to use for generation")
|
||||
top_k: int = field(default=50, help="The top k to use for generation")
|
||||
top_p: float = field(default=1.0, help="The top p to use for generation")
|
||||
num_beams: int = field(default=1, help="The number of beams to use for generation")
|
||||
truncate_prompt_tokens: bool = field(default=True, help="Whether to truncate the prompt tokens for generation")
|
||||
max_num_seqs: int = field(default=16, help="The maximum number of sequences to generate")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandbConfig:
|
||||
entity: str = field(help="The wandb entity to use for logging", default="ai2-llm")
|
||||
project: str = field(help="The wandb project to use for logging", default="pdf-qwen2vl")
|
||||
wandb_api_key: Optional[str] = field(help="The wandb api key to use for logging", default=None)
|
||||
mode: str = field(help="The wandb mode to use for logging. Set it to `offline`", default="online")
|
||||
watch: str = field(help="The wandb watch to use for logging", default="false")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AwsConfig:
|
||||
profile: Optional[str] = field(help="The aws profile to use for s3 access", default=None)
|
||||
access_key_id: Optional[str] = field(help="The aws access key id to use for s3 access", default=None)
|
||||
secret_access_key: Optional[str] = field(help="The aws secret access key to use for s3 access", default=None)
|
||||
default_region: Optional[str] = field(help="The default region to use for s3 access", default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceConfig:
|
||||
name: str = field(help="The name of the source")
|
||||
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
|
||||
target_longest_image_dim: list[int] = field(help="Dimensions to render the pdf page image to")
|
||||
target_anchor_text_len: list[int] = field(help="Maximum amount of anchor text (aka prompt hint)")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataConfig:
|
||||
seed: int = field(default=42, help="The seed to use for data loading")
|
||||
cache_location: Optional[str] = field(help="Location to store s3 pdfs that need to be used to compute page images", default=None)
|
||||
metric_for_best_model: Optional[str] = field(help="metric to pass to trainer args to use for picking best model checkpoint at end", default=None)
|
||||
sources: List[SourceConfig] = field(help="The source configurations")
|
||||
|
||||
|
||||
@dataclass
|
||||
class HyperparamConfig:
|
||||
batch_size: int = field(default=8, help="The batch size to use for training")
|
||||
eval_batch_size: Optional[int] = field(default=None, help="The batch size to use for evaluation; default is the same as the training batch size")
|
||||
learning_rate: float = field(default=2e-5, help="The learning rate to use for training")
|
||||
max_steps: int = field(default=-1, help="The maximum number of steps to train the model")
|
||||
pad_multiple_of: int = field(default=16, help="The padding multiple to use for the model")
|
||||
log_every_steps: int = field(default=5, help="The number of steps to log training metrics")
|
||||
eval_every_steps: int = field(default=100, help="The number of steps to evaluate the model")
|
||||
weight_decay: float = field(default=0.0, help="The weight decay to use for training")
|
||||
warmup_steps: int = field(default=0, help="The number of warmup steps to use for training")
|
||||
warmup_ratio: float = field(default=0.0, help="The ratio of warmup steps to use for training")
|
||||
lr_scheduler: str = field(default="linear", help="The learning rate scheduler to use for training")
|
||||
gradient_accumulation_steps: int = field(default=1, help="The number of gradient accumulation steps to use for training")
|
||||
gradient_checkpointing: bool = field(default=False, help="Whether to use gradient checkpointing for training")
|
||||
seed: int = field(default=42, help="The seed to use for training")
|
||||
reduce_loss: str = field(default="mean", help="The loss reduction to use for training")
|
||||
clip_grad_norm: float = field(default=0.0, help="The gradient norm to clip to for training")
|
||||
optim: str = field(default="adamw_torch", help="The optimizer to use for training")
|
||||
find_unused_parameters: bool = field(default=False, help="Whether to find unused parameters for training")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaveConfig:
|
||||
path: str = field(default="./results", help="The output directory to save the model")
|
||||
limit: Optional[int] = field(default=None, help="The number of checkpoints to save")
|
||||
save_every_steps: int = field(default="${hparams.eval_every_steps}", help="The number of steps to save the model") # type: ignore
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig:
|
||||
rank: int = field(default=16, help="The rank of the LoRA attention")
|
||||
alpha: int = field(default=16, help="The alpha parameter for LoRA scaling")
|
||||
dropout: float = field(default=0.05, help="The dropout probability for LoRA layers")
|
||||
bias: str = field(default="none", help="The bias to use for LoRA layers (none, causal, or full)")
|
||||
task_type: str = field(default=TaskType.CAUSAL_LM, help="The task type for the model")
|
||||
target_modules: List[str] = field(
|
||||
default=["k_proj", "q_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
|
||||
help="The target modules in the model that will be replaced with LoRA layers",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
model: ModelConfig = field(default=ModelConfig(), help="The model configuration")
|
||||
lora: Optional[LoraConfig] = field(default=None, help="The LoRA configuration")
|
||||
aws: AwsConfig = field(default=AwsConfig(), help="Configuration for AWS S3")
|
||||
wandb: WandbConfig = field(default=WandbConfig(), help="Configuration for Weights and Biases")
|
||||
train_data: DataConfig = field(default=DataConfig(), help="Configuration for the training data")
|
||||
valid_data: DataConfig = field(default=DataConfig(), help="Configuration for the validation data")
|
||||
generate: GenerateConfig = field(default=GenerateConfig(), help="Configuration for text generation")
|
||||
num_proc: int = field(default=1, help="The maximum number of workers to use for data processing")
|
||||
max_workers: int = field(default=1, help="The maximum number of workers to use for data loaders")
|
||||
hparams: HyperparamConfig = field(default=HyperparamConfig(), help="Hyperparameters for training")
|
||||
save: SaveConfig = field(default=SaveConfig(), help="Configuration for saving the model")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DemoConfig:
|
||||
title: str = field(default="# Dolma Rewriter Demo")
|
||||
description: str = field(default="Internal use only, **DO NOT SHARE OUTSIDE AI2**.")
|
||||
share: bool = field(default=False, help="Share the demo publicly.")
|
||||
|
||||
model: ModelConfig = field(default=ModelConfig())
|
||||
generate: GenerateConfig = field(default=GenerateConfig())
|
||||
@ -1 +0,0 @@
|
||||
class DolmaRefineError(RuntimeError): ...
|
||||
@ -1,52 +0,0 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from typing import Union
|
||||
|
||||
LOGGER_PREFIX = "dolma-refine"
|
||||
|
||||
|
||||
def get_logger(name: str, level: Union[int, str] = logging.WARN) -> logging.Logger:
|
||||
if (proc_name := multiprocessing.current_process().name) == "MainProcess":
|
||||
proc_name = "main"
|
||||
proc_name = proc_name.replace(" ", "_")
|
||||
|
||||
# set the log level
|
||||
level = level if isinstance(level, int) else getattr(logging, level.strip().upper(), logging.WARN)
|
||||
|
||||
# set name
|
||||
name = f"{LOGGER_PREFIX}.{proc_name}.{name}"
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
# add handler
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def reset_level(level: Union[int, str]) -> None:
|
||||
"""
|
||||
Reset the log level for all Dolma loggers.
|
||||
|
||||
Args:
|
||||
level (Union[int, str]): The log level to set. It can be either an integer
|
||||
representing the log level (e.g., logging.DEBUG) or a string
|
||||
representing the log level name (e.g., 'debug').
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if isinstance(level, str):
|
||||
if (level_tmp := getattr(logging, level.strip().upper(), None)) is not None:
|
||||
level = level_tmp
|
||||
else:
|
||||
raise ValueError(f"Invalid log level: {level}")
|
||||
|
||||
for logger in logging.Logger.manager.loggerDict.values():
|
||||
if isinstance(logger, logging.Logger):
|
||||
if logger.name.startswith(LOGGER_PREFIX):
|
||||
logger.setLevel(level)
|
||||
@ -1,615 +0,0 @@
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial, reduce
|
||||
from hashlib import sha256
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from shutil import copyfileobj
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import platformdirs
|
||||
import smart_open
|
||||
from fsspec import AbstractFileSystem, get_filesystem_class
|
||||
from smart_open.compression import get_supported_extensions
|
||||
|
||||
from .loggers import LOGGER_PREFIX, get_logger
|
||||
|
||||
__all__ = [
|
||||
"glob_path",
|
||||
"sub_prefix",
|
||||
"add_suffix",
|
||||
"sub_suffix",
|
||||
"make_relative",
|
||||
"mkdir_p",
|
||||
"split_path",
|
||||
"join_path",
|
||||
"is_glob",
|
||||
"split_glob",
|
||||
"partition_path",
|
||||
]
|
||||
|
||||
|
||||
FS_KWARGS: Dict[str, Dict[str, Any]] = {
|
||||
"": {"auto_mkdir": True},
|
||||
}
|
||||
|
||||
|
||||
RE_ANY_ESCAPE = re.compile(r"(?<!\\)(\*\?\[\])")
|
||||
RE_GLOB_STAR_ESCAPE = re.compile(r"(?<!\\)\*")
|
||||
RE_GLOB_ONE_ESCAPE = re.compile(r"(?<!\\)\?")
|
||||
RE_GLOB_OPEN_ESCAPE = re.compile(r"(?<!\\)\[")
|
||||
RE_GLOB_CLOSE_ESCAPE = re.compile(r"(?<!\\)\]")
|
||||
ESCAPE_SYMBOLS_MAP = {"*": "\u2581", "?": "\u2582", "[": "\u2583", "]": "\u2584"}
|
||||
REVERSE_ESCAPE_SYMBOLS_MAP = {v: k for k, v in ESCAPE_SYMBOLS_MAP.items()}
|
||||
PATCHED_GLOB = False
|
||||
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
def get_fs(path: Union[Path, str]) -> AbstractFileSystem:
|
||||
"""
|
||||
Get the filesystem class for a given path.
|
||||
"""
|
||||
path = str(path)
|
||||
protocol = urlparse(path).scheme
|
||||
fs = get_filesystem_class(protocol)(**FS_KWARGS.get(protocol, {}))
|
||||
|
||||
global PATCHED_GLOB # pylint: disable=global-statement
|
||||
|
||||
# patch glob method to support recursive globbing
|
||||
if protocol == "" and not PATCHED_GLOB:
|
||||
fs.glob = partial(glob.glob, recursive=True)
|
||||
|
||||
# only patch once
|
||||
PATCHED_GLOB = True
|
||||
|
||||
return fs
|
||||
|
||||
|
||||
def _escape_glob(s: Union[str, Path]) -> str:
|
||||
"""
|
||||
Escape glob characters in a string.
|
||||
"""
|
||||
s = str(s)
|
||||
s = RE_GLOB_STAR_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["*"], s)
|
||||
s = RE_GLOB_ONE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["?"], s)
|
||||
s = RE_GLOB_OPEN_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["["], s)
|
||||
s = RE_GLOB_CLOSE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["]"], s)
|
||||
return s
|
||||
|
||||
|
||||
def _unescape_glob(s: Union[str, Path]) -> str:
|
||||
"""
|
||||
Unescape glob characters in a string.
|
||||
"""
|
||||
s = str(s)
|
||||
for k, v in REVERSE_ESCAPE_SYMBOLS_MAP.items():
|
||||
s = s.replace(k, v)
|
||||
return s
|
||||
|
||||
|
||||
def _pathify(path: Union[Path, str]) -> Tuple[str, Path]:
|
||||
"""
|
||||
Return the protocol and path of a given path.
|
||||
"""
|
||||
path = _escape_glob(str(path))
|
||||
parsed = urlparse(path)
|
||||
path = Path(f"{parsed.netloc}/{parsed.path}") if parsed.netloc else Path(parsed.path)
|
||||
return parsed.scheme, path
|
||||
|
||||
|
||||
def _unpathify(protocol: str, path: Path) -> str:
|
||||
"""
|
||||
Return a path from its protocol and path components.
|
||||
"""
|
||||
path_str = _unescape_glob(str(path))
|
||||
if protocol:
|
||||
path_str = f"{protocol}://{path_str.lstrip('/')}"
|
||||
return path_str
|
||||
|
||||
|
||||
def remove_params(path: str) -> str:
|
||||
"""
|
||||
Remove parameters from a path.
|
||||
"""
|
||||
parsed = urlparse(path)
|
||||
return (f"{parsed.scheme}://" if parsed.scheme else "") + f"{parsed.netloc}{parsed.path}"
|
||||
|
||||
|
||||
def is_local(path: str) -> bool:
|
||||
"""
|
||||
Check if a path is local.
|
||||
"""
|
||||
prot, _ = _pathify(path)
|
||||
return prot == "" or prot == "file"
|
||||
|
||||
|
||||
def copy_file(src: str, dest: str) -> None:
|
||||
"""Copy a file using shutil.copyfileobj for efficient chunked copying."""
|
||||
with smart_open.open(src, "rb") as src_file, smart_open.open(dest, "wb") as dest_file:
|
||||
copyfileobj(src_file, dest_file)
|
||||
|
||||
|
||||
def copy_dir(src: str, dst: str, src_fs: Optional[AbstractFileSystem] = None, dst_fs: Optional[AbstractFileSystem] = None):
|
||||
"""Copy a directory using a ThreadPoolExecutor for parallel file copying."""
|
||||
src_fs = src_fs or get_fs(src)
|
||||
dst_fs = dst_fs or get_fs(dst)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures = []
|
||||
|
||||
for src_path in glob_path(src, yield_dirs=True, fs=src_fs):
|
||||
rel_path = sub_prefix(src_path, src)
|
||||
dest_path = join_path("", dst, rel_path)
|
||||
|
||||
if is_dir(src_path, fs=src_fs):
|
||||
# Recursively copy directories
|
||||
copy_dir(src=src_path, dst=dest_path, src_fs=src_fs, dst_fs=dst_fs)
|
||||
else:
|
||||
# File; copy over using the executor for parallelism
|
||||
logger.info(f"Copying {src_path} to {dest_path}")
|
||||
futures.append(executor.submit(copy_file, src_path, dest_path))
|
||||
|
||||
# Wait for all futures to complete
|
||||
for future in futures:
|
||||
future.result() # This will raise an exception if any of the threads failed
|
||||
|
||||
|
||||
def delete_file(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool:
|
||||
"""Delete a file."""
|
||||
|
||||
fs = fs or get_fs(path)
|
||||
try:
|
||||
fs.rm(path)
|
||||
deleted = True
|
||||
except FileNotFoundError as ex:
|
||||
if not ignore_missing:
|
||||
raise ex
|
||||
deleted = False
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def get_size(path: str, fs: Optional[AbstractFileSystem] = None) -> int:
|
||||
"""Get the size of a file"""
|
||||
|
||||
fs = fs or get_fs(path)
|
||||
|
||||
if not exists(path, fs=fs):
|
||||
raise ValueError(f"Path {path} does not exist")
|
||||
if is_dir(path, fs=fs):
|
||||
raise ValueError(f"Path {path} is a directory")
|
||||
|
||||
return fs.info(path)["size"]
|
||||
|
||||
|
||||
def delete_dir(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool:
|
||||
"""Delete a directory."""
|
||||
|
||||
fs = fs or get_fs(path)
|
||||
try:
|
||||
fs.rm(path, recursive=True)
|
||||
deleted = True
|
||||
except FileNotFoundError as ex:
|
||||
if not ignore_missing:
|
||||
raise ex
|
||||
deleted = False
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def partition_path(path: str) -> Tuple[str, Tuple[str, ...], Tuple[str, ...]]:
|
||||
"""Partition a path into its protocol, symbols before a glob, and symbols after a glob."""
|
||||
# split the path into its protocol and path components
|
||||
prot, path_obj = _pathify(path)
|
||||
|
||||
# we need to first figure out if this path has a glob by checking if any of the escaped symbols for
|
||||
# globs are in the path.
|
||||
glob_locs = [i for i, p in enumerate(path_obj.parts) if any(c in p for c in REVERSE_ESCAPE_SYMBOLS_MAP)]
|
||||
|
||||
# make the path components before the glob
|
||||
pre_glob_path = path_obj.parts[: glob_locs[0]] if glob_locs else path_obj.parts
|
||||
pre_glob_path = tuple(_unescape_glob(p) for p in pre_glob_path)
|
||||
|
||||
# make the path components after the glob
|
||||
post_glob_path = path_obj.parts[glob_locs[0] + 1 :] if glob_locs else ()
|
||||
post_glob_path = tuple(_unescape_glob(p) for p in post_glob_path)
|
||||
|
||||
return prot, pre_glob_path, post_glob_path
|
||||
|
||||
|
||||
def split_path(path: str) -> Tuple[str, Tuple[str, ...]]:
|
||||
"""
|
||||
Split a path into its protocol and path components.
|
||||
"""
|
||||
protocol, _path = _pathify(path)
|
||||
return protocol, tuple(_unescape_glob(p) for p in _path.parts)
|
||||
|
||||
|
||||
def join_path(protocol: Union[str, None], *parts: Union[str, Iterable[str]]) -> str:
|
||||
"""
|
||||
Join a path from its protocol and path components.
|
||||
"""
|
||||
all_prots, all_parts = zip(*(_pathify(p) for p in chain.from_iterable([p] if isinstance(p, str) else p for p in parts)))
|
||||
path = str(Path(*all_parts)).rstrip("/")
|
||||
protocol = protocol or str(all_prots[0])
|
||||
|
||||
if protocol:
|
||||
path = f"{protocol}://{path.lstrip('/')}"
|
||||
return _unescape_glob(path)
|
||||
|
||||
|
||||
def glob_path(
|
||||
path: Union[Path, str],
|
||||
hidden_files: bool = False,
|
||||
autoglob_dirs: bool = True,
|
||||
recursive_dirs: bool = False,
|
||||
yield_dirs: bool = True,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Expand a glob path into a list of paths.
|
||||
"""
|
||||
protocol, parsed_path = _pathify(path)
|
||||
fs = fs or get_fs(path)
|
||||
|
||||
if autoglob_dirs and fs.isdir(path):
|
||||
path = join_path(protocol, _unescape_glob(parsed_path), "*")
|
||||
|
||||
if "*" not in str(path):
|
||||
# nothing to glob
|
||||
yield str(path)
|
||||
return
|
||||
|
||||
for gl in fs.glob(path):
|
||||
gl = str(gl)
|
||||
|
||||
if not hidden_files and Path(gl).name.startswith("."):
|
||||
continue
|
||||
|
||||
if fs.isdir(gl):
|
||||
if recursive_dirs:
|
||||
yield from glob_path(
|
||||
gl,
|
||||
hidden_files=hidden_files,
|
||||
autoglob_dirs=autoglob_dirs,
|
||||
recursive_dirs=recursive_dirs,
|
||||
yield_dirs=yield_dirs,
|
||||
fs=fs,
|
||||
)
|
||||
if yield_dirs:
|
||||
yield join_path(protocol, gl)
|
||||
else:
|
||||
yield join_path(protocol, gl)
|
||||
|
||||
|
||||
def sub_prefix(a: str, b: str) -> str:
|
||||
"""
|
||||
Return the relative path of b from a.
|
||||
"""
|
||||
prot_a, path_a = _pathify(a)
|
||||
prot_b, path_b = _pathify(b)
|
||||
|
||||
if prot_a != prot_b:
|
||||
raise ValueError(f"Protocols of {a} and {b} do not match")
|
||||
|
||||
try:
|
||||
diff = str(path_a.relative_to(path_b))
|
||||
except ValueError:
|
||||
diff = join_path(prot_a, path_a.parts)
|
||||
|
||||
return _unescape_glob(diff)
|
||||
|
||||
|
||||
def sub_suffix(a: str, b: str) -> str:
|
||||
"""
|
||||
Remove b from the end of a.
|
||||
"""
|
||||
prot_a, path_a = _pathify(a)
|
||||
prot_b, path_b = _pathify(b)
|
||||
|
||||
if prot_b:
|
||||
raise ValueError(f"{b} is not a relative path")
|
||||
|
||||
sub_path = re.sub(f"{path_b}$", "", str(path_a))
|
||||
sub_prot = f"{prot_a}://" if prot_a else ""
|
||||
|
||||
# need to trim '/' from the end if (a) '/' is not the only symbol in the path or
|
||||
# (b) there is a protocol so absolute paths don't make sense
|
||||
if sub_path != "/" or sub_prot:
|
||||
sub_path = sub_path.rstrip("/")
|
||||
|
||||
return _unescape_glob(sub_prot + sub_path)
|
||||
|
||||
|
||||
def add_suffix(a: str, b: str) -> str:
|
||||
"""
|
||||
Return the the path of a joined with b.
|
||||
"""
|
||||
prot_a, path_a = _pathify(a)
|
||||
prot_b, path_b = _pathify(b)
|
||||
|
||||
if prot_b:
|
||||
raise ValueError(f"{b} is not a relative path")
|
||||
|
||||
return join_path(prot_a, str(path_a / path_b))
|
||||
|
||||
|
||||
def exists(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
|
||||
"""Check if a path exists."""
|
||||
|
||||
fs = fs or get_fs(path)
|
||||
return fs.exists(path)
|
||||
|
||||
|
||||
def is_dir(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
|
||||
"""Check if a path is a directory."""
|
||||
fs = fs or get_fs(path)
|
||||
if exists(path, fs=fs):
|
||||
return fs.isdir(path)
|
||||
return False
|
||||
|
||||
|
||||
def is_file(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
|
||||
"""Check if a path is a file."""
|
||||
fs = fs or get_fs(path)
|
||||
if exists(path, fs=fs):
|
||||
return fs.isfile(path)
|
||||
return False
|
||||
|
||||
|
||||
def parent(path: str) -> str:
|
||||
"""Get the parent directory of a path; if the parent is the root, return the root."""
|
||||
|
||||
prot, parts = split_path(path)
|
||||
if len(parts) == 1:
|
||||
return path
|
||||
return join_path(prot, *parts[:-1])
|
||||
|
||||
|
||||
def mkdir_p(path: str, fs: Optional[AbstractFileSystem] = None) -> None:
|
||||
"""
|
||||
Create a directory if it does not exist.
|
||||
"""
|
||||
if is_glob(path):
|
||||
raise ValueError(f"Cannot create directory with glob pattern: {path}")
|
||||
|
||||
fs = fs or get_fs(path)
|
||||
fs.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
def make_relative(paths: List[str]) -> Tuple[str, List[str]]:
|
||||
"""Find minimum longest root shared among all paths"""
|
||||
if len(paths) == 0:
|
||||
raise ValueError("Cannot make relative path of empty list")
|
||||
|
||||
common_prot, common_parts, _ = partition_path(paths[0])
|
||||
|
||||
for path in paths:
|
||||
current_prot, current_parts, _ = partition_path(path)
|
||||
if current_prot != common_prot:
|
||||
raise ValueError(f"Protocols of {path} and {paths[0]} do not match")
|
||||
|
||||
for i in range(min(len(common_parts), len(current_parts))):
|
||||
if common_parts[i] != current_parts[i]:
|
||||
common_parts = common_parts[:i]
|
||||
break
|
||||
|
||||
if len(common_parts) > 0:
|
||||
common_path = (f"{common_prot}://" if common_prot else "") + str(Path(*common_parts))
|
||||
relative_paths = [sub_prefix(path, common_path) for path in paths]
|
||||
else:
|
||||
common_path = f"{common_prot}://" if common_prot else ""
|
||||
relative_paths = [_unpathify("", _pathify(path)[1]) for path in paths]
|
||||
|
||||
return common_path, relative_paths
|
||||
|
||||
|
||||
def is_glob(path: str) -> bool:
|
||||
"""
|
||||
Check if a path contains a glob wildcard.
|
||||
"""
|
||||
return bool(re.search(r"(?<!\\)[*?[\]]", path))
|
||||
|
||||
|
||||
def split_glob(path: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Partition a path on the first wildcard.
|
||||
"""
|
||||
if not is_glob(path):
|
||||
# it's not a glob, so it's all path
|
||||
return path, ""
|
||||
|
||||
if path[0] == "*":
|
||||
# starts with a glob, so it's all glob
|
||||
return "", path
|
||||
|
||||
protocol, parts = split_path(path)
|
||||
|
||||
i = min(i for i, c in enumerate(parts) if is_glob(c))
|
||||
|
||||
if i == 0:
|
||||
# no path, so it's all glob
|
||||
return protocol, join_path("", *parts)
|
||||
|
||||
path = join_path(protocol, *parts[:i])
|
||||
rest = join_path("", *parts[i:])
|
||||
return path, rest
|
||||
|
||||
|
||||
def get_cache_dir() -> str:
|
||||
"""
|
||||
Returns the path to the cache directory for the Dolma toolkit.
|
||||
If the directory does not exist, it will be created.
|
||||
|
||||
Returns:
|
||||
str: The path to the cache directory.
|
||||
"""
|
||||
loc = platformdirs.user_cache_dir(LOGGER_PREFIX)
|
||||
mkdir_p(loc)
|
||||
return loc
|
||||
|
||||
|
||||
def resource_to_filename(resource: Union[str, bytes]) -> str:
|
||||
"""
|
||||
Convert a ``resource`` into a hashed filename in a repeatable way. Preserves the file extensions.
|
||||
"""
|
||||
_, (*_, orig_filename) = split_path(remove_params(str(resource)))
|
||||
_, extensions = split_basename_and_extension(orig_filename)
|
||||
|
||||
resource_bytes = str(resource).encode("utf-8")
|
||||
resource_hash = sha256(resource_bytes)
|
||||
hash_filename = resource_hash.hexdigest() + extensions
|
||||
|
||||
return hash_filename
|
||||
|
||||
|
||||
def cached_path(path: str, fs: Optional[AbstractFileSystem] = None) -> str:
|
||||
"""
|
||||
Returns the cached path for a given resource.
|
||||
|
||||
If the resource is already available locally, the function returns the path as is.
|
||||
Otherwise, it downloads the resource from the specified path and saves it in the cache directory.
|
||||
|
||||
Args:
|
||||
path (str): The path to the resource.
|
||||
|
||||
Returns:
|
||||
str: The cached path of the resource.
|
||||
"""
|
||||
if is_local(path):
|
||||
# Implementation goes here
|
||||
pass
|
||||
return path
|
||||
|
||||
destination = f"{get_cache_dir()}/{resource_to_filename(path)}"
|
||||
|
||||
remote_fs = fs or get_fs(path)
|
||||
local_fs = get_fs(destination)
|
||||
|
||||
if exists(destination, fs=local_fs):
|
||||
LOGGER.info(f"Using cached file {destination} for {path}")
|
||||
return destination
|
||||
|
||||
if is_dir(path, fs=remote_fs):
|
||||
for sub_path in glob_path(path, fs=remote_fs):
|
||||
rel_path = sub_prefix(sub_path, path)
|
||||
dest_path = join_path("", destination, rel_path)
|
||||
mkdir_p(parent(dest_path), fs=local_fs)
|
||||
LOGGER.info(f"Downloading {sub_path} to {dest_path}")
|
||||
with smart_open.open(sub_path, "rb") as src, smart_open.open(dest_path, "wb") as dest:
|
||||
dest.write(src.read())
|
||||
else:
|
||||
LOGGER.info(f"Downloading {path} to {destination}")
|
||||
with smart_open.open(path, "rb") as src, smart_open.open(destination, "wb") as dest:
|
||||
dest.write(src.read())
|
||||
|
||||
return destination
|
||||
|
||||
|
||||
def split_basename_and_extension(path: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Get the path and extension from a given file path. If a file has multiple
|
||||
extensions, they will be joined with a period, e.g. "foo/bar/baz.tar.gz"
|
||||
will return ("foo/bar/baz", ".tar.gz"). If the file has no extension, the
|
||||
second element of the tuple will be an empty string. Works with both local
|
||||
and remote (e.g. s3://) paths.
|
||||
|
||||
Args:
|
||||
path (str): The file path.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing the path and extension.
|
||||
"""
|
||||
prot, (*parts, filename) = split_path(path)
|
||||
base, *ext_parts = filename.split(".")
|
||||
ext = ("." + ".".join(ext_parts)) if ext_parts else ""
|
||||
return join_path(prot, *parts, base), ext
|
||||
|
||||
|
||||
def decompress_path(path: str, dest: Optional[str] = None) -> str:
|
||||
"""
|
||||
Decompresses a file at the given path and returns the path to the decompressed file.
|
||||
|
||||
Args:
|
||||
path (str): The path to the file to be decompressed.
|
||||
dest (str, optional): The destination path for the decompressed file.
|
||||
If not provided, a destination path will be computed based on the original
|
||||
file name and the cache directory.
|
||||
|
||||
Returns:
|
||||
str: The path to the decompressed file. If the file cannot be decompressed,
|
||||
the original path will be returned.
|
||||
"""
|
||||
for supported_ext in get_supported_extensions():
|
||||
# not the supported extension
|
||||
if not path.endswith(supported_ext):
|
||||
continue
|
||||
|
||||
if dest is None:
|
||||
# compute the name for the decompressed file; to do this, we first hash for
|
||||
# resource and then remove the extension.
|
||||
base_fn, ext = split_basename_and_extension(resource_to_filename(path))
|
||||
|
||||
# to get the decompressed file name, we remove the bit of the extension that
|
||||
# indicates the compression type.
|
||||
decompressed_fn = base_fn + ext.replace(supported_ext, "")
|
||||
|
||||
# finally, we get cache directory and join the decompressed file name to it
|
||||
dest = join_path("", get_cache_dir(), decompressed_fn)
|
||||
|
||||
# here we do the actual decompression
|
||||
with smart_open.open(path, "rb") as fr, smart_open.open(dest, "wb") as fw:
|
||||
fw.write(fr.read())
|
||||
|
||||
# return the path to the decompressed file
|
||||
return dest
|
||||
|
||||
# already decompressed or can't be decompressed
|
||||
return path
|
||||
|
||||
|
||||
def split_ext(path: str) -> Tuple[str, Tuple[str, ...], str]:
|
||||
"""
|
||||
Split a path into its protocol and extensions.
|
||||
"""
|
||||
prot, parts = split_path(path)
|
||||
if not parts:
|
||||
return prot, (), ""
|
||||
|
||||
filename = parts[-1]
|
||||
extensions = []
|
||||
while True:
|
||||
filename, ext = os.path.splitext(filename)
|
||||
if not ext:
|
||||
break
|
||||
extensions.append(ext)
|
||||
|
||||
return prot, (*parts[:-1], filename), "".join(reversed(extensions))
|
||||
|
||||
|
||||
def get_unified_path(paths: List[str]) -> str:
|
||||
"""Get a unified path for a list of paths."""
|
||||
|
||||
if len(paths) == 1:
|
||||
# if there is only one path, we don't need to unify anything
|
||||
return paths[0]
|
||||
|
||||
# get shared root for all paths; we will put the unified path here
|
||||
root, relative = make_relative(paths)
|
||||
|
||||
# get the extension from the first path; assume all paths have the same extension
|
||||
_, _, ext = split_ext(relative[0])
|
||||
|
||||
# hash all the sorted relative paths in order to get a unique name
|
||||
# the type: ignore is needed because mypy fails to infer the type of the lambda
|
||||
# (the "or" ensures that the lambda returns the same type as the first argument, which is a hash)
|
||||
h = reduce(lambda h, p: h.update(p.encode()) or h, sorted(relative), sha256()) # type: ignore
|
||||
|
||||
# return the unified path
|
||||
return join_path(root, h.hexdigest() + ext)
|
||||
@ -1,27 +0,0 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeakerState:
|
||||
job_id: Optional[str] = None
|
||||
job_kind: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
experiment_id: Optional[str] = None
|
||||
replica_rank: Optional[str] = None
|
||||
leader_replica_hostname: Optional[str] = None
|
||||
leader_replica_node_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
for key, value in os.environ.items():
|
||||
if not key.startswith("BEAKER_"):
|
||||
continue
|
||||
setattr(self, key.lstrip("BEAKER_").lower(), value)
|
||||
|
||||
@property
|
||||
def url(self) -> Optional[str]:
|
||||
if self.job_id:
|
||||
return f"https://beaker.org/jobs/{self.job_id}"
|
||||
return None
|
||||
@ -1,165 +0,0 @@
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from datasets import Dataset, load_dataset
|
||||
from filelock import FileLock
|
||||
|
||||
from olmocr.data.renderpdf import get_pdf_media_box_width_height
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.s3_utils import parse_custom_id, parse_s3_path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Quiet logs from pypdf and smart open
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
logging.getLogger("smart_open").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def list_dataset_files(s3_glob_path: str):
|
||||
"""
|
||||
Lists files in the specified S3 path that match the glob pattern.
|
||||
"""
|
||||
if s3_glob_path.startswith("s3://"):
|
||||
s3 = boto3.client("s3")
|
||||
match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path)
|
||||
if not match:
|
||||
logger.error(f"Invalid S3 path: {s3_glob_path}")
|
||||
raise ValueError(f"Invalid S3 path: {s3_glob_path}")
|
||||
|
||||
bucket, prefix_pattern = match.groups()
|
||||
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
|
||||
paginator = s3.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
||||
|
||||
files = []
|
||||
pattern = re.compile(prefix_pattern.replace("*", ".*"))
|
||||
for page in pages:
|
||||
for obj in page.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if pattern.fullmatch(key):
|
||||
files.append(f"s3://{bucket}/{key}")
|
||||
return files
|
||||
else:
|
||||
return glob.glob(s3_glob_path)
|
||||
|
||||
|
||||
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: Optional[int] = None) -> Dataset:
|
||||
"""
|
||||
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
|
||||
"""
|
||||
all_json_files = list_dataset_files(s3_glob_path)
|
||||
|
||||
if first_n_files:
|
||||
all_json_files = all_json_files[:first_n_files]
|
||||
|
||||
# Use datasets library to load JSON files from S3
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files=all_json_files,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def extract_openai_batch_response(example):
|
||||
custom_id = example.get("custom_id", None)
|
||||
|
||||
# Parse the custom id into an s3 document path and page number (1indexed)
|
||||
s3_path, page_num = parse_custom_id(custom_id)
|
||||
|
||||
response_body = example.get("response", {}).get("body", {})
|
||||
choices = response_body.get("choices", [])
|
||||
response = ""
|
||||
finish_reason = ""
|
||||
if choices:
|
||||
first_choice = choices[0]
|
||||
message = first_choice.get("message", {})
|
||||
response = message.get("content", "")
|
||||
finish_reason = first_choice.get("finish_reason", "")
|
||||
|
||||
# TODO Maybe in the future we can parse the response (which is a structured JSON document itself)
|
||||
# into its own columns
|
||||
|
||||
return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason}
|
||||
|
||||
|
||||
def _cache_s3_file(s3_path: str, local_cache_dir: str):
|
||||
"""
|
||||
Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file.
|
||||
"""
|
||||
bucket, key = parse_s3_path(s3_path)
|
||||
|
||||
# Define the local file path
|
||||
local_file_path = os.path.join(local_cache_dir, bucket + "__" + key.replace("/", "_"))
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
lock_file = f"{local_file_path}.lock"
|
||||
|
||||
# Use a file lock to prevent concurrent writes
|
||||
with FileLock(lock_file):
|
||||
if not os.path.exists(local_file_path):
|
||||
logger.info(f"Downloading {s3_path} to {local_file_path}")
|
||||
s3_client = boto3.client("s3", aws_access_key_id=os.getenv("DS_AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("DS_AWS_SECRET_ACCESS_KEY"))
|
||||
s3_client.download_file(bucket, key, local_file_path)
|
||||
else:
|
||||
pass
|
||||
# logger.info(f"File {local_file_path} already exists, skipping download.")
|
||||
|
||||
return local_file_path
|
||||
|
||||
|
||||
def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset:
|
||||
"""
|
||||
Caches all S3 paths in the dataset to the local cache directory.
|
||||
"""
|
||||
|
||||
# Define the download function to use in parallel processing
|
||||
def cache_file(example):
|
||||
s3_path = example["s3_path"]
|
||||
if s3_path:
|
||||
# Download the file and cache it locally
|
||||
local_path = _cache_s3_file(s3_path, pdf_cache_location)
|
||||
return {"local_pdf_path": local_path}
|
||||
return {"local_pdf_path": None}
|
||||
|
||||
# Map the caching function to the dataset (with parallelism if needed)
|
||||
dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str] = None, num_proc: int = 32) -> Dataset:
|
||||
if pdf_cache_location is None:
|
||||
pdf_cache_location = os.path.join(os.path.expanduser("~"), ".cache", "olmocr_pdfs")
|
||||
|
||||
logger.info("Loading fine tuning dataset from OpenAI style batch responses")
|
||||
response_data = load_jsonl_into_ds(response_glob_path)
|
||||
response_data = response_data["train"]
|
||||
|
||||
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)
|
||||
|
||||
# Don't include data where the model cut off due to a length issue, or moderation issue
|
||||
logger.info("Filtering on finish_reason == stop")
|
||||
final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
|
||||
|
||||
# Cache all the s3_paths that were accessed to a local storage location,
|
||||
final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc)
|
||||
|
||||
# Filter out pages where you cannot get an anchor text generated, to prevent errors during actual training
|
||||
def _can_create_anchor_text(example):
|
||||
try:
|
||||
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=4000)
|
||||
_ = get_pdf_media_box_width_height(example["local_pdf_path"], example["page_num"])
|
||||
return anchor_text is not None
|
||||
except:
|
||||
logger.exception("Could not generate anchor text for file, be sure you have all dependencies installed")
|
||||
return False
|
||||
|
||||
final_dataset = final_dataset.filter(_can_create_anchor_text, num_proc=num_proc)
|
||||
|
||||
return final_dataset
|
||||
@ -1,164 +0,0 @@
|
||||
import base64
|
||||
import random
|
||||
from io import BytesIO
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch # Make sure to import torch as it's used in the DataCollator
|
||||
from PIL import Image
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts import build_finetuning_prompt
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
|
||||
|
||||
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
|
||||
if isinstance(target_longest_image_dim, list):
|
||||
target_longest_image_dim = random.choice(target_longest_image_dim)
|
||||
|
||||
if isinstance(target_anchor_text_len, list):
|
||||
target_anchor_text_len = random.choice(target_anchor_text_len)
|
||||
|
||||
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
||||
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
|
||||
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": base64_page_image},
|
||||
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
||||
],
|
||||
}
|
||||
]
|
||||
# Apply chat template to get the text
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
|
||||
|
||||
# Process inputs using processor
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=[main_image],
|
||||
padding=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Get labels by tokenizing the output text
|
||||
labels = processor(text=[example["response"]], padding=True, return_tensors="np")
|
||||
|
||||
# Append an <|im_end|>\n" to the labels, because this is what it would look like
|
||||
# if we passed the whole message stream in there
|
||||
im_end_tokens = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
|
||||
im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype) # Ensure correct dtype
|
||||
|
||||
# Handle the case where labels['input_ids'] is empty
|
||||
if labels["input_ids"].shape[1] == 0:
|
||||
labels_input_ids_0 = np.array([], dtype=inputs.input_ids.dtype)
|
||||
else:
|
||||
labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype)
|
||||
|
||||
labels["input_ids"] = np.concatenate([labels_input_ids_0, im_end_tokens])
|
||||
labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0)
|
||||
|
||||
# Concatenate input_ids and labels
|
||||
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
|
||||
|
||||
# All columns will participate in attention fully
|
||||
attention_mask = np.ones_like(input_ids)
|
||||
|
||||
# Create labels, masking the input portion with -100
|
||||
labels_full = np.full_like(input_ids, fill_value=-100)
|
||||
labels_full[len(inputs.input_ids[0]) :] = labels.input_ids[0]
|
||||
|
||||
# TODO Maybe cap the max length
|
||||
|
||||
# Return as dict, including pixel_values
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels_full,
|
||||
"pixel_values": inputs.pixel_values,
|
||||
"image_grid_thw": inputs["image_grid_thw"][0],
|
||||
}
|
||||
|
||||
|
||||
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
|
||||
# Process each example in the batch using the helper function
|
||||
processed_examples = []
|
||||
for i in range(len(batch["response"])):
|
||||
example = {"local_pdf_path": batch["local_pdf_path"][i], "page_num": batch["page_num"][i], "response": batch["response"][i]}
|
||||
processed_example = prepare_data_for_qwen2_training(
|
||||
example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
|
||||
)
|
||||
processed_examples.append(processed_example)
|
||||
|
||||
return {
|
||||
"input_ids": [x["input_ids"] for x in processed_examples],
|
||||
"attention_mask": [x["attention_mask"] for x in processed_examples],
|
||||
"labels": [x["labels"] for x in processed_examples],
|
||||
"pixel_values": [x["pixel_values"] for x in processed_examples],
|
||||
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
||||
}
|
||||
|
||||
|
||||
def prepare_data_for_molmo_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
|
||||
if isinstance(target_longest_image_dim, list):
|
||||
target_longest_image_dim = random.choice(target_longest_image_dim)
|
||||
|
||||
if isinstance(target_anchor_text_len, list):
|
||||
target_anchor_text_len = random.choice(target_anchor_text_len)
|
||||
|
||||
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
||||
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
|
||||
|
||||
# Process the input text and image
|
||||
inputs = processor.process(
|
||||
images=[main_image],
|
||||
text=build_finetuning_prompt(anchor_text),
|
||||
)
|
||||
|
||||
# Get labels by tokenizing the output text
|
||||
labels = processor.tokenizer(example["response"], return_tensors="np")["input_ids"][0]
|
||||
# Concatenate input_ids and labels
|
||||
full_input_ids = torch.cat([inputs["input_ids"], torch.from_numpy(labels)], dim=0)
|
||||
|
||||
labels_full = torch.cat([torch.ones_like(inputs["input_ids"]) * -100, torch.from_numpy(labels)], dim=0)
|
||||
|
||||
# Create a full attention mask
|
||||
attention_mask = torch.ones_like(full_input_ids)
|
||||
|
||||
# image_input_idx does not need adjustment as images are inserted before labels
|
||||
image_input_idx = inputs["image_input_idx"]
|
||||
|
||||
return {
|
||||
"input_ids": full_input_ids,
|
||||
"labels": labels_full,
|
||||
"images": inputs["images"],
|
||||
"image_input_idx": image_input_idx,
|
||||
"image_masks": inputs["image_masks"],
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
def batch_prepare_data_for_molmo_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
|
||||
# Assume batch size 1 and process the single example
|
||||
example = {"local_pdf_path": batch["local_pdf_path"][0], "page_num": batch["page_num"][0], "response": batch["response"][0]}
|
||||
processed_example = prepare_data_for_molmo_training(
|
||||
example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
|
||||
)
|
||||
|
||||
# Return in the same format as the qwen2 function
|
||||
return {
|
||||
"input_ids": [processed_example["input_ids"]],
|
||||
"attention_mask": [processed_example["attention_mask"]],
|
||||
"labels": [processed_example["labels"]],
|
||||
"images": [processed_example["images"]],
|
||||
"image_input_idx": [processed_example["image_input_idx"]],
|
||||
"image_masks": [processed_example["image_masks"]],
|
||||
}
|
||||
@ -1,122 +0,0 @@
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
|
||||
import boto3
|
||||
import torch
|
||||
from smart_open import smart_open
|
||||
from tqdm import tqdm
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
|
||||
def download_file_from_s3(bucket_name, key, local_file_path):
|
||||
"""Download a single file from S3."""
|
||||
s3_client.download_file(bucket_name, key, local_file_path)
|
||||
print(f"Downloaded {key} to {local_file_path}")
|
||||
|
||||
|
||||
def download_model_from_s3(bucket_name, model_s3_key, local_model_dir):
|
||||
if not os.path.exists(local_model_dir):
|
||||
os.makedirs(local_model_dir)
|
||||
|
||||
# List objects in the S3 model path
|
||||
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=model_s3_key)
|
||||
objects = response.get("Contents", [])
|
||||
|
||||
# Prepare list of download tasks
|
||||
download_tasks = []
|
||||
for obj in objects:
|
||||
key = obj["Key"]
|
||||
if key.endswith("/"):
|
||||
continue # Skip directories
|
||||
|
||||
local_file_path = os.path.join(local_model_dir, os.path.basename(key))
|
||||
download_tasks.append((bucket_name, key, local_file_path))
|
||||
|
||||
# Use a ThreadPoolExecutor to download files in parallel
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks]
|
||||
|
||||
# Wait for all downloads to complete and handle any exceptions
|
||||
for future in tqdm(concurrent.futures.as_completed(futures)):
|
||||
try:
|
||||
future.result() # This will raise any exceptions encountered during download
|
||||
except Exception as e:
|
||||
print(f"Error downloading file: {e}")
|
||||
|
||||
|
||||
def upload_file_to_s3(local_file_path, bucket_name, s3_key):
|
||||
"""Upload a single file to S3."""
|
||||
try:
|
||||
s3_client.upload_file(local_file_path, bucket_name, s3_key)
|
||||
print(f"Uploaded {local_file_path} to s3://{bucket_name}/{s3_key}")
|
||||
except Exception as e:
|
||||
print(f"Error uploading {local_file_path} to s3://{bucket_name}/{s3_key}: {e}")
|
||||
|
||||
|
||||
def save_model_to_s3(local_model_dir, bucket_name, s3_model_key):
|
||||
"""Upload the model directory to S3 in parallel."""
|
||||
# Collect all file paths to be uploaded
|
||||
upload_tasks = []
|
||||
for root, dirs, files in os.walk(local_model_dir):
|
||||
for file in files:
|
||||
local_file_path = os.path.join(root, file)
|
||||
s3_key = os.path.join(s3_model_key, file)
|
||||
upload_tasks.append((local_file_path, bucket_name, s3_key))
|
||||
|
||||
# Use a ThreadPoolExecutor to upload files in parallel
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(upload_file_to_s3, local_file_path, bucket_name, s3_key) for local_file_path, bucket_name, s3_key in upload_tasks]
|
||||
|
||||
# Wait for all uploads to complete and handle any exceptions
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
future.result() # This will raise any exceptions encountered during upload
|
||||
except Exception as e:
|
||||
print(f"Error during upload: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Fix up a Qwen2VL checkpoint saved on s3 or otherwise, so that it will load properly in vllm/birr")
|
||||
parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Now, download the config.json from the original path and verify the architectures
|
||||
config_path = os.path.join(args.s3_path, "config.json")
|
||||
|
||||
with smart_open(config_path, "r") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
assert config_data["architectures"] == ["Qwen2_5_VLForConditionalGeneration"]
|
||||
|
||||
if config_data["torch_dtype"] == "float32":
|
||||
print("Detected model is float32, this is probably an FSDP checkpoint")
|
||||
print("Saving to _bf16 location with adjusted parameters")
|
||||
|
||||
bucket, prefix = parse_s3_path(args.s3_path)
|
||||
td = "/tmp/qwen2_checkpoint_saving"
|
||||
download_model_from_s3(bucket, prefix, td)
|
||||
|
||||
print("Downloaded entire model from s3, resaving as bfloat16")
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(td)
|
||||
model = model.to(torch.bfloat16)
|
||||
os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)
|
||||
|
||||
print("Saving...")
|
||||
model.save_pretrained(os.path.join(td, "bf16_checkpoint"))
|
||||
|
||||
print("Uploading")
|
||||
save_model_to_s3(os.path.join(td, "bf16_checkpoint"), bucket, prefix.rstrip("/") + "/bf16")
|
||||
|
||||
args.s3_path = args.s3_path.rstrip("/") + "/bf16"
|
||||
|
||||
print("Model updated successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,371 +0,0 @@
|
||||
# Script to generate parquet dataset files to upload to hugging face
|
||||
# Input is a dataset location /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
|
||||
# Each json line has a custom id that looks like {"custom_id": "s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1", ... more data}
|
||||
|
||||
# Fix this script so that it works, and that it will take a path to an input dataset, and sqllite database location
|
||||
# And then it will build a parquet file with rows that look like: "id", "url", "page_number", "response"
|
||||
# Where Id will be the output of parse_pdf_hash plus "-" plus the page number
|
||||
# The url will be the result of get_uri_from_db
|
||||
# Rresponse will be NormalizedEntry.text
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import glob
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
import pandas as pd
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||
"""
|
||||
Extracts a hash from a pretty PDF S3 URL.
|
||||
For example, given:
|
||||
s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1
|
||||
it will return "de80a57e6c57b45796d2e020173227f7eae44232".
|
||||
"""
|
||||
# Allow an optional "-<number>" at the end.
|
||||
if pretty_pdf_path.startswith("s3://ai2-s2-pdfs/"):
|
||||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$"
|
||||
match = re.match(pattern, pretty_pdf_path)
|
||||
if match:
|
||||
return match.group(1) + match.group(2)
|
||||
return None
|
||||
elif pretty_pdf_path.startswith("s3://ai2-oe-data/reganh/iabooks/"):
|
||||
return urlparse(pretty_pdf_path).path.split("/")[-1]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
||||
"""
|
||||
Looks up the URL for the given pdf_hash in the sqlite database.
|
||||
Assumes there is a table called 'pdf_mapping' with a column 'uri'.
|
||||
"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
return result[0].strip() if result and result[0] else None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizedEntry:
|
||||
s3_path: str
|
||||
pagenum: int
|
||||
text: Optional[str]
|
||||
finish_reason: Optional[str]
|
||||
error: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_goldkey(goldkey: str, **kwargs):
|
||||
"""
|
||||
Constructs a NormalizedEntry from a goldkey string.
|
||||
The goldkey is expected to be of the format:
|
||||
<s3_path>-<page_number>
|
||||
"""
|
||||
s3_path = goldkey[: goldkey.rindex("-")]
|
||||
page_num = int(goldkey[goldkey.rindex("-") + 1 :])
|
||||
return NormalizedEntry(s3_path, page_num, **kwargs)
|
||||
|
||||
@property
|
||||
def goldkey(self):
|
||||
return f"{self.s3_path}-{self.pagenum}"
|
||||
|
||||
|
||||
def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||
"""
|
||||
Normalizes a JSON entry from any of the supported formats.
|
||||
It supports:
|
||||
- Birr: looks for an "outputs" field.
|
||||
- Already normalized entries: if they contain s3_path, pagenum, etc.
|
||||
- OpenAI: where the response is in data["response"]["body"]["choices"].
|
||||
- SGLang: where the response is in data["response"]["choices"].
|
||||
"""
|
||||
if "outputs" in data:
|
||||
# Birr case
|
||||
if data["outputs"] is None:
|
||||
text = None
|
||||
finish_reason = None
|
||||
else:
|
||||
text = data["outputs"][0]["text"]
|
||||
finish_reason = data["outputs"][0]["finish_reason"]
|
||||
|
||||
return NormalizedEntry.from_goldkey(
|
||||
goldkey=data["custom_id"],
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
error=data.get("completion_error", None),
|
||||
)
|
||||
elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]):
|
||||
# Already normalized
|
||||
return NormalizedEntry(**data)
|
||||
elif "response" in data and "body" in data["response"] and "choices" in data["response"]["body"]:
|
||||
return NormalizedEntry.from_goldkey(
|
||||
goldkey=data["custom_id"],
|
||||
text=data["response"]["body"]["choices"][0]["message"]["content"],
|
||||
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"],
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported JSON format")
|
||||
|
||||
|
||||
def parse_s3_url(s3_url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
|
||||
"""
|
||||
if not s3_url.startswith("s3://"):
|
||||
raise ValueError(f"Invalid S3 URL: {s3_url}")
|
||||
s3_path = s3_url[5:]
|
||||
bucket, key = s3_path.split("/", 1)
|
||||
return bucket, key
|
||||
|
||||
|
||||
def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
|
||||
"""
|
||||
Downloads the PDF from the given S3 URL into the specified cache directory.
|
||||
The destination filename is based on the parsed PDF hash.
|
||||
Returns the path to the downloaded PDF.
|
||||
"""
|
||||
try:
|
||||
bucket, key = parse_s3_url(s3_url)
|
||||
s3_client = boto3.client("s3")
|
||||
pdf_hash = parse_pdf_hash(s3_url)
|
||||
if not pdf_hash:
|
||||
# Fallback: use a sanitized version of the s3_url
|
||||
pdf_hash = re.sub(r"\W+", "_", s3_url)
|
||||
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
|
||||
# Avoid re-downloading if already exists
|
||||
if not os.path.exists(dest_path):
|
||||
s3_client.download_file(bucket, key, dest_path)
|
||||
return dest_path
|
||||
except Exception as e:
|
||||
print(f"Error downloading {s3_url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
|
||||
Writes a new single-page PDF to the output_pdf_dir using the combined_id as the filename.
|
||||
Returns the relative path to the new PDF (e.g., "pdfs/<combined_id>.pdf").
|
||||
"""
|
||||
try:
|
||||
local_cached_pdf = pdf_cache.get(s3_url)
|
||||
if not local_cached_pdf or not os.path.exists(local_cached_pdf):
|
||||
print(f"Cached PDF not found for {s3_url}")
|
||||
return None
|
||||
reader = PdfReader(local_cached_pdf)
|
||||
# pypdf uses 0-indexed page numbers
|
||||
page_index = page_number - 1
|
||||
if page_index < 0 or page_index >= len(reader.pages):
|
||||
print(f"Page number {page_number} out of range for PDF {s3_url}")
|
||||
return None
|
||||
writer = PdfWriter()
|
||||
writer.add_page(reader.pages[page_index])
|
||||
output_filename = f"{combined_id}.pdf"
|
||||
output_path = os.path.join(output_pdf_dir, output_filename)
|
||||
with open(output_path, "wb") as f_out:
|
||||
writer.write(f_out)
|
||||
# Return the relative path (assuming pdfs/ folder is relative to the parquet file location)
|
||||
return os.path.join("pdfs", output_filename)
|
||||
except Exception as e:
|
||||
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
|
||||
"""
|
||||
Process a single file and return a tuple:
|
||||
(list of valid rows, number of rows skipped due to missing URL or PDF extraction/filtering).
|
||||
For each JSON entry, the function:
|
||||
- Normalizes the JSON.
|
||||
- Skips entries whose response contains the word "resume" (any case) along with either an email address or a phone number.
|
||||
- Extracts the PDF hash and builds the combined id.
|
||||
- Looks up the corresponding URL from the sqlite database.
|
||||
- Extracts the specified page from the cached PDF and writes it to output_pdf_dir.
|
||||
- Outputs a row with "id", "url", "page_number", "response".
|
||||
"""
|
||||
rows = []
|
||||
missing_count = 0
|
||||
email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"
|
||||
phone_regex = r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b"
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Skipping invalid JSON at {file_path}:{line_num} - {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
normalized = normalize_json_entry(data)
|
||||
except Exception as e:
|
||||
print(f"Error normalizing entry at {file_path}:{line_num} - {e}")
|
||||
continue
|
||||
|
||||
# Apply filter: skip if response contains "resume" (any case) and an email or phone number.
|
||||
response_text = normalized.text if normalized.text else ""
|
||||
if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)):
|
||||
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
|
||||
continue
|
||||
|
||||
# Extract the PDF hash from the s3_path.
|
||||
pdf_hash = parse_pdf_hash(normalized.s3_path)
|
||||
if pdf_hash is None:
|
||||
print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}")
|
||||
continue
|
||||
|
||||
# The output id is the pdf hash plus '-' plus the page number.
|
||||
combined_id = f"{pdf_hash}-{normalized.pagenum}"
|
||||
|
||||
# Look up the corresponding URL from the sqlite database.
|
||||
url = get_uri_from_db(db_path, pdf_hash)
|
||||
if not url:
|
||||
print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}")
|
||||
missing_count += 1
|
||||
continue
|
||||
|
||||
# Process PDF: extract the specified page from the cached PDF.
|
||||
local_pdf_path = process_pdf_page(normalized.s3_path, normalized.pagenum, combined_id, output_pdf_dir, pdf_cache)
|
||||
if local_pdf_path is None:
|
||||
print(f"Skipping entry because PDF processing failed for {normalized.s3_path} page {normalized.pagenum} at {file_path}:{line_num}")
|
||||
missing_count += 1
|
||||
continue
|
||||
|
||||
row = {
|
||||
"id": combined_id,
|
||||
"url": url,
|
||||
"page_number": normalized.pagenum,
|
||||
"response": normalized.text,
|
||||
}
|
||||
rows.append(row)
|
||||
except Exception as e:
|
||||
print(f"Error processing file {file_path}: {e}")
|
||||
return rows, missing_count
|
||||
|
||||
|
||||
def scan_file_for_s3_urls(file_path: str) -> Set[str]:
|
||||
"""
|
||||
Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
|
||||
"""
|
||||
urls = set()
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
normalized = normalize_json_entry(data)
|
||||
urls.add(normalized.s3_path)
|
||||
except Exception:
|
||||
# Skip entries that cannot be normalized
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
return urls
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.")
|
||||
parser.add_argument(
|
||||
"input_dataset",
|
||||
help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')",
|
||||
)
|
||||
parser.add_argument("db_path", help="Path to the SQLite database file.")
|
||||
parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
files = glob.glob(args.input_dataset)
|
||||
print(f"Found {len(files)} files matching pattern: {args.input_dataset}")
|
||||
|
||||
# Determine output directory and create 'pdfs' subfolder.
|
||||
output_abs_path = os.path.abspath(args.output)
|
||||
output_dir = os.path.dirname(output_abs_path)
|
||||
pdfs_dir = os.path.join(output_dir, "pdfs")
|
||||
os.makedirs(pdfs_dir, exist_ok=True)
|
||||
|
||||
# Create a temporary directory for caching PDFs.
|
||||
pdf_cache_dir = "/tmp/pdf_cache"
|
||||
os.makedirs(pdf_cache_dir, exist_ok=True)
|
||||
|
||||
print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor.
|
||||
unique_s3_urls: Set[str] = set()
|
||||
print("Scanning input files to collect unique PDF URLs...")
|
||||
num_cpus = multiprocessing.cpu_count()
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 4) as executor:
|
||||
results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files"))
|
||||
for url_set in results:
|
||||
unique_s3_urls |= url_set
|
||||
|
||||
print(f"Found {len(unique_s3_urls)} unique PDF URLs.")
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Step 2: Download all unique PDFs to the cache directory.
|
||||
pdf_cache: Dict[str, str] = {}
|
||||
print("Caching PDFs from S3...")
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor:
|
||||
future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls}
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"):
|
||||
s3_url = future_to_url[future]
|
||||
try:
|
||||
local_path = future.result()
|
||||
if local_path:
|
||||
pdf_cache[s3_url] = local_path
|
||||
else:
|
||||
print(f"Failed to cache PDF for {s3_url}")
|
||||
except Exception as e:
|
||||
print(f"Error caching PDF for {s3_url}: {e}")
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Step 3: Process input files using the precached PDFs.
|
||||
all_rows = []
|
||||
total_missing = 0
|
||||
print("Processing files...")
|
||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||
futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files}
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"):
|
||||
file_path = futures[future]
|
||||
try:
|
||||
rows, missing_count = future.result()
|
||||
all_rows.extend(rows)
|
||||
total_missing += missing_count
|
||||
except Exception as e:
|
||||
print(f"Error processing file {file_path}: {e}")
|
||||
|
||||
if all_rows:
|
||||
df = pd.DataFrame(all_rows)
|
||||
# Set the "id" column as the index.
|
||||
df.set_index("id", inplace=True)
|
||||
df.to_parquet(args.output)
|
||||
|
||||
valid_count = len(df)
|
||||
total_processed = valid_count + total_missing
|
||||
print(f"Successfully wrote {valid_count} rows to {args.output}")
|
||||
print(f"Rows skipped due to missing URL/PDF or filtering: {total_missing} out of {total_processed} processed rows")
|
||||
else:
|
||||
print("No valid rows to write. Exiting.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,97 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
from math import ceil
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
# Configuration
|
||||
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
|
||||
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
|
||||
os.makedirs(tarball_dir, exist_ok=True)
|
||||
repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID
|
||||
|
||||
# Set up logging to file
|
||||
logging.basicConfig(filename="upload.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
|
||||
def process_chunk(args):
|
||||
"""
|
||||
Worker function to create a tar.gz file for a given chunk.
|
||||
Returns a tuple: (chunk_index, success (bool), message).
|
||||
"""
|
||||
chunk_index, chunk_files = args
|
||||
tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz"
|
||||
tarball_path = os.path.join(tarball_dir, tarball_name)
|
||||
|
||||
try:
|
||||
with tarfile.open(tarball_path, "w:gz") as tar:
|
||||
for pdf_filename in chunk_files:
|
||||
pdf_path = os.path.join(pdf_dir, pdf_filename)
|
||||
# Add the file with its basename to maintain a flat structure
|
||||
tar.add(pdf_path, arcname=pdf_filename)
|
||||
logging.info(f"Chunk {chunk_index:04d}: Created '{tarball_name}' with {len(chunk_files)} PDFs.")
|
||||
return chunk_index, True, "Success"
|
||||
except Exception as e:
|
||||
error_msg = f"Chunk {chunk_index:04d}: Error creating '{tarball_name}': {e}"
|
||||
logging.error(error_msg)
|
||||
return chunk_index, False, error_msg
|
||||
|
||||
|
||||
def main():
|
||||
# List all PDF files (assuming a flat directory)
|
||||
try:
|
||||
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")])
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing PDFs in '{pdf_dir}': {e}")
|
||||
return
|
||||
|
||||
total_files = len(pdf_files)
|
||||
chunk_size = 5000
|
||||
total_chunks = ceil(total_files / chunk_size)
|
||||
logging.info(f"Found {total_files} PDFs; dividing into {total_chunks} chunks of up to {chunk_size} files each.")
|
||||
|
||||
# # Enumerate chunks (starting at 0000)
|
||||
# chunks = []
|
||||
# for idx in range(total_chunks):
|
||||
# start = idx * chunk_size
|
||||
# end = start + chunk_size
|
||||
# chunk_files = pdf_files[start:end]
|
||||
# chunks.append((idx, chunk_files))
|
||||
|
||||
# # Create tarballs in parallel
|
||||
# results = []
|
||||
# with ProcessPoolExecutor() as executor:
|
||||
# futures = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||
# for future in tqdm(as_completed(futures), total=len(futures), desc="Creating tarballs"):
|
||||
# try:
|
||||
# result = future.result()
|
||||
# results.append(result)
|
||||
# chunk_index, success, message = result
|
||||
# if not success:
|
||||
# logging.error(f"Chunk {chunk_index:04d} failed: {message}")
|
||||
# except Exception as e:
|
||||
# logging.error(f"Unexpected error processing a chunk: {e}")
|
||||
|
||||
# # Abort upload if any tarball creation failed
|
||||
# failed_chunks = [r for r in results if not r[1]]
|
||||
# if failed_chunks:
|
||||
# logging.error(f"{len(failed_chunks)} chunk(s) failed to create. Aborting upload.")
|
||||
# return
|
||||
|
||||
# All tarballs created successfully; now upload the entire tarball directory
|
||||
|
||||
api = HfApi()
|
||||
logging.info("Starting upload of tarballs folder to Hugging Face Hub...")
|
||||
# This will upload all files in tarball_dir to the repo under "pdf_tarballs"
|
||||
api.upload_large_folder(
|
||||
folder_path=tarball_dir,
|
||||
repo_id=repo_id,
|
||||
# path_in_repo="pdf_tarballs",
|
||||
repo_type="dataset",
|
||||
)
|
||||
logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,138 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import sqlite3
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import boto3
|
||||
from tqdm import tqdm
|
||||
from warcio.archiveiterator import ArchiveIterator
|
||||
|
||||
|
||||
def parse_s3_path(s3_path):
|
||||
"""
|
||||
Parses an S3 path of the form s3://bucket/prefix and returns the bucket and prefix.
|
||||
"""
|
||||
if not s3_path.startswith("s3://"):
|
||||
raise ValueError("S3 path must start with s3://")
|
||||
without_prefix = s3_path[5:]
|
||||
parts = without_prefix.split("/", 1)
|
||||
bucket = parts[0]
|
||||
prefix = parts[1] if len(parts) > 1 else ""
|
||||
return bucket, prefix
|
||||
|
||||
|
||||
def list_s3_warc_objects(s3_path, suffix=".warc.gz"):
|
||||
"""
|
||||
Lists all objects under the given S3 path that end with the provided suffix.
|
||||
Uses a paginator to handle large result sets.
|
||||
"""
|
||||
bucket, prefix = parse_s3_path(s3_path)
|
||||
s3_client = boto3.client("s3")
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
warc_keys = []
|
||||
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
key = obj["Key"]
|
||||
if key.endswith(suffix):
|
||||
warc_keys.append(key)
|
||||
return bucket, warc_keys, s3_client
|
||||
|
||||
|
||||
def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
|
||||
"""
|
||||
Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request,
|
||||
and extracts the first response record's target URI from the HTTP headers.
|
||||
"""
|
||||
target_uri = None
|
||||
try:
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key, Range=f"bytes=0-{head_bytes-1}")
|
||||
stream = response["Body"]
|
||||
for record in ArchiveIterator(stream):
|
||||
for name, value in record.rec_headers.headers:
|
||||
if name == "WARC-Target-URI":
|
||||
target_uri = value
|
||||
break
|
||||
if target_uri:
|
||||
break # Only use the first valid response record
|
||||
except Exception as e:
|
||||
tqdm.write(f"Error processing s3://{bucket}/{key}: {e}")
|
||||
return target_uri
|
||||
|
||||
|
||||
def create_db(db_path):
|
||||
"""
|
||||
Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists,
|
||||
including an index on pdf_hash.
|
||||
"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS pdf_mapping (
|
||||
pdf_hash TEXT PRIMARY KEY,
|
||||
uri TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def process_warc_file(key, bucket, s3_client):
|
||||
"""
|
||||
Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri)
|
||||
if successful, otherwise returns None.
|
||||
"""
|
||||
uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576)
|
||||
if uri:
|
||||
# Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf.
|
||||
pdf_hash = key.split("/")[-1].replace(".warc.gz", ".pdf")
|
||||
return (pdf_hash, uri)
|
||||
else:
|
||||
tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}")
|
||||
return None
|
||||
|
||||
|
||||
def process_s3_folder(s3_path, db_path):
|
||||
"""
|
||||
Lists all .warc.gz files under the provided S3 path, then processes each file in parallel
|
||||
to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's
|
||||
basename with .warc.gz replaced by .pdf) is stored in the SQLite database.
|
||||
"""
|
||||
bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix=".warc.gz")
|
||||
conn = create_db(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Process WARC files concurrently using ThreadPoolExecutor.
|
||||
results = []
|
||||
func = partial(process_warc_file, bucket=bucket, s3_client=s3_client)
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for result in tqdm(executor.map(func, warc_keys), total=len(warc_keys), desc="Processing S3 WARC files"):
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
|
||||
# Bulk insert into the database.
|
||||
conn.execute("BEGIN")
|
||||
for pdf_hash, uri in results:
|
||||
cursor.execute("INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", (pdf_hash, uri))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files.")
|
||||
parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files")
|
||||
parser.add_argument("db_file", help="Path for the output SQLite database file")
|
||||
args = parser.parse_args()
|
||||
process_s3_folder(args.s3_path, args.db_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,66 +0,0 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from PIL import Image
|
||||
from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.prompts.prompts import build_openai_silver_data_prompt
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(model_name: str):
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
# If it doesn't load, change the type:mrope key to "default"
|
||||
|
||||
# model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
|
||||
model.eval()
|
||||
|
||||
# local_pdf_path = os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "horribleocr.pdf")
|
||||
local_pdf_path = "/root/brochure.pdf"
|
||||
page = 1
|
||||
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
|
||||
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=[main_image],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
|
||||
generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(inputs["input_ids"], output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text[0])
|
||||
|
||||
|
||||
def main():
|
||||
run_inference(model_name="Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,50 +0,0 @@
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from olmocr.train.core.cli import make_cli
|
||||
from olmocr.train.core.config import TrainConfig
|
||||
|
||||
from .utils import make_dataset
|
||||
|
||||
|
||||
def main():
|
||||
train_config = make_cli(TrainConfig) # pyright: ignore
|
||||
|
||||
processor = AutoProcessor.from_pretrained(train_config.model.name_or_path, trust_remote_code=True)
|
||||
train_dataset, valid_dataset = make_dataset(train_config, processor)
|
||||
|
||||
print("Training dataset........")
|
||||
print(train_dataset)
|
||||
|
||||
train_example = train_dataset[0]
|
||||
print(train_example)
|
||||
print({(x, y.shape) for x, y in train_example.items()})
|
||||
print("\nTokens")
|
||||
print(processor.tokenizer.batch_decode(train_example["input_ids"]))
|
||||
|
||||
print("\n\n")
|
||||
|
||||
print("Validation dataset........")
|
||||
print(valid_dataset)
|
||||
print(valid_dataset[list(valid_dataset.keys())[0]][0])
|
||||
print("\n\n")
|
||||
|
||||
print("Datasets loaded into hugging face cache directory")
|
||||
|
||||
# data_collator = TruncatingCollator(
|
||||
# max_length=4096
|
||||
# )
|
||||
|
||||
# train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=data_collator)
|
||||
# max_seen_len = 0
|
||||
# for index, entry in tqdm(enumerate(train_dataloader)):
|
||||
# if index == 0:
|
||||
# print(entry)
|
||||
|
||||
# num_input_tokens = entry["input_ids"].shape[1]
|
||||
# max_seen_len = max(max_seen_len, num_input_tokens)
|
||||
|
||||
# print(max_seen_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,59 +0,0 @@
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MolmoConfig(PretrainedConfig):
|
||||
model_type = "molmo"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50304,
|
||||
embedding_size=50304,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
layer_norm_eps: float = 1e-5,
|
||||
rope_theta=10000.0,
|
||||
clip_qkv=None,
|
||||
qkv_bias: bool = False,
|
||||
weight_tying: bool = False,
|
||||
use_position_ids: bool = True,
|
||||
tie_word_embeddings: bool = True,
|
||||
attention_layer_norm: bool = False,
|
||||
norm_after: bool = False,
|
||||
layer_norm_type: str = "rms",
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.weight_tying = weight_tying
|
||||
self.use_position_ids = use_position_ids
|
||||
self.attention_layer_norm = attention_layer_norm
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.clip_qkv = clip_qkv
|
||||
self.qkv_bias = qkv_bias
|
||||
self.norm_after = norm_after
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.layer_norm_type = layer_norm_type
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
MolmoConfig.register_for_auto_class()
|
||||
@ -1,497 +0,0 @@
|
||||
"""Image processor class for Molmo"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import convert_image_dtype
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput
|
||||
from transformers.processing_utils import ImagesKwargs
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width, value=0):
|
||||
height, width = image.shape[:2]
|
||||
after_padding_width = target_width - offset_width - width
|
||||
after_padding_height = target_height - offset_height - height
|
||||
return np.pad(image, [[offset_height, after_padding_height], [offset_width, after_padding_width], [0, 0]], constant_values=value)
|
||||
|
||||
|
||||
def normalize_image(image, offset, scale):
|
||||
image -= np.array(offset, dtype=np.float32)[None, None, :]
|
||||
image /= np.array(scale, dtype=np.float32)[None, None, :]
|
||||
return image
|
||||
|
||||
|
||||
def resize_and_pad(
|
||||
image,
|
||||
desired_output_size,
|
||||
resize_method="torch-bilinear",
|
||||
pad_value=0,
|
||||
normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
):
|
||||
desired_height, desired_width = desired_output_size
|
||||
height, width = image.shape[:2]
|
||||
|
||||
# Cast into float32 since the training code did this in float32 and it (very rarely) effects
|
||||
# the results after rounding.
|
||||
image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
|
||||
image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
|
||||
image_scale = min(image_scale_x, image_scale_y)
|
||||
scaled_height = int(np.array(height, np.float32) * image_scale)
|
||||
scaled_width = int(np.array(width, np.float32) * image_scale)
|
||||
|
||||
if resize_method == "tensorflow":
|
||||
# This how the original training code did resizing, it can produce slightly different
|
||||
# results then using torch resize so we keep it just in case
|
||||
import tensorflow as tf
|
||||
|
||||
image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
|
||||
image = tf.image.resize(
|
||||
image,
|
||||
[scaled_height, scaled_width],
|
||||
method=tf.image.ResizeMethod.BILINEAR,
|
||||
antialias=True,
|
||||
)
|
||||
image = tf.clip_by_value(image, 0.0, 1.0)
|
||||
image = image.numpy()
|
||||
elif resize_method == "torch-bilinear":
|
||||
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
||||
image = convert_image_dtype(image) # resize in float32 to match the training code
|
||||
image = torchvision.transforms.Resize([scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True)(image)
|
||||
image = torch.clip(image, 0.0, 1.0)
|
||||
image = torch.permute(image, [1, 2, 0]).numpy()
|
||||
else:
|
||||
raise NotImplementedError(resize_method)
|
||||
|
||||
top_pad = (desired_height - scaled_height) // 2
|
||||
left_pad = (desired_width - scaled_width) // 2
|
||||
padding = [[top_pad, desired_height - scaled_height - top_pad], [left_pad, desired_width - scaled_width - left_pad], [0, 0]]
|
||||
image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
|
||||
image = np.pad(image, padding, constant_values=pad_value)
|
||||
if normalize:
|
||||
image = normalize_image(image, offset=image_mean, scale=image_std)
|
||||
return image, image_mask
|
||||
|
||||
|
||||
def select_tiling(h, w, patch_size, max_num_patches):
|
||||
"""Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
||||
original_size = np.stack([h, w]) # [1, 2]
|
||||
original_res = h * w
|
||||
tilings = []
|
||||
for i in range(1, max_num_patches + 1):
|
||||
for j in range(1, max_num_patches + 1):
|
||||
if i * j <= max_num_patches:
|
||||
tilings.append((i, j))
|
||||
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
||||
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
|
||||
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
||||
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
||||
|
||||
# How much we would need to scale the image to fit exactly in each tiling
|
||||
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
||||
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
|
||||
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
||||
if np.all(required_scale < 1):
|
||||
# We are forced to downscale, so try to minimize the amount of downscaling
|
||||
ix = np.argmax(required_scale)
|
||||
else:
|
||||
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
||||
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
||||
ix = np.argmin(required_scale)
|
||||
return candidate_tilings[ix]
|
||||
|
||||
|
||||
class MolmoImagesKwargs(ImagesKwargs, total=False):
|
||||
max_crops: Optional[int]
|
||||
overlap_margins: Optional[List[int]]
|
||||
base_image_input_size: Optional[List[int]]
|
||||
image_token_length_w: Optional[int]
|
||||
image_token_length_h: Optional[int]
|
||||
image_patch_size: Optional[int]
|
||||
image_padding_mask: Optional[bool]
|
||||
|
||||
|
||||
class MolmoImageProcessor(BaseImageProcessor):
|
||||
"""Preprocess images and multi-model inputs"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_crops: int = 12,
|
||||
overlap_margins: List[int] = (4, 4),
|
||||
base_image_input_size: List[int] = (336, 336),
|
||||
image_token_length_w: int = 12,
|
||||
image_token_length_h: int = 12,
|
||||
image_patch_size: int = 14,
|
||||
image_padding_mask: bool = True,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.max_crops = max_crops
|
||||
self.overlap_margins = overlap_margins
|
||||
self.base_image_input_size = base_image_input_size
|
||||
self.image_token_length_w = image_token_length_w
|
||||
self.image_token_length_h = image_token_length_h
|
||||
self.image_patch_size = image_patch_size
|
||||
self.image_padding_mask = image_padding_mask
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
|
||||
def image_to_patches_and_tokens(
|
||||
self,
|
||||
image: ImageInput,
|
||||
image_patch_token_id: int,
|
||||
image_col_token_id: int,
|
||||
image_start_token_id: int,
|
||||
image_end_token_id: int,
|
||||
max_crops: Optional[int] = None,
|
||||
overlap_margins: Optional[List[int]] = None,
|
||||
base_image_input_size: Optional[Union[int, List[int]]] = None,
|
||||
image_token_length_w: Optional[int] = None,
|
||||
image_token_length_h: Optional[int] = None,
|
||||
image_patch_size: Optional[int] = None,
|
||||
):
|
||||
if isinstance(base_image_input_size, int):
|
||||
base_image_input_size = (base_image_input_size, base_image_input_size)
|
||||
|
||||
base_image_input_d = image_patch_size
|
||||
tokens_per_image = image_token_length_w * image_token_length_h
|
||||
image_base_patch_w = base_image_input_size[1] // base_image_input_d
|
||||
image_base_patch_h = base_image_input_size[0] // base_image_input_d
|
||||
|
||||
original_image_h, original_image_w = image.shape[:2]
|
||||
crop_size = base_image_input_size[0]
|
||||
|
||||
# Discard this many patches from the (left/top, right/bottom) of crops
|
||||
left_margin, right_margin = overlap_margins
|
||||
# left_margin, right_margin = 2, 2
|
||||
assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
|
||||
total_margin_pixels = base_image_input_d * (right_margin + left_margin) # pixels removed per dim
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
||||
crop_window_size = crop_window_patches * base_image_input_d
|
||||
tiling = select_tiling(original_image_h - total_margin_pixels, original_image_w - total_margin_pixels, crop_window_size, max_crops)
|
||||
src, img_mask = resize_and_pad(image, [tiling[0] * crop_window_size + total_margin_pixels, tiling[1] * crop_window_size + total_margin_pixels])
|
||||
|
||||
# Now we have to split the image into crops, while keeping track of how each patch in the
|
||||
# each crop should be ordered in the global image, this require a lot of tricky booking
|
||||
n_crops = tiling[0] * tiling[1]
|
||||
patches_arr = []
|
||||
mask_arr = []
|
||||
patch_ordering_arr = []
|
||||
|
||||
# We assume 2x2 pooling, but can allow padding the right/bottom with extra
|
||||
# patches if the number of patches per side is not even
|
||||
assert (crop_patches + 1) // 2 == image_token_length_h
|
||||
assert (crop_patches + 1) // 2 == image_token_length_w
|
||||
on = 0
|
||||
on_patch = 0
|
||||
for i in range(tiling[0]):
|
||||
y0 = i * crop_window_size
|
||||
if i == 0:
|
||||
crop_y0 = 0
|
||||
else:
|
||||
crop_y0 = left_margin // 2
|
||||
|
||||
crop_h = image_base_patch_h - (right_margin + left_margin)
|
||||
if i == 0:
|
||||
crop_h += left_margin
|
||||
if i == (tiling[0] - 1):
|
||||
crop_h += right_margin
|
||||
for j in range(tiling[1]):
|
||||
x0 = j * crop_window_size
|
||||
if j == 0:
|
||||
crop_x0 = 0
|
||||
else:
|
||||
crop_x0 = left_margin // 2
|
||||
|
||||
crop_w = image_base_patch_w - (right_margin + left_margin)
|
||||
if j == 0:
|
||||
crop_w += left_margin
|
||||
if j == (tiling[1] - 1):
|
||||
crop_w += right_margin
|
||||
|
||||
pooled_w = (crop_w + 1) // 2
|
||||
pooled_h = (crop_h + 1) // 2
|
||||
patch_ordering_arr.append(
|
||||
pad_to_bounding_box(
|
||||
np.reshape(np.arange(on, on + pooled_h * pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)),
|
||||
crop_y0,
|
||||
crop_x0,
|
||||
image_token_length_h,
|
||||
image_token_length_w,
|
||||
value=-1,
|
||||
)[:, :, 0]
|
||||
)
|
||||
patches_arr.append(src[y0 : y0 + crop_size, x0 : x0 + crop_size])
|
||||
mask_arr.append(img_mask[y0 : y0 + crop_size, x0 : x0 + crop_size])
|
||||
|
||||
on += pooled_h * pooled_w
|
||||
on_patch += 1
|
||||
patches = np.stack(patches_arr)
|
||||
patch_ordering = np.stack(patch_ordering_arr)
|
||||
img_mask = np.stack(mask_arr)
|
||||
|
||||
# Switch to [n_crops, n_patches, pixels_per_patch] format
|
||||
image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
|
||||
patches = einops.rearrange(
|
||||
patches, "p (h dh) (w dw) c -> p (h w) (dh dw c)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w
|
||||
)
|
||||
img_mask = einops.rearrange(
|
||||
img_mask, "p (h dh) (w dw) -> p (h w) (dh dw)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w
|
||||
)
|
||||
|
||||
img_mask = img_mask.astype(np.float32).mean(axis=-1)
|
||||
patch_ordering = np.reshape(patch_ordering, [-1])
|
||||
valid = patch_ordering >= 0
|
||||
|
||||
# Transpose order, to get left-to-right order instead of crop-by-crop order
|
||||
patch_ordering_rh = np.reshape(patch_ordering, [tiling[0], tiling[1], image_token_length_h, image_token_length_w])
|
||||
patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
|
||||
patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
|
||||
|
||||
# The transpose will screw up which patches are masked, project the
|
||||
# new order into sparse structure of `patch_ordering` to fix this
|
||||
patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
|
||||
|
||||
# Now build the output tokens
|
||||
h = tiling[0] * crop_window_patches + (right_margin + left_margin)
|
||||
w = tiling[1] * crop_window_patches + (right_margin + left_margin)
|
||||
per_row = np.full(
|
||||
((w + 1) // 2,),
|
||||
image_patch_token_id,
|
||||
)
|
||||
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
|
||||
|
||||
joint = np.tile(per_row, [(h + 1) // 2])
|
||||
joint = [[image_start_token_id], joint, [image_end_token_id]]
|
||||
|
||||
# Finally do the same for the global image
|
||||
resized, _ = resize_and_pad(image, base_image_input_size)
|
||||
resized = einops.rearrange(
|
||||
resized, "(h dh) (w dw) c -> (h w) (dh dw c)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w
|
||||
)
|
||||
patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)
|
||||
|
||||
# Global image goes first, so the order of patches in previous crops gets increased
|
||||
patch_ordering = np.where(patch_ordering >= 0, patch_ordering + tokens_per_image, -1)
|
||||
patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0)
|
||||
per_row = np.full(
|
||||
(image_token_length_w,),
|
||||
image_patch_token_id,
|
||||
)
|
||||
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
|
||||
extra_tokens = np.tile(per_row, [image_token_length_h])
|
||||
joint = [
|
||||
[image_start_token_id],
|
||||
extra_tokens,
|
||||
[image_end_token_id],
|
||||
] + joint
|
||||
|
||||
joint = np.concatenate(joint, 0)
|
||||
img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
|
||||
return patches, joint, patch_ordering, img_mask
|
||||
|
||||
def build_image_input_idx(
|
||||
self,
|
||||
image_tokens: np.ndarray,
|
||||
patch_order: np.ndarray,
|
||||
image_patch_token_id: int,
|
||||
no_image: Optional[bool] = None,
|
||||
image_token_length_w: Optional[int] = None,
|
||||
image_token_length_h: Optional[int] = None,
|
||||
):
|
||||
"""Converts `patch_order` into a mapping of token_id -> patch_id"""
|
||||
|
||||
tokens_per_image = image_token_length_w * image_token_length_h
|
||||
if no_image is not None and no_image:
|
||||
return np.zeros((0, tokens_per_image), np.int32)
|
||||
|
||||
# Indices to insert the patches
|
||||
image_input_idx = image_tokens == image_patch_token_id
|
||||
image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
|
||||
|
||||
if patch_order is not None:
|
||||
n_tokens = image_input_idx.shape[0]
|
||||
patch_order = np.reshape(patch_order, [-1])
|
||||
n_patches = patch_order.shape[0]
|
||||
|
||||
valid = patch_order >= 0
|
||||
n_valid_patches = valid.sum()
|
||||
assert len(image_input_idx) == n_valid_patches
|
||||
|
||||
sorted_patch_ixs = np.zeros([n_tokens], np.int32)
|
||||
sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32)
|
||||
|
||||
# Project the inverted mapping into same sparse structure
|
||||
sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
|
||||
sorted_patch_ixs_ex[valid] = sorted_patch_ixs
|
||||
|
||||
# Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
|
||||
valid = (sorted_patch_ixs_ex >= 0).astype(np.int32)
|
||||
image_input_idx = image_input_idx[sorted_patch_ixs_ex * valid]
|
||||
image_input_idx = image_input_idx * valid - 100 * (1 - valid)
|
||||
image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
|
||||
return image_input_idx
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
image_patch_token_id: int,
|
||||
image_col_token_id: int,
|
||||
image_start_token_id: int,
|
||||
image_end_token_id: int,
|
||||
max_crops: Optional[int] = None,
|
||||
overlap_margins: Optional[List[int]] = None,
|
||||
base_image_input_size: Optional[Union[int, List[int]]] = None,
|
||||
image_token_length_w: Optional[int] = None,
|
||||
image_token_length_h: Optional[int] = None,
|
||||
image_patch_size: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses an image
|
||||
|
||||
Returns:
|
||||
crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
|
||||
change between images but the other dimension are fixed
|
||||
tokens: (n_tokens,) int32 tokens, pad tokens indicate where to insert the
|
||||
patch features, might include other special tokens as well
|
||||
image_idx: (n_crops, n_patches) index in `tokens` to put the patch features from the
|
||||
crops after pooling, negative values indicates patches features to exclude
|
||||
padding_mask: (n_crops, n_patches) what percent of each crop is padding, can be None
|
||||
if the image mask is not being used.
|
||||
"""
|
||||
|
||||
max_crops = max_crops or self.max_crops
|
||||
overlap_margins = overlap_margins or self.overlap_margins
|
||||
base_image_input_size = base_image_input_size or self.base_image_input_size
|
||||
image_token_length_w = image_token_length_w or self.image_token_length_w
|
||||
image_token_length_h = image_token_length_h or self.image_token_length_h
|
||||
image_patch_size = image_patch_size or self.image_patch_size
|
||||
|
||||
crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(
|
||||
image,
|
||||
image_patch_token_id,
|
||||
image_col_token_id,
|
||||
image_start_token_id,
|
||||
image_end_token_id,
|
||||
max_crops,
|
||||
overlap_margins,
|
||||
base_image_input_size,
|
||||
image_token_length_w,
|
||||
image_token_length_h,
|
||||
image_patch_size,
|
||||
)
|
||||
patch_idx = self.build_image_input_idx(
|
||||
image_tokens,
|
||||
patch_ordering,
|
||||
image_patch_token_id,
|
||||
image_token_length_w=image_token_length_w,
|
||||
image_token_length_h=image_token_length_h,
|
||||
)
|
||||
return crops, image_tokens, patch_idx, img_mask
|
||||
|
||||
def multimodal_preprocess(
|
||||
self,
|
||||
images: np.ndarray,
|
||||
tokens: List[int],
|
||||
image_idx: np.ndarray,
|
||||
sequence_length: int,
|
||||
image_patch_token_id: int,
|
||||
image_col_token_id: int,
|
||||
image_start_token_id: int,
|
||||
image_end_token_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""Merge images and text tokens into multi-modal features for the model
|
||||
|
||||
:param images: images to use as input
|
||||
:param tokens: input text tokens
|
||||
:param image_idx: where to insert the images into `tokens`
|
||||
:params image_patch_token_id: id to use of tokens that will contain image features
|
||||
:params image_col_token_id: token id for image column special tokens
|
||||
:params image_start_token_id: token id for image start special tokens
|
||||
:params image_end_token_id: token id for image end special tokens
|
||||
:params kwargs: override preprocessor default args
|
||||
"""
|
||||
max_total_crops = kwargs.get("max_crops") or self.max_crops
|
||||
image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w
|
||||
image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h
|
||||
image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size
|
||||
base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size
|
||||
image_num_patch = (
|
||||
base_image_input_size[0] // image_patch_size,
|
||||
base_image_input_size[1] // image_patch_size,
|
||||
)
|
||||
image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask
|
||||
|
||||
tokens_per_image = image_token_length_w * image_token_length_h
|
||||
n_pixels = image_patch_size * image_patch_size * 3
|
||||
n_patches = image_num_patch[0] * image_num_patch[1]
|
||||
|
||||
if images is None:
|
||||
return {
|
||||
"input_ids": tokens,
|
||||
}
|
||||
else:
|
||||
n = len(images)
|
||||
all_crops = []
|
||||
all_image_idx = []
|
||||
out_tokens = []
|
||||
all_crop_masks = []
|
||||
|
||||
for ix in range(n):
|
||||
token_ix = image_idx[ix]
|
||||
crops, image_tokens, patch_idx, img_mask = self.preprocess(
|
||||
images[ix],
|
||||
image_patch_token_id,
|
||||
image_col_token_id,
|
||||
image_start_token_id,
|
||||
image_end_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if token_ix == -1: # -1 is an image inserted at the very start
|
||||
start = 0
|
||||
token_ix = 0
|
||||
end = 0
|
||||
else:
|
||||
start = 0 if ix == 0 else image_idx[ix - 1] + 1
|
||||
end = token_ix + 1
|
||||
|
||||
all_image_idx.append(patch_idx + token_ix)
|
||||
all_crops.append(crops)
|
||||
out_tokens.append(tokens[start:token_ix])
|
||||
out_tokens.append(image_tokens)
|
||||
if ix == (n - 1):
|
||||
out_tokens.append(tokens[end:])
|
||||
if image_padding_mask:
|
||||
all_crop_masks.append(img_mask)
|
||||
|
||||
input_ids = np.concatenate(out_tokens, 0)
|
||||
images = np.concatenate(all_crops, 0)
|
||||
image_input_idx = np.concatenate(all_image_idx, 0)
|
||||
if image_padding_mask:
|
||||
image_masks = np.concatenate(all_crop_masks, 0)
|
||||
else:
|
||||
image_masks = None
|
||||
|
||||
out = {"input_ids": input_ids, "images": images, "image_input_idx": image_input_idx}
|
||||
if image_masks is not None:
|
||||
out["image_masks"] = image_masks
|
||||
return out
|
||||
|
||||
|
||||
MolmoImageProcessor.register_for_auto_class()
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,184 +0,0 @@
|
||||
"""
|
||||
Processor class for Molmo.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from PIL import ImageOps
|
||||
from PIL.Image import Image
|
||||
|
||||
try:
|
||||
from typing import Unpack
|
||||
except ImportError:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import logging
|
||||
|
||||
from .image_preprocessing_molmo import MolmoImageProcessor, MolmoImagesKwargs
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
||||
DEFAULT_IM_COL_TOKEN = "<im_col>"
|
||||
IMAGE_PROMPT = "<|image|>"
|
||||
|
||||
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
|
||||
|
||||
|
||||
def get_special_token_ids(tokenizer):
|
||||
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
|
||||
assert len(ids) == len(EXTRA_TOKENS)
|
||||
return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
|
||||
|
||||
|
||||
class MolmoTextKwargs(TextKwargs, total=False):
|
||||
style: Optional[str]
|
||||
system_prompt: Optional[str]
|
||||
message_format: Optional[str]
|
||||
always_start_with_space: Optional[bool]
|
||||
sequence_length: Optional[int]
|
||||
|
||||
|
||||
class MolmoProcessorKwargs(ProcessingKwargs, total=False):
|
||||
text_kwargs: MolmoTextKwargs
|
||||
images_kwargs: MolmoImagesKwargs
|
||||
_defaults = {
|
||||
"images_kwargs": {
|
||||
"max_crops": 12,
|
||||
"overlap_margins": [4, 4],
|
||||
"base_image_input_size": [336, 336],
|
||||
"image_token_length_w": 12,
|
||||
"image_token_length_h": 12,
|
||||
"image_patch_size": 14,
|
||||
"image_padding_mask": True,
|
||||
},
|
||||
"text_kwargs": {
|
||||
"style": "long_caption",
|
||||
"system_prompt": "none",
|
||||
"message_format": "role",
|
||||
"always_start_with_space": True,
|
||||
"sequence_length": 1536,
|
||||
"padding": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MolmoProcessor(ProcessorMixin):
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
|
||||
|
||||
def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer: AutoTokenizer = None, **kwargs):
|
||||
# self.image_processor = image_processor
|
||||
# self.tokenizer = tokenizer
|
||||
super().__init__(image_processor, tokenizer)
|
||||
self._special_tokens = None
|
||||
|
||||
@property
|
||||
def special_token_ids(self):
|
||||
if self._special_tokens is None:
|
||||
self._special_tokens = get_special_token_ids(self.tokenizer)
|
||||
return self._special_tokens
|
||||
|
||||
def get_tokens_input(self, prompt, message_format, always_start_with_space):
|
||||
if message_format == "none" or message_format is None:
|
||||
pass
|
||||
elif message_format == "role":
|
||||
prompt = "User: " + prompt + " Assistant:"
|
||||
else:
|
||||
raise NotImplementedError(f"Message format {message_format} not implemented")
|
||||
|
||||
if always_start_with_space:
|
||||
prompt = " " + prompt
|
||||
|
||||
tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
|
||||
|
||||
return tokens
|
||||
|
||||
def process(
|
||||
self,
|
||||
text: TextInput = None,
|
||||
images: ImageInput = None,
|
||||
*,
|
||||
tokens: Optional[PreTokenizedInput] = None,
|
||||
**kwargs: Unpack[MolmoProcessorKwargs],
|
||||
):
|
||||
output_kwargs = self._merge_kwargs(
|
||||
MolmoProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if tokens is None:
|
||||
tokens = self.get_tokens_input(
|
||||
text,
|
||||
output_kwargs["text_kwargs"]["message_format"],
|
||||
output_kwargs["text_kwargs"]["always_start_with_space"],
|
||||
)
|
||||
|
||||
image_token_id = self.special_token_ids[IMAGE_PROMPT]
|
||||
|
||||
if images is not None:
|
||||
if not isinstance(images, (list, tuple)):
|
||||
images = [images]
|
||||
image_arrays = []
|
||||
for image in images:
|
||||
if isinstance(image, Image):
|
||||
image = image.convert("RGB")
|
||||
# Handle images with EXIF orientation tags, which PIL will ignore by default
|
||||
# https://github.com/python-pillow/Pillow/issues/4703
|
||||
img = ImageOps.exif_transpose(image)
|
||||
image_arrays.append(np.array(image))
|
||||
else:
|
||||
assert len(image.shape) == 3 and image.shape[-1] == 3
|
||||
image_arrays.append(image.astype(np.uint8))
|
||||
images = image_arrays
|
||||
# For now only support inserting images at the start
|
||||
image_idx = [-1] * len(images)
|
||||
else:
|
||||
image_idx = None
|
||||
|
||||
sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
|
||||
|
||||
image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
|
||||
image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
|
||||
image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
|
||||
image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
|
||||
out = self.image_processor.multimodal_preprocess(
|
||||
images=images,
|
||||
image_idx=image_idx,
|
||||
tokens=np.asarray(tokens).astype(np.int32),
|
||||
sequence_length=sequence_length,
|
||||
image_patch_token_id=image_patch_token_id,
|
||||
image_col_token_id=image_col_token_id,
|
||||
image_start_token_id=image_start_token_id,
|
||||
image_end_token_id=image_end_token_id,
|
||||
**output_kwargs["images_kwargs"],
|
||||
)
|
||||
|
||||
# Prepend BOS
|
||||
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
|
||||
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
||||
decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
|
||||
out["input_ids"] = decoder_input_tokens
|
||||
if "image_input_idx" in out:
|
||||
# Shift patch mapping up by one since we added BOS
|
||||
image_input_idx = out["image_input_idx"]
|
||||
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
||||
|
||||
for k, v in out.items():
|
||||
out[k] = torch.from_numpy(v)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
MolmoProcessor.register_for_auto_class()
|
||||
170
olmocr/train/prepare_olmocrmix.py
Normal file
170
olmocr/train/prepare_olmocrmix.py
Normal file
@ -0,0 +1,170 @@
|
||||
import argparse
|
||||
import json
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import pandas as pd
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: str | PathLike, max_examples: Optional[int] = None) -> str:
|
||||
"""
|
||||
Prepare OLMoCR mix dataset by downloading from HuggingFace and organizing into a folder structure.
|
||||
|
||||
Args:
|
||||
dataset_path: HuggingFace dataset path
|
||||
subset: Dataset subset name
|
||||
split: Dataset split (train/validation/test)
|
||||
destination: Destination directory path
|
||||
max_examples: Maximum number of examples to process (None for all)
|
||||
"""
|
||||
# Step 1: Download dataset using hugging face hub snapshot_download to destination/hugging_face folder
|
||||
dest_path = Path(destination)
|
||||
hugging_face_dir = dest_path / "hugging_face"
|
||||
hugging_face_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Downloading dataset {dataset_path} to {hugging_face_dir}...")
|
||||
|
||||
# Download the entire repository including PDFs and parquet files
|
||||
local_dir = snapshot_download(
|
||||
repo_id=dataset_path,
|
||||
repo_type="dataset",
|
||||
local_dir=hugging_face_dir,
|
||||
)
|
||||
|
||||
print(f"Downloaded to: {local_dir}")
|
||||
|
||||
# Step 2: Create destination folder structure for processed markdown files
|
||||
processed_dir = dest_path / f"processed_{subset}_{split}"
|
||||
processed_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Manual map to parquet files for now
|
||||
assert dataset_path == "allenai/olmOCR-mix-0225", "Only supporting the olmocr-mix for now, later will support other training sets"
|
||||
if subset == "00_documents" and split == "train_s2pdf":
|
||||
parquet_files = [dest_path / "hugging_face" / "train-s2pdf.parquet"]
|
||||
elif subset == "00_documents" and split == "eval_s2pdf":
|
||||
parquet_files = [dest_path / "hugging_face" / "eval-s2pdf.parquet"]
|
||||
elif subset == "01_books" and split == "train_s2pdf":
|
||||
parquet_files = [dest_path / "hugging_face" / "train-iabooks.parquet"]
|
||||
elif subset == "01_books" and split == "train_s2pdf":
|
||||
parquet_files = [dest_path / "hugging_face" / "eval-iabooks.parquet"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# Step 3: Process parquet files
|
||||
total_processed = 0
|
||||
total_errors = 0
|
||||
|
||||
for parquet_file in parquet_files:
|
||||
print(f"Processing {parquet_file.name}...")
|
||||
df = pd.read_parquet(parquet_file)
|
||||
|
||||
# Process each row
|
||||
for idx, row in df.iterrows():
|
||||
if max_examples and total_processed >= max_examples:
|
||||
break
|
||||
|
||||
try:
|
||||
|
||||
# Extract fields from the row
|
||||
# The rows in the parquet will look like url, page_number, response (json format), and id
|
||||
response = row.get('response', '')
|
||||
doc_id = str(idx)
|
||||
|
||||
assert len(doc_id) > 4
|
||||
|
||||
# Parse response if it's a JSON string
|
||||
response_data = json.loads(response)
|
||||
response = response_data
|
||||
|
||||
# Create folder structure using first 4 digits of id
|
||||
# Make a folder structure, to prevent a huge amount of files in one folder, using the first 4 digits of the id, ex. id[:4]/id[4:].md
|
||||
folder_name = doc_id[:4]
|
||||
file_name = f"{doc_id[4:]}.md"
|
||||
|
||||
# Create directory
|
||||
output_dir = processed_dir / folder_name
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Write markdown file with front matter and natural text
|
||||
output_file = output_dir / file_name
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
# Extract natural_text and other fields for front matter
|
||||
natural_text = response.get('natural_text', '')
|
||||
# Create front matter from other fields
|
||||
front_matter = {k: v for k, v in response.items() if k != 'natural_text'}
|
||||
|
||||
# Write front matter
|
||||
f.write("---\n")
|
||||
for k, v in front_matter.items():
|
||||
f.write(f"{k}: {v}\n")
|
||||
f.write("---\n")
|
||||
|
||||
# Write natural text
|
||||
f.write(natural_text)
|
||||
|
||||
total_processed += 1
|
||||
if total_processed % 1000 == 0:
|
||||
print(f"Processed {total_processed} examples...")
|
||||
except Exception as ex:
|
||||
print(f"Error processing line: {ex}")
|
||||
total_errors += 1
|
||||
|
||||
if max_examples and total_processed >= max_examples:
|
||||
break
|
||||
|
||||
print(f"Completed! Processed {total_processed} examples to {processed_dir}")
|
||||
print(f"Total errors: {total_errors}")
|
||||
return str(processed_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Prepare OLMoCR mix dataset")
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default="allenai/olmOCR-mix-0225",
|
||||
help="HuggingFace dataset path (e.g., 'allenai/olmocr-mix')"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
default="00_documents",
|
||||
required=True,
|
||||
help="Dataset subset name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=str,
|
||||
default="eval_s2pdf",
|
||||
required=True,
|
||||
help="Dataset split ex eval_s2pdf"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--destination",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Destination directory path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-examples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of examples to process (default: all)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prepare_olmocr_mix(
|
||||
dataset_path=args.dataset_path,
|
||||
subset=args.subset,
|
||||
split=args.split,
|
||||
destination=args.destination,
|
||||
max_examples=args.max_examples
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,228 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import wandb
|
||||
from datasets.utils import disable_progress_bars
|
||||
from datasets.utils.logging import set_verbosity
|
||||
from peft import LoraConfig, get_peft_model # pyright: ignore
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.integrations import WandbCallback
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from olmocr.train.core.cli import make_cli, save_config, to_native_types
|
||||
from olmocr.train.core.config import TrainConfig
|
||||
from olmocr.train.core.loggers import get_logger
|
||||
from olmocr.train.core.paths import copy_dir, join_path
|
||||
from olmocr.train.core.state import BeakerState
|
||||
|
||||
from .utils import (
|
||||
RunName,
|
||||
TruncatingCollator,
|
||||
get_local_dir,
|
||||
log_trainable_parameters,
|
||||
make_dataset,
|
||||
setup_environment,
|
||||
)
|
||||
|
||||
|
||||
class CheckpointUploadCallback(TrainerCallback):
|
||||
def __init__(self, save_path: str, logger: Optional[Logger] = None):
|
||||
self.save_path = save_path
|
||||
self.logger = logger or get_logger(self.__class__.__name__)
|
||||
|
||||
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
latest_checkpoint = get_last_checkpoint(args.output_dir)
|
||||
if not latest_checkpoint:
|
||||
return
|
||||
|
||||
dir_name = Path(latest_checkpoint).name
|
||||
copy_dir(str(latest_checkpoint), f"{self.save_path}/{dir_name}")
|
||||
self.logger.info("Saved checkpoint to %s", f"{self.save_path}/{dir_name}")
|
||||
|
||||
|
||||
def update_wandb_config(config: TrainConfig, trainer: Trainer, model: torch.nn.Module):
|
||||
# finding wandb callback
|
||||
callbacks = [c for c in trainer.callback_handler.callbacks if isinstance(c, WandbCallback)] # pyright: ignore
|
||||
if not callbacks:
|
||||
raise ValueError("WandbCallback not found in trainer callbacks")
|
||||
|
||||
wandb_callback = callbacks[0]
|
||||
peft_config = to_native_types(getattr(model, "peft_config", {}))
|
||||
script_config = to_native_types(config)
|
||||
beaker_envs = {k: v for k, v in os.environ.items() if k.lower().startswith("beaker")}
|
||||
|
||||
on_setup_fn = wandb_callback.setup
|
||||
|
||||
def setup_and_update(args, state, model, **kwargs):
|
||||
on_setup_fn(args=args, state=state, model=model, **kwargs)
|
||||
wandb.config.update({"peft": peft_config}, allow_val_change=True)
|
||||
wandb.config.update({"script": script_config}, allow_val_change=True)
|
||||
wandb.config.update({"beaker": beaker_envs}, allow_val_change=True)
|
||||
if (run := wandb.run) and (beaker_url := BeakerState().url):
|
||||
run.notes = beaker_url
|
||||
|
||||
wandb_callback.setup = setup_and_update
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank()
|
||||
return 0
|
||||
|
||||
|
||||
def run_train(config: TrainConfig):
|
||||
if get_rank() == 0:
|
||||
logger_level = logging.INFO
|
||||
else:
|
||||
logger_level = logging.WARN
|
||||
disable_progress_bars()
|
||||
|
||||
logger = get_logger(__name__, level=logger_level)
|
||||
set_verbosity(logger_level)
|
||||
|
||||
run_name = RunName.get(config)
|
||||
|
||||
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(config.model.name_or_path, trust_remote_code=True)
|
||||
train_dataset, valid_dataset = make_dataset(config, processor)
|
||||
logger.info(train_dataset)
|
||||
logger.info(valid_dataset)
|
||||
|
||||
if "qwen" in config.model.name_or_path.lower():
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
config.model.name_or_path, torch_dtype=torch.bfloat16, _attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
||||
)
|
||||
else:
|
||||
from .molmo.config_molmo import MolmoConfig
|
||||
from .molmo.modeling_molmo import MolmoForCausalLM
|
||||
|
||||
model_config = MolmoConfig.from_pretrained(config.model.name_or_path, trust_remote_code=True)
|
||||
|
||||
if model_config.max_position_embeddings < config.generate.max_length:
|
||||
logger.warning(
|
||||
f"ALERT, force adjusting model config max_position_embeddings upwards from {model_config.max_position_embeddings} to {config.generate.max_length}"
|
||||
)
|
||||
model_config.max_position_embeddings = config.generate.max_length
|
||||
|
||||
if config.model.use_flash_attn:
|
||||
model_config.attention_type = "flash"
|
||||
|
||||
model = MolmoForCausalLM.from_pretrained(config.model.name_or_path, torch_dtype=torch.bfloat16, config=model_config, trust_remote_code=True)
|
||||
|
||||
logger.info(model)
|
||||
|
||||
if config.lora is not None:
|
||||
peft_config = LoraConfig(
|
||||
r=config.lora.rank,
|
||||
lora_alpha=config.lora.alpha,
|
||||
lora_dropout=config.lora.dropout,
|
||||
bias=config.lora.bias, # pyright: ignore
|
||||
task_type=config.lora.task_type,
|
||||
target_modules=list(config.lora.target_modules),
|
||||
)
|
||||
model = get_peft_model(model=model, peft_config=peft_config)
|
||||
log_trainable_parameters(model=model, logger=logger)
|
||||
|
||||
save_path = join_path("", config.save.path, run_name.run)
|
||||
|
||||
# Make sure directory exists if local
|
||||
if not save_path.startswith("s3://"):
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
|
||||
|
||||
with TemporaryDirectory() as output_dir:
|
||||
training_args = TrainingArguments(
|
||||
run_name=run_name.run,
|
||||
logging_steps=config.hparams.log_every_steps,
|
||||
output_dir=output_dir,
|
||||
eval_strategy="steps",
|
||||
report_to="wandb",
|
||||
# report_to=[], # disable logging to wandb, we will use a custom callback
|
||||
optim=config.hparams.optim,
|
||||
eval_steps=config.hparams.eval_every_steps,
|
||||
learning_rate=config.hparams.learning_rate,
|
||||
per_device_train_batch_size=config.hparams.batch_size,
|
||||
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
|
||||
gradient_checkpointing=config.hparams.gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs=(
|
||||
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
|
||||
if config.hparams.gradient_checkpointing and config.lora is not None
|
||||
else {}
|
||||
),
|
||||
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
|
||||
max_steps=config.hparams.max_steps,
|
||||
weight_decay=config.hparams.weight_decay,
|
||||
dataloader_num_workers=config.max_workers,
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
|
||||
save_steps=config.save.save_every_steps,
|
||||
warmup_steps=config.hparams.warmup_steps,
|
||||
warmup_ratio=config.hparams.warmup_ratio,
|
||||
bf16=True,
|
||||
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
|
||||
max_grad_norm=config.hparams.clip_grad_norm,
|
||||
remove_unused_columns=False,
|
||||
eval_on_start=True,
|
||||
metric_for_best_model=config.valid_data.metric_for_best_model,
|
||||
)
|
||||
|
||||
data_collator = TruncatingCollator(max_length=config.generate.max_length)
|
||||
|
||||
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
|
||||
|
||||
# Initialize Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=valid_dataset,
|
||||
tokenizer=processor.tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=[checkpoint_callback],
|
||||
)
|
||||
|
||||
# Train the model
|
||||
trainer.train() # pyright: ignore
|
||||
|
||||
if get_rank() == 0:
|
||||
with get_local_dir(join_path("", save_path, "best")) as best_dir:
|
||||
if config.lora is not None:
|
||||
logger.info("Merging LoRA adapters into the base model...")
|
||||
model = model.merge_and_unload()
|
||||
logger.info("LoRA adapters merged successfully.")
|
||||
|
||||
model.save_pretrained(best_dir)
|
||||
|
||||
logger.info("Saved best model to %s", best_dir)
|
||||
|
||||
# Uncomment to test speed of data loader
|
||||
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
|
||||
# for entry in tqdm(train_dataloader):
|
||||
# print("Step!")
|
||||
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
|
||||
|
||||
|
||||
def main():
|
||||
train_config = make_cli(TrainConfig) # pyright: ignore
|
||||
run_train(train_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# TODO Overall, this code will read in a config yaml file with omega conf
|
||||
# From that config, we are going to use HuggingFace Trainer to train a model
|
||||
# TODOS:
|
||||
# Build a script to convert olmocr-mix to a new dataloader format
|
||||
# Write a new dataloader and collator, with tests that brings in everything, only needs to support batch size 1 for this first version
|
||||
# Get a basic config yaml file system working
|
||||
# Get a basic hugging face trainer running, supporting Qwen2.5VL for now
|
||||
# Saving and restoring training checkpoints
|
||||
# Converting training checkpoints to vllm compatible checkpoinst
|
||||
|
||||
@ -1,232 +0,0 @@
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict, Generator, List, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import PrecisionType
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from olmocr.train.dataloader import build_finetuning_dataset
|
||||
from olmocr.train.dataprep import (
|
||||
batch_prepare_data_for_molmo_training,
|
||||
batch_prepare_data_for_qwen2_training,
|
||||
)
|
||||
|
||||
from .core.cli import to_native_types
|
||||
from .core.config import AwsConfig, DataConfig, SourceConfig, TrainConfig, WandbConfig
|
||||
from .core.loggers import get_logger
|
||||
from .core.paths import copy_dir, is_local
|
||||
from .core.state import BeakerState
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
|
||||
pt = PrecisionType(accelerator.mixed_precision)
|
||||
if pt == PrecisionType.FP16:
|
||||
return torch.float16
|
||||
elif pt == PrecisionType.BF16:
|
||||
return torch.bfloat16
|
||||
elif pt == PrecisionType.FP8:
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) -> Dataset:
|
||||
return build_finetuning_dataset(source.response_glob_path, pdf_cache_location=data_config.cache_location)
|
||||
|
||||
|
||||
def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
|
||||
random.seed(config.train_data.seed)
|
||||
|
||||
if "qwen" in config.model.name_or_path.lower():
|
||||
batch_fn = batch_prepare_data_for_qwen2_training
|
||||
elif "molmo" in config.model.name_or_path.lower():
|
||||
batch_fn = batch_prepare_data_for_molmo_training
|
||||
else:
|
||||
raise NotImplementedError("Model format not supported")
|
||||
|
||||
# Retrieve the two target lengths from the first source for comparison
|
||||
first_source = config.train_data.sources[0]
|
||||
target_longest_image_dim = first_source.target_longest_image_dim
|
||||
target_anchor_text_len = first_source.target_anchor_text_len
|
||||
|
||||
# Verify that all sources have the same target lengths
|
||||
for source in config.train_data.sources:
|
||||
if source.target_longest_image_dim != target_longest_image_dim:
|
||||
raise ValueError(f"Inconsistent target_longest_image_dim found in source {source}")
|
||||
if source.target_anchor_text_len != target_anchor_text_len:
|
||||
raise ValueError(f"Inconsistent target_anchor_text_len found in source {source}")
|
||||
|
||||
# Concatenate datasets first, unfortunately you can't apply the transform before concatenation due to the library
|
||||
train_dataset = concatenate_datasets([get_rawdataset_from_source(config.train_data, source) for source in config.train_data.sources])
|
||||
|
||||
# Apply the transform to the concatenated dataset
|
||||
train_dataset = train_dataset.with_transform(
|
||||
partial(
|
||||
batch_fn,
|
||||
processor=processor,
|
||||
target_longest_image_dim=list(target_longest_image_dim),
|
||||
target_anchor_text_len=list(target_anchor_text_len),
|
||||
)
|
||||
)
|
||||
|
||||
# Validation sets get put into a datasetdict so each can report a loss separately
|
||||
valid_dataset = DatasetDict(
|
||||
**{
|
||||
source.name: get_rawdataset_from_source(config.valid_data, source).with_transform(
|
||||
partial(
|
||||
batch_fn,
|
||||
processor=processor,
|
||||
target_longest_image_dim=list(source.target_longest_image_dim),
|
||||
target_anchor_text_len=list(source.target_anchor_text_len),
|
||||
)
|
||||
)
|
||||
for source in config.valid_data.sources
|
||||
}
|
||||
)
|
||||
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
def setup_environment(aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str):
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "false"
|
||||
|
||||
if wandb_config:
|
||||
os.environ["WANDB_WATCH"] = "false"
|
||||
|
||||
for key, value in to_native_types(wandb_config or {}).items():
|
||||
if value is not None:
|
||||
os.environ[f"WANDB_{key.upper()}"] = str(value)
|
||||
|
||||
for key, value in to_native_types(aws_config or {}).items():
|
||||
if value is not None:
|
||||
os.environ[f"AWS_{key.upper()}"] = str(value)
|
||||
|
||||
os.environ.update(kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunName:
|
||||
run: str
|
||||
group: str
|
||||
|
||||
@classmethod
|
||||
def get(cls, config: TrainConfig, accelerator: Optional[Accelerator] = None) -> "RunName":
|
||||
job_rank = f"-{accelerator.process_index}" if accelerator else ""
|
||||
|
||||
if beaker_job_id := BeakerState().job_id:
|
||||
job_id = f"-{beaker_job_id}"
|
||||
else:
|
||||
job_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
(config_hash := sha1()).update(json.dumps(to_native_types(config)).encode())
|
||||
model_name = config.model.name_or_path.replace("/", "_")
|
||||
|
||||
group_name = f"{model_name}-{config_hash.hexdigest()[:6]}"
|
||||
run_name = f"{group_name}{job_id}{job_rank}"
|
||||
return cls(group=group_name, run=run_name)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_torch_threads(n: int):
|
||||
torch_num_threads = torch.get_num_threads()
|
||||
torch.set_num_threads(n)
|
||||
|
||||
yield
|
||||
|
||||
torch.set_num_threads(torch_num_threads)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_args(obj: T, **kwargs) -> Generator[T, None, None]:
|
||||
orig = {k: getattr(obj, k) for k in kwargs.keys()}
|
||||
for k, v in kwargs.items():
|
||||
setattr(obj, k, v)
|
||||
|
||||
yield obj
|
||||
|
||||
for k, v in orig.items():
|
||||
setattr(obj, k, v)
|
||||
|
||||
|
||||
def log_trainable_parameters(model: torch.nn.Module, logger: Optional[Logger] = None):
|
||||
"""
|
||||
Prints the number of trainable parameters in the model.
|
||||
"""
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for name, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
(logger or get_logger(__name__)).info(f"training with {name}")
|
||||
trainable_params += param.numel()
|
||||
|
||||
(logger or get_logger(__name__)).info(
|
||||
"trainable params: %s || all params: %s || trainable%%: %s",
|
||||
f"{trainable_params:,}",
|
||||
f"{all_param:,}",
|
||||
f"{trainable_params / all_param:.2%}",
|
||||
)
|
||||
|
||||
|
||||
class TruncatingCollator:
|
||||
def __init__(self, max_length: int):
|
||||
self.max_length = max_length
|
||||
|
||||
def __call__(self, batch: List[Dict]) -> Dict:
|
||||
# Assert that we are only handling batch size 1 for now
|
||||
assert len(batch) == 1, "Only batch size 1 is supported for now"
|
||||
|
||||
if "pixel_values" in batch[0]:
|
||||
# Qwen2 case
|
||||
truncated_input_ids = torch.tensor(batch[0]["input_ids"][: self.max_length]).unsqueeze(0)
|
||||
truncated_attention_mask = torch.tensor(batch[0]["attention_mask"][: self.max_length]).unsqueeze(0)
|
||||
truncated_labels = torch.tensor(batch[0]["labels"][: self.max_length]).unsqueeze(0)
|
||||
|
||||
return {
|
||||
"input_ids": truncated_input_ids,
|
||||
"attention_mask": truncated_attention_mask,
|
||||
"labels": truncated_labels,
|
||||
"pixel_values": torch.tensor(batch[0]["pixel_values"]).unsqueeze(0),
|
||||
"image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]).unsqueeze(0),
|
||||
}
|
||||
elif "image_input_idx" in batch[0]:
|
||||
# molmo case
|
||||
truncated_input_ids = batch[0]["input_ids"][: self.max_length].unsqueeze(0)
|
||||
truncated_attention_mask = batch[0]["attention_mask"][: self.max_length].unsqueeze(0)
|
||||
truncated_labels = batch[0]["labels"][: self.max_length].unsqueeze(0)
|
||||
|
||||
return {
|
||||
"input_ids": truncated_input_ids,
|
||||
"attention_mask": truncated_attention_mask,
|
||||
"labels": truncated_labels,
|
||||
"images": batch[0]["images"].unsqueeze(0),
|
||||
"image_input_idx": batch[0]["image_input_idx"].unsqueeze(0),
|
||||
"image_masks": batch[0]["image_masks"].unsqueeze(0),
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_local_dir(output_dir: str):
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
if is_local(output_dir):
|
||||
yield output_dir
|
||||
else:
|
||||
yield tmp_dir
|
||||
copy_dir(tmp_dir, output_dir)
|
||||
Loading…
x
Reference in New Issue
Block a user