First attempt at new trainer code

This commit is contained in:
Jake Poznanski 2025-06-11 16:56:16 +00:00
parent 3eda2c04c1
commit f0d8ff7bd3
54 changed files with 180 additions and 6558 deletions

1
.gitignore vendored
View File

@ -14,6 +14,7 @@ localworkspace/*
math_data/*
math_data_big/*
gpt4otestset/*
old_train/
gpt4otestset_output/*
pdfs/*
olmOCR-bench/*

View File

@ -1,31 +0,0 @@
# pip install llmcompressor
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
MODEL_ID = "/home/ubuntu/olmocr/olmOCR-7B-0225-preview"
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
# Configure the simple PTQ quantization
# recipe = QuantizationModifier(
# targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"])
# Configure pre-defined qwen2vl recipe
recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["re:.*lm_head", "re:visual.*"],
)
# Apply the quantization algorithm.
oneshot(model=model, recipe=recipe)
# Save the model.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

View File

@ -1,87 +0,0 @@
model:
name_or_path: allenai/Molmo-7B-O-0924
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
generate:
max_length: 8192
train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
find_unused_parameters: true
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: CAUSAL_LM
target_modules:
# attention layers in main transformer
- att_proj
- ff_proj
- attn_out
- ff_out
# vision transformer attention and FF
- attention.wq
- attention.wk
- attention.wv
- attention.wo
- feed_forward.w1
- feed_forward.w2
# vision image projector
- vision_backbone.image_projector.w1
- vision_backbone.image_projector.w2
- vision_backbone.image_projector.w3
save:
path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/
save_every_steps: 1000
max_workers: 10

View File

@ -1,89 +0,0 @@
model:
name_or_path: allenai/Molmo-7B-O-0924
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
generate:
max_length: 4096
train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
find_unused_parameters: true
clip_grad_norm: 1.0
learning_rate: 1e-4
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: CAUSAL_LM
target_modules:
# attention layers in main transformer
- att_proj
- ff_proj
- attn_out
- ff_out
# vision transformer attention and FF
- attention.wq
- attention.wk
- attention.wv
- attention.wo
- feed_forward.w1
- feed_forward.w2
# vision image projector
- vision_backbone.image_projector.w1
- vision_backbone.image_projector.w2
- vision_backbone.image_projector.w3
save:
path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/
save_every_steps: 1000
max_workers: 10

View File

@ -1,73 +0,0 @@
model:
name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
generate:
max_length: 8192
train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
# - name: openai_batch_data_v5_1_train
# response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
# target_longest_image_dim: [1024]
# target_anchor_text_len: [6000]
# - name: openai_batch_data_v5_1_iabooks_train
# response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
# target_longest_image_dim: [1024]
# target_anchor_text_len: [6000]
valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 1e-6
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
save:
path: s3://ai2-oe-data/jakep/experiments/qwen25vl-pdf/v1/models/
save_every_steps: 9500
max_workers: 10

View File

@ -1,89 +0,0 @@
model:
name_or_path: Qwen/Qwen2-VL-2B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
# TODO This is not used
format:
instruction_template: "Original:"
response_template: "Rewritten:"
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
chat_template: |
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
{% if loop.last %}
{{ '<|im_end|>'}}
{% else %}
{{ '<|im_end|>\n' }}
{% endif %}
{% endfor %}
generate:
max_length: 4096
train_data:
seed: 1337
sources:
- name: openai_batch_data_v2
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
backend:
- openai
size: 100_000
valid_data:
sources:
- name: openai_batch_data_eval_mini
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
backend:
- openai
size: 100_000
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: false
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 2000
pad_multiple_of: 16
log_every_steps: 50
eval_every_steps: 1000
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: causal_lm
target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
- visual.blocks.[0-9]+.attn.qkv
- visual.blocks.[0-9]+.attn.proj
- visual.blocks.[0-9]+.mlp.fc1
- visual.blocks.[0-9]+.mlp.fc2
- visual.merger.mlp.0
- visual.merger.mlp.2
save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 1000
max_workers: 10

View File

@ -1,84 +0,0 @@
model:
name_or_path: Qwen/Qwen2-VL-2B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
# TODO This is not used
format:
instruction_template: "Original:"
response_template: "Rewritten:"
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
chat_template: |
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
{% if loop.last %}
{{ '<|im_end|>'}}
{% else %}
{{ '<|im_end|>\n' }}
{% endif %}
{% endfor %}
generate:
max_length: 4096
train_data:
seed: 1337
sources:
- name: openai_batch_data_v2
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
backend:
- openai
size: 100_000
valid_data:
sources:
- name: openai_batch_data_eval_mini
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
backend:
- openai
size: 100_000
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: false
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 2000
pad_multiple_of: 16
log_every_steps: 50
eval_every_steps: 1000
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
# Disable LORA for now, because we want the visual network to get trained too
# lora:
# rank: 32
# alpha: 32
# dropout: 0.05
# task_type: causal_lm
# target_modules:
# - q_proj
# - k_proj
# - v_proj
# - o_proj
# - gate_proj
# - up_proj
# - down_proj
save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 1000
max_workers: 10

View File

@ -1,84 +0,0 @@
model:
name_or_path: Qwen/Qwen2-VL-7B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
generate:
max_length: 8192
train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 1e-4
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: causal_lm
target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
- visual.blocks.[0-9]+.attn.qkv
- visual.blocks.[0-9]+.attn.proj
- visual.blocks.[0-9]+.mlp.fc1
- visual.blocks.[0-9]+.mlp.fc2
- visual.merger.mlp.0
- visual.merger.mlp.2
save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 1000
max_workers: 10

View File

@ -1,64 +0,0 @@
model:
name_or_path: Qwen/Qwen2-VL-7B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
generate:
max_length: 8192
train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 1e-6
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 9500
max_workers: 10

View File

@ -1,96 +0,0 @@
import json
from logging import Logger
from typing import Optional, Type
import smart_open
import torch
from peft.peft_model import PeftModel
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelWithLMHead,
AutoTokenizer,
)
from .config import ModelConfig
from .loggers import get_logger
from .paths import cached_path, exists, get_cache_dir, join_path, resource_to_filename
__all__ = ["load_model", "cache_merged_model"]
def get_model_cls(config: ModelConfig) -> Type[AutoModelWithLMHead]:
if config.arch == "seq2seq":
return AutoModelForSeq2SeqLM # pyright: ignore
elif config.arch == "causal" or config.arch == "vllm":
return AutoModelForCausalLM # pyright: ignore
else:
raise ValueError(f"Unsupported model architecture: {config.arch}")
def get_adapter_config(config: ModelConfig) -> dict:
local_path = cached_path(config.name_or_path)
if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")):
with smart_open.open(adapter_config_path, "rt", encoding="utf-8") as f:
return json.load(f)
return {}
def load_model(config: ModelConfig, logger: Optional[Logger] = None) -> AutoModelWithLMHead:
logger = logger or get_logger(__file__, level="INFO")
logger.info(f"Loading model from {config.name_or_path}")
local_path = cached_path(config.name_or_path)
if local_path != config.name_or_path:
logger.info(f"Model cached at {local_path}")
if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")):
logger.info(f"Loading LoRA adapter from {adapter_config_path}")
with smart_open.open(adapter_config_path) as f:
adapter_config = json.load(f)
base_model_name_or_path = adapter_config["base_model_name_or_path"]
enable_lora = True
else:
base_model_name_or_path = local_path
enable_lora = False
model = get_model_cls(config).from_pretrained(
base_model_name_or_path,
device_map="auto",
trust_remote_code=config.trust_remote_code,
# low_cpu_mem_usage=model_config.low_cpu_mem_usage,
use_flash_attention_2=True if config.use_flash_attn else False,
revision=config.model_revision,
torch_dtype=torch.bfloat16 if config.use_flash_attn else getattr(torch, config.dtype),
)
logger.info(f"Successfully loaded base model from {base_model_name_or_path}")
if enable_lora:
peft_model = PeftModel.from_pretrained(model, local_path)
model = peft_model.merge_and_unload()
logger.info(f"Successfully loaded LoRA adapter from base model: {base_model_name_or_path}")
return model
def cache_merged_model(config: ModelConfig, logger: Optional[Logger] = None) -> str:
logger = logger or get_logger(__file__, level="INFO")
base_local_path = cached_path(config.name_or_path)
adapter_config = get_adapter_config(config)
if not adapter_config:
logger.info("No adapter config found; using base model")
return base_local_path
local_fn = resource_to_filename(json.dumps({"adapter": adapter_config, "model": config.name_or_path}))
merged_local_path = f"{get_cache_dir()}/{local_fn}"
if not exists(merged_local_path):
model = load_model(config=config, logger=logger)
tokenizer = AutoTokenizer.from_pretrained(base_local_path)
logger.info(f"Saving merged model to {merged_local_path}")
model.save_pretrained(merged_local_path)
tokenizer.save_pretrained(merged_local_path)
return merged_local_path

View File

@ -1,327 +0,0 @@
"""
Utilities to work with a OmegaConf structured config object
From Dolma Toolkit: https://github.com/allenai/dolma/blob/64886d9db15bd99acea9e28740ae20a510875dfb/python/dolma/cli/__init__.py
Author: Luca Soldaini (@soldni)
""" # noqa: E501
from argparse import ArgumentParser, Namespace
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import Field
from dataclasses import field as dataclass_field
from dataclasses import is_dataclass
from logging import warning
from typing import (
Any,
Dict,
Literal,
Optional,
Protocol,
Type,
TypeVar,
Union,
get_args,
get_origin,
)
import smart_open
from necessary import necessary
from omegaconf import MISSING, DictConfig, ListConfig
from omegaconf import OmegaConf as om
from omegaconf.errors import OmegaConfBaseException
from rich.console import Console
from rich.syntax import Syntax
from yaml import safe_load # type: ignore
from .errors import DolmaRefineError
__all__ = ["field", "namespace_to_nested_omegaconf", "print_config", "make_cli", "read_config", "to_native_types"]
T = TypeVar("T", bound=Any)
D = TypeVar("D", bound="DataClass")
A = TypeVar("A", bound="ArgumentParser")
def _field_nargs(default: Any) -> Union[Literal["?"], Literal["*"]]:
# return '+' if _default is iterable but not string/bytes, else 1
if isinstance(default, (str, bytes)):
return "?"
if isinstance(default, Iterable):
return "*"
return "?"
def field(default: T = MISSING, help: Optional[str] = None, **extra: Any) -> T:
metadata = {"help": help, "type": type(default), "default": default, "nargs": _field_nargs(default), **extra}
return dataclass_field(default_factory=lambda: deepcopy(default), metadata=metadata)
class DataClass(Protocol):
__dataclass_fields__: Dict[str, Field]
def read_config(path: Union[None, str]) -> Dict[str, Any]:
"""Read a configuration file if it exists"""
if path is None:
return {}
try:
with smart_open.open(path, mode="rt") as f:
return dict(safe_load(f))
except FileNotFoundError as ex:
raise DolmaRefineError(f"Config file not found: {path}") from ex
except Exception as ex:
raise DolmaRefineError(f"Error while reading config file: {path}") from ex
def save_config(config: Union[dict, DictConfig, list, ListConfig, DataClass], path: str) -> None:
"""Save a configuration to a file"""
if isinstance(config, (list, dict)):
config = om.create(config)
elif is_dataclass(config):
config = om.structured(config)
with smart_open.open(path, mode="wt") as f:
f.write(om.to_yaml(config))
def _make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None) -> A:
for field_name, dt_field in config.__dataclass_fields__.items():
# get type from annotations or metadata
typ_ = config.__annotations__.get(field_name, dt_field.metadata.get("type", MISSING))
if typ_ is MISSING:
warning(f"No type annotation for field {field_name} in {config.__name__}")
continue
# join prefix and field name
field_name = f"{prefix}.{field_name}" if prefix else field_name
# This section here is to handle Optional[T] types; we only care for cases where T is a dataclass
# So we first check if type is Union since Optional[T] is just a shorthand for Union[T, None]
# and that the union contains only one non-None type
if get_origin(typ_) == Union:
# get all non-None types
args = [a for a in get_args(typ_) if a is not type(None)] # noqa: E721
if len(args) == 1:
# simple Optional[T] type
typ_ = args[0]
# here's where we check if T is a dataclass
if is_dataclass(typ_):
# recursively add subparsers
_make_parser(parser, typ_, prefix=field_name) # type: ignore
continue
if typ_ is bool:
# for boolean values, we add two arguments: --field_name and --no-field_name
parser.add_argument(
f"--{field_name}",
help=dt_field.metadata.get("help"),
dest=field_name,
action="store_true",
default=MISSING,
)
parser.add_argument(
f"--no-{field_name}",
help=f"Disable {field_name}",
dest=field_name,
action="store_false",
default=MISSING,
)
else:
# else it's just a normal argument
parser.add_argument(
f"--{field_name}",
help=dt_field.metadata.get("help"),
nargs=dt_field.metadata.get("nargs", "?"),
default=MISSING,
)
return parser
def make_nested_dict(key: str, value: Any, d: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
d = d or {}
if "." in key:
key, rest = key.split(".", 1)
value = make_nested_dict(rest, value, d.get(key))
# the value was provided (is not MISSING constant) and is not an empty dict or list
if value != MISSING and (not isinstance(value, (dict, list)) or len(value) > 0):
d[key] = value
return d
def to_native_types(obj: Any, resolve: bool = True, throw_on_missing: bool = True, enum_to_str: bool = True) -> Any:
"""Converts an OmegaConf object to native types (dicts, lists, etc.)"""
# convert dataclass to structured config
if hasattr(obj, "to_dict"):
# huggingface objects have a to_dict method, we prefer that
obj = obj.to_dict()
elif is_dataclass(obj):
# we go through structured config instead and hope for the best
obj = om.to_container(obj)
if isinstance(obj, DictConfig) or isinstance(obj, ListConfig):
obj = om.to_container(obj, resolve=resolve, throw_on_missing=throw_on_missing, enum_to_str=enum_to_str)
if isinstance(obj, dict):
return {k: to_native_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [to_native_types(v) for v in obj]
else:
return obj
def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config: Optional[dict] = None) -> T:
nested_config_dict: Dict[str, Any] = {}
for key, value in vars(args).items():
nested_config_dict = make_nested_dict(key, value, nested_config_dict)
untyped_config: DictConfig = om.merge(
om.create(config or {}), om.create(nested_config_dict)
) # pyright: ignore (pylance is confused because om.create might return a DictConfig or a ListConfig)
# resolve any interpolations in the config
om.resolve(untyped_config)
# create structured config from cli dataclass
base_structured_config: DictConfig = om.structured(structured)
# merge with options parsed from config file and
merged_config = om.merge(base_structured_config, untyped_config)
# check for type
if not isinstance(merged_config, DictConfig):
raise DolmaRefineError(f"Expected a DictConfig, got {type(merged_config).__name__}")
# try resolving all cross references in the config, raise a DolmaConfigError if it fails
try:
om.resolve(merged_config)
except OmegaConfBaseException as ex:
raise DolmaRefineError(f"Invalid error while parsing key `{ex.full_key}`: {type(ex).__name__}") from ex
return merged_config # pyright: ignore
def print_config(config: Any, console: Optional[Console] = None) -> None:
if not isinstance(config, (DictConfig, ListConfig)):
config = om.create(config)
# print the config as yaml using a rich syntax highlighter
console = console or Console()
yaml_config = om.to_yaml(config, sort_keys=True).strip()
highlighted = Syntax(code=yaml_config, lexer="yaml", theme="ansi_dark")
console.print(highlighted)
def _patch_old_omegaconf():
"""Monkey patch omegaconf below version 2.3.0 to support custom resolver returning
lists or dicts. Applies patch https://github.com/omry/omegaconf/pull/1093"""
if necessary(("omegaconf", "2.4.0"), soft=True):
# no need to patch
return
if getattr(_patch_old_omegaconf, "__patched__", False):
# already patched
return
from omegaconf import _impl # pylint: disable=import-outside-toplevel
from omegaconf import ( # pylint: disable=import-outside-toplevel
Container,
Node,
ValueNode,
)
from omegaconf._utils import ( # noqa: F401 # pylint: disable=import-outside-toplevel
_ensure_container,
_get_value,
is_primitive_container,
is_structured_config,
)
from omegaconf.errors import ( # pylint: disable=import-outside-toplevel
InterpolationToMissingValueError,
)
from omegaconf.nodes import ( # pylint: disable=import-outside-toplevel
InterpolationResultNode,
)
def _resolve_container_value(cfg: Container, key: Any) -> None:
node = cfg._get_child(key) # pylint: disable=protected-access
assert isinstance(node, Node)
if node._is_interpolation(): # pylint: disable=protected-access
try:
resolved = node._dereference_node() # pylint: disable=protected-access
except InterpolationToMissingValueError:
node._set_value(MISSING) # pylint: disable=protected-access
else:
if isinstance(resolved, Container):
_impl._resolve(resolved) # pylint: disable=protected-access
if isinstance(resolved, InterpolationResultNode):
resolved_value = _get_value(resolved)
if is_primitive_container(resolved_value) or is_structured_config(resolved_value):
resolved = _ensure_container(resolved_value)
if isinstance(resolved, Container) and isinstance(node, ValueNode):
cfg[key] = resolved
else:
node._set_value(_get_value(resolved)) # pylint: disable=protected-access
else:
_impl._resolve(node) # pylint: disable=protected-access
# set new function and mark as patched
setattr(_impl, "_resolve_container_value", _resolve_container_value)
setattr(_patch_old_omegaconf, "__patched__", True)
# actually executes the patch
_patch_old_omegaconf()
def make_cli(config_cls: Type[D], _config_flag: str = "config", _dryrun_flag: str = "dryrun") -> D:
"""Create a CLI parser for a dataclass and parse the arguments into a structured config object."""
if hasattr(config_cls, _config_flag):
raise DolmaRefineError(f"`{_config_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`")
if hasattr(config_cls, _dryrun_flag):
raise DolmaRefineError(f"`{_dryrun_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`")
parser = ArgumentParser()
parser.add_argument(f"-{_config_flag[0]}", f"--{_config_flag}", help="Path to config file", default=None, type=str)
parser.add_argument(
f"-{_dryrun_flag[0]}",
f"--{_dryrun_flag}",
help="Dry run mode: print config and exit",
action="store_true",
default=False,
)
parser = _make_parser(parser, config_cls)
args = parser.parse_args()
parsed_config: Dict[str, Any] = {}
if (config_path := getattr(args, _config_flag)) is not None:
parsed_config = read_config(config_path)
delattr(args, _config_flag)
only_dryrun = getattr(args, _dryrun_flag, False)
delattr(args, _dryrun_flag)
full_config = namespace_to_nested_omegaconf(args, config_cls, parsed_config)
print_config(full_config)
if only_dryrun:
exit(0)
return full_config

View File

@ -1,16 +0,0 @@
from smart_open import register_compressor
__all__ = ["mk_compression"]
def mk_compression():
def _handle_zst(file_obj, mode):
try:
import zstandard as zstd
except ImportError:
raise ImportError("zstandard is required for zstd support")
return zstd.open(file_obj, mode)
register_compressor(".zstd", _handle_zst)
register_compressor(".zst", _handle_zst)

View File

@ -1,131 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional
from peft import TaskType # pyright: ignore
from .cli import field
@dataclass
class ModelConfig:
"""Configuration for loading a model; includes model name and type."""
name_or_path: str = field(help="The model name or path to load; must be compatible with huggingface transformers.")
arch: str = field(help="The model type to load; can be 'vllm', 'causal', or 'vllm'")
dtype: str = field(help="The precision to use for the model", default="bfloat16")
use_flash_attn: bool = field(help="Whether to use the flash attention for the model.", default=False)
trust_remote_code: bool = field(help="Whether to trust remote code for the model.", default=False)
low_cpu_mem_usage: bool = field(help="Whether to use low cpu memory usage for the model.", default=False)
fast_tokenizer: bool = field(help="Whether to use the fast tokenizer for the model.", default=True)
model_revision: Optional[str] = field(help="The model revision to use for the model.", default=None)
@dataclass
class GenerateConfig:
max_length: int = field(help="The maximum length of the generated text", default=4096)
temperature: float = field(default=0.2, help="The temperature to use for generation")
top_k: int = field(default=50, help="The top k to use for generation")
top_p: float = field(default=1.0, help="The top p to use for generation")
num_beams: int = field(default=1, help="The number of beams to use for generation")
truncate_prompt_tokens: bool = field(default=True, help="Whether to truncate the prompt tokens for generation")
max_num_seqs: int = field(default=16, help="The maximum number of sequences to generate")
@dataclass
class WandbConfig:
entity: str = field(help="The wandb entity to use for logging", default="ai2-llm")
project: str = field(help="The wandb project to use for logging", default="pdf-qwen2vl")
wandb_api_key: Optional[str] = field(help="The wandb api key to use for logging", default=None)
mode: str = field(help="The wandb mode to use for logging. Set it to `offline`", default="online")
watch: str = field(help="The wandb watch to use for logging", default="false")
@dataclass
class AwsConfig:
profile: Optional[str] = field(help="The aws profile to use for s3 access", default=None)
access_key_id: Optional[str] = field(help="The aws access key id to use for s3 access", default=None)
secret_access_key: Optional[str] = field(help="The aws secret access key to use for s3 access", default=None)
default_region: Optional[str] = field(help="The default region to use for s3 access", default=None)
@dataclass
class SourceConfig:
name: str = field(help="The name of the source")
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
target_longest_image_dim: list[int] = field(help="Dimensions to render the pdf page image to")
target_anchor_text_len: list[int] = field(help="Maximum amount of anchor text (aka prompt hint)")
@dataclass
class DataConfig:
seed: int = field(default=42, help="The seed to use for data loading")
cache_location: Optional[str] = field(help="Location to store s3 pdfs that need to be used to compute page images", default=None)
metric_for_best_model: Optional[str] = field(help="metric to pass to trainer args to use for picking best model checkpoint at end", default=None)
sources: List[SourceConfig] = field(help="The source configurations")
@dataclass
class HyperparamConfig:
batch_size: int = field(default=8, help="The batch size to use for training")
eval_batch_size: Optional[int] = field(default=None, help="The batch size to use for evaluation; default is the same as the training batch size")
learning_rate: float = field(default=2e-5, help="The learning rate to use for training")
max_steps: int = field(default=-1, help="The maximum number of steps to train the model")
pad_multiple_of: int = field(default=16, help="The padding multiple to use for the model")
log_every_steps: int = field(default=5, help="The number of steps to log training metrics")
eval_every_steps: int = field(default=100, help="The number of steps to evaluate the model")
weight_decay: float = field(default=0.0, help="The weight decay to use for training")
warmup_steps: int = field(default=0, help="The number of warmup steps to use for training")
warmup_ratio: float = field(default=0.0, help="The ratio of warmup steps to use for training")
lr_scheduler: str = field(default="linear", help="The learning rate scheduler to use for training")
gradient_accumulation_steps: int = field(default=1, help="The number of gradient accumulation steps to use for training")
gradient_checkpointing: bool = field(default=False, help="Whether to use gradient checkpointing for training")
seed: int = field(default=42, help="The seed to use for training")
reduce_loss: str = field(default="mean", help="The loss reduction to use for training")
clip_grad_norm: float = field(default=0.0, help="The gradient norm to clip to for training")
optim: str = field(default="adamw_torch", help="The optimizer to use for training")
find_unused_parameters: bool = field(default=False, help="Whether to find unused parameters for training")
@dataclass
class SaveConfig:
path: str = field(default="./results", help="The output directory to save the model")
limit: Optional[int] = field(default=None, help="The number of checkpoints to save")
save_every_steps: int = field(default="${hparams.eval_every_steps}", help="The number of steps to save the model") # type: ignore
@dataclass
class LoraConfig:
rank: int = field(default=16, help="The rank of the LoRA attention")
alpha: int = field(default=16, help="The alpha parameter for LoRA scaling")
dropout: float = field(default=0.05, help="The dropout probability for LoRA layers")
bias: str = field(default="none", help="The bias to use for LoRA layers (none, causal, or full)")
task_type: str = field(default=TaskType.CAUSAL_LM, help="The task type for the model")
target_modules: List[str] = field(
default=["k_proj", "q_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
help="The target modules in the model that will be replaced with LoRA layers",
)
@dataclass
class TrainConfig:
model: ModelConfig = field(default=ModelConfig(), help="The model configuration")
lora: Optional[LoraConfig] = field(default=None, help="The LoRA configuration")
aws: AwsConfig = field(default=AwsConfig(), help="Configuration for AWS S3")
wandb: WandbConfig = field(default=WandbConfig(), help="Configuration for Weights and Biases")
train_data: DataConfig = field(default=DataConfig(), help="Configuration for the training data")
valid_data: DataConfig = field(default=DataConfig(), help="Configuration for the validation data")
generate: GenerateConfig = field(default=GenerateConfig(), help="Configuration for text generation")
num_proc: int = field(default=1, help="The maximum number of workers to use for data processing")
max_workers: int = field(default=1, help="The maximum number of workers to use for data loaders")
hparams: HyperparamConfig = field(default=HyperparamConfig(), help="Hyperparameters for training")
save: SaveConfig = field(default=SaveConfig(), help="Configuration for saving the model")
@dataclass
class DemoConfig:
title: str = field(default="# Dolma Rewriter Demo")
description: str = field(default="Internal use only, **DO NOT SHARE OUTSIDE AI2**.")
share: bool = field(default=False, help="Share the demo publicly.")
model: ModelConfig = field(default=ModelConfig())
generate: GenerateConfig = field(default=GenerateConfig())

View File

@ -1 +0,0 @@
class DolmaRefineError(RuntimeError): ...

View File

@ -1,52 +0,0 @@
import logging
import multiprocessing
from typing import Union
LOGGER_PREFIX = "dolma-refine"
def get_logger(name: str, level: Union[int, str] = logging.WARN) -> logging.Logger:
if (proc_name := multiprocessing.current_process().name) == "MainProcess":
proc_name = "main"
proc_name = proc_name.replace(" ", "_")
# set the log level
level = level if isinstance(level, int) else getattr(logging, level.strip().upper(), logging.WARN)
# set name
name = f"{LOGGER_PREFIX}.{proc_name}.{name}"
logger = logging.getLogger(name)
logger.setLevel(level)
# add handler
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter("[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def reset_level(level: Union[int, str]) -> None:
"""
Reset the log level for all Dolma loggers.
Args:
level (Union[int, str]): The log level to set. It can be either an integer
representing the log level (e.g., logging.DEBUG) or a string
representing the log level name (e.g., 'debug').
Returns:
None
"""
if isinstance(level, str):
if (level_tmp := getattr(logging, level.strip().upper(), None)) is not None:
level = level_tmp
else:
raise ValueError(f"Invalid log level: {level}")
for logger in logging.Logger.manager.loggerDict.values():
if isinstance(logger, logging.Logger):
if logger.name.startswith(LOGGER_PREFIX):
logger.setLevel(level)

View File

@ -1,615 +0,0 @@
import glob
import os
import re
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce
from hashlib import sha256
from itertools import chain
from pathlib import Path
from shutil import copyfileobj
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from urllib.parse import urlparse
import platformdirs
import smart_open
from fsspec import AbstractFileSystem, get_filesystem_class
from smart_open.compression import get_supported_extensions
from .loggers import LOGGER_PREFIX, get_logger
__all__ = [
"glob_path",
"sub_prefix",
"add_suffix",
"sub_suffix",
"make_relative",
"mkdir_p",
"split_path",
"join_path",
"is_glob",
"split_glob",
"partition_path",
]
FS_KWARGS: Dict[str, Dict[str, Any]] = {
"": {"auto_mkdir": True},
}
RE_ANY_ESCAPE = re.compile(r"(?<!\\)(\*\?\[\])")
RE_GLOB_STAR_ESCAPE = re.compile(r"(?<!\\)\*")
RE_GLOB_ONE_ESCAPE = re.compile(r"(?<!\\)\?")
RE_GLOB_OPEN_ESCAPE = re.compile(r"(?<!\\)\[")
RE_GLOB_CLOSE_ESCAPE = re.compile(r"(?<!\\)\]")
ESCAPE_SYMBOLS_MAP = {"*": "\u2581", "?": "\u2582", "[": "\u2583", "]": "\u2584"}
REVERSE_ESCAPE_SYMBOLS_MAP = {v: k for k, v in ESCAPE_SYMBOLS_MAP.items()}
PATCHED_GLOB = False
LOGGER = get_logger(__name__)
def get_fs(path: Union[Path, str]) -> AbstractFileSystem:
"""
Get the filesystem class for a given path.
"""
path = str(path)
protocol = urlparse(path).scheme
fs = get_filesystem_class(protocol)(**FS_KWARGS.get(protocol, {}))
global PATCHED_GLOB # pylint: disable=global-statement
# patch glob method to support recursive globbing
if protocol == "" and not PATCHED_GLOB:
fs.glob = partial(glob.glob, recursive=True)
# only patch once
PATCHED_GLOB = True
return fs
def _escape_glob(s: Union[str, Path]) -> str:
"""
Escape glob characters in a string.
"""
s = str(s)
s = RE_GLOB_STAR_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["*"], s)
s = RE_GLOB_ONE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["?"], s)
s = RE_GLOB_OPEN_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["["], s)
s = RE_GLOB_CLOSE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["]"], s)
return s
def _unescape_glob(s: Union[str, Path]) -> str:
"""
Unescape glob characters in a string.
"""
s = str(s)
for k, v in REVERSE_ESCAPE_SYMBOLS_MAP.items():
s = s.replace(k, v)
return s
def _pathify(path: Union[Path, str]) -> Tuple[str, Path]:
"""
Return the protocol and path of a given path.
"""
path = _escape_glob(str(path))
parsed = urlparse(path)
path = Path(f"{parsed.netloc}/{parsed.path}") if parsed.netloc else Path(parsed.path)
return parsed.scheme, path
def _unpathify(protocol: str, path: Path) -> str:
"""
Return a path from its protocol and path components.
"""
path_str = _unescape_glob(str(path))
if protocol:
path_str = f"{protocol}://{path_str.lstrip('/')}"
return path_str
def remove_params(path: str) -> str:
"""
Remove parameters from a path.
"""
parsed = urlparse(path)
return (f"{parsed.scheme}://" if parsed.scheme else "") + f"{parsed.netloc}{parsed.path}"
def is_local(path: str) -> bool:
"""
Check if a path is local.
"""
prot, _ = _pathify(path)
return prot == "" or prot == "file"
def copy_file(src: str, dest: str) -> None:
"""Copy a file using shutil.copyfileobj for efficient chunked copying."""
with smart_open.open(src, "rb") as src_file, smart_open.open(dest, "wb") as dest_file:
copyfileobj(src_file, dest_file)
def copy_dir(src: str, dst: str, src_fs: Optional[AbstractFileSystem] = None, dst_fs: Optional[AbstractFileSystem] = None):
"""Copy a directory using a ThreadPoolExecutor for parallel file copying."""
src_fs = src_fs or get_fs(src)
dst_fs = dst_fs or get_fs(dst)
logger = get_logger(__name__)
with ThreadPoolExecutor(max_workers=8) as executor:
futures = []
for src_path in glob_path(src, yield_dirs=True, fs=src_fs):
rel_path = sub_prefix(src_path, src)
dest_path = join_path("", dst, rel_path)
if is_dir(src_path, fs=src_fs):
# Recursively copy directories
copy_dir(src=src_path, dst=dest_path, src_fs=src_fs, dst_fs=dst_fs)
else:
# File; copy over using the executor for parallelism
logger.info(f"Copying {src_path} to {dest_path}")
futures.append(executor.submit(copy_file, src_path, dest_path))
# Wait for all futures to complete
for future in futures:
future.result() # This will raise an exception if any of the threads failed
def delete_file(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Delete a file."""
fs = fs or get_fs(path)
try:
fs.rm(path)
deleted = True
except FileNotFoundError as ex:
if not ignore_missing:
raise ex
deleted = False
return deleted
def get_size(path: str, fs: Optional[AbstractFileSystem] = None) -> int:
"""Get the size of a file"""
fs = fs or get_fs(path)
if not exists(path, fs=fs):
raise ValueError(f"Path {path} does not exist")
if is_dir(path, fs=fs):
raise ValueError(f"Path {path} is a directory")
return fs.info(path)["size"]
def delete_dir(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Delete a directory."""
fs = fs or get_fs(path)
try:
fs.rm(path, recursive=True)
deleted = True
except FileNotFoundError as ex:
if not ignore_missing:
raise ex
deleted = False
return deleted
def partition_path(path: str) -> Tuple[str, Tuple[str, ...], Tuple[str, ...]]:
"""Partition a path into its protocol, symbols before a glob, and symbols after a glob."""
# split the path into its protocol and path components
prot, path_obj = _pathify(path)
# we need to first figure out if this path has a glob by checking if any of the escaped symbols for
# globs are in the path.
glob_locs = [i for i, p in enumerate(path_obj.parts) if any(c in p for c in REVERSE_ESCAPE_SYMBOLS_MAP)]
# make the path components before the glob
pre_glob_path = path_obj.parts[: glob_locs[0]] if glob_locs else path_obj.parts
pre_glob_path = tuple(_unescape_glob(p) for p in pre_glob_path)
# make the path components after the glob
post_glob_path = path_obj.parts[glob_locs[0] + 1 :] if glob_locs else ()
post_glob_path = tuple(_unescape_glob(p) for p in post_glob_path)
return prot, pre_glob_path, post_glob_path
def split_path(path: str) -> Tuple[str, Tuple[str, ...]]:
"""
Split a path into its protocol and path components.
"""
protocol, _path = _pathify(path)
return protocol, tuple(_unescape_glob(p) for p in _path.parts)
def join_path(protocol: Union[str, None], *parts: Union[str, Iterable[str]]) -> str:
"""
Join a path from its protocol and path components.
"""
all_prots, all_parts = zip(*(_pathify(p) for p in chain.from_iterable([p] if isinstance(p, str) else p for p in parts)))
path = str(Path(*all_parts)).rstrip("/")
protocol = protocol or str(all_prots[0])
if protocol:
path = f"{protocol}://{path.lstrip('/')}"
return _unescape_glob(path)
def glob_path(
path: Union[Path, str],
hidden_files: bool = False,
autoglob_dirs: bool = True,
recursive_dirs: bool = False,
yield_dirs: bool = True,
fs: Optional[AbstractFileSystem] = None,
) -> Iterator[str]:
"""
Expand a glob path into a list of paths.
"""
protocol, parsed_path = _pathify(path)
fs = fs or get_fs(path)
if autoglob_dirs and fs.isdir(path):
path = join_path(protocol, _unescape_glob(parsed_path), "*")
if "*" not in str(path):
# nothing to glob
yield str(path)
return
for gl in fs.glob(path):
gl = str(gl)
if not hidden_files and Path(gl).name.startswith("."):
continue
if fs.isdir(gl):
if recursive_dirs:
yield from glob_path(
gl,
hidden_files=hidden_files,
autoglob_dirs=autoglob_dirs,
recursive_dirs=recursive_dirs,
yield_dirs=yield_dirs,
fs=fs,
)
if yield_dirs:
yield join_path(protocol, gl)
else:
yield join_path(protocol, gl)
def sub_prefix(a: str, b: str) -> str:
"""
Return the relative path of b from a.
"""
prot_a, path_a = _pathify(a)
prot_b, path_b = _pathify(b)
if prot_a != prot_b:
raise ValueError(f"Protocols of {a} and {b} do not match")
try:
diff = str(path_a.relative_to(path_b))
except ValueError:
diff = join_path(prot_a, path_a.parts)
return _unescape_glob(diff)
def sub_suffix(a: str, b: str) -> str:
"""
Remove b from the end of a.
"""
prot_a, path_a = _pathify(a)
prot_b, path_b = _pathify(b)
if prot_b:
raise ValueError(f"{b} is not a relative path")
sub_path = re.sub(f"{path_b}$", "", str(path_a))
sub_prot = f"{prot_a}://" if prot_a else ""
# need to trim '/' from the end if (a) '/' is not the only symbol in the path or
# (b) there is a protocol so absolute paths don't make sense
if sub_path != "/" or sub_prot:
sub_path = sub_path.rstrip("/")
return _unescape_glob(sub_prot + sub_path)
def add_suffix(a: str, b: str) -> str:
"""
Return the the path of a joined with b.
"""
prot_a, path_a = _pathify(a)
prot_b, path_b = _pathify(b)
if prot_b:
raise ValueError(f"{b} is not a relative path")
return join_path(prot_a, str(path_a / path_b))
def exists(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Check if a path exists."""
fs = fs or get_fs(path)
return fs.exists(path)
def is_dir(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Check if a path is a directory."""
fs = fs or get_fs(path)
if exists(path, fs=fs):
return fs.isdir(path)
return False
def is_file(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Check if a path is a file."""
fs = fs or get_fs(path)
if exists(path, fs=fs):
return fs.isfile(path)
return False
def parent(path: str) -> str:
"""Get the parent directory of a path; if the parent is the root, return the root."""
prot, parts = split_path(path)
if len(parts) == 1:
return path
return join_path(prot, *parts[:-1])
def mkdir_p(path: str, fs: Optional[AbstractFileSystem] = None) -> None:
"""
Create a directory if it does not exist.
"""
if is_glob(path):
raise ValueError(f"Cannot create directory with glob pattern: {path}")
fs = fs or get_fs(path)
fs.makedirs(path, exist_ok=True)
def make_relative(paths: List[str]) -> Tuple[str, List[str]]:
"""Find minimum longest root shared among all paths"""
if len(paths) == 0:
raise ValueError("Cannot make relative path of empty list")
common_prot, common_parts, _ = partition_path(paths[0])
for path in paths:
current_prot, current_parts, _ = partition_path(path)
if current_prot != common_prot:
raise ValueError(f"Protocols of {path} and {paths[0]} do not match")
for i in range(min(len(common_parts), len(current_parts))):
if common_parts[i] != current_parts[i]:
common_parts = common_parts[:i]
break
if len(common_parts) > 0:
common_path = (f"{common_prot}://" if common_prot else "") + str(Path(*common_parts))
relative_paths = [sub_prefix(path, common_path) for path in paths]
else:
common_path = f"{common_prot}://" if common_prot else ""
relative_paths = [_unpathify("", _pathify(path)[1]) for path in paths]
return common_path, relative_paths
def is_glob(path: str) -> bool:
"""
Check if a path contains a glob wildcard.
"""
return bool(re.search(r"(?<!\\)[*?[\]]", path))
def split_glob(path: str) -> Tuple[str, str]:
"""
Partition a path on the first wildcard.
"""
if not is_glob(path):
# it's not a glob, so it's all path
return path, ""
if path[0] == "*":
# starts with a glob, so it's all glob
return "", path
protocol, parts = split_path(path)
i = min(i for i, c in enumerate(parts) if is_glob(c))
if i == 0:
# no path, so it's all glob
return protocol, join_path("", *parts)
path = join_path(protocol, *parts[:i])
rest = join_path("", *parts[i:])
return path, rest
def get_cache_dir() -> str:
"""
Returns the path to the cache directory for the Dolma toolkit.
If the directory does not exist, it will be created.
Returns:
str: The path to the cache directory.
"""
loc = platformdirs.user_cache_dir(LOGGER_PREFIX)
mkdir_p(loc)
return loc
def resource_to_filename(resource: Union[str, bytes]) -> str:
"""
Convert a ``resource`` into a hashed filename in a repeatable way. Preserves the file extensions.
"""
_, (*_, orig_filename) = split_path(remove_params(str(resource)))
_, extensions = split_basename_and_extension(orig_filename)
resource_bytes = str(resource).encode("utf-8")
resource_hash = sha256(resource_bytes)
hash_filename = resource_hash.hexdigest() + extensions
return hash_filename
def cached_path(path: str, fs: Optional[AbstractFileSystem] = None) -> str:
"""
Returns the cached path for a given resource.
If the resource is already available locally, the function returns the path as is.
Otherwise, it downloads the resource from the specified path and saves it in the cache directory.
Args:
path (str): The path to the resource.
Returns:
str: The cached path of the resource.
"""
if is_local(path):
# Implementation goes here
pass
return path
destination = f"{get_cache_dir()}/{resource_to_filename(path)}"
remote_fs = fs or get_fs(path)
local_fs = get_fs(destination)
if exists(destination, fs=local_fs):
LOGGER.info(f"Using cached file {destination} for {path}")
return destination
if is_dir(path, fs=remote_fs):
for sub_path in glob_path(path, fs=remote_fs):
rel_path = sub_prefix(sub_path, path)
dest_path = join_path("", destination, rel_path)
mkdir_p(parent(dest_path), fs=local_fs)
LOGGER.info(f"Downloading {sub_path} to {dest_path}")
with smart_open.open(sub_path, "rb") as src, smart_open.open(dest_path, "wb") as dest:
dest.write(src.read())
else:
LOGGER.info(f"Downloading {path} to {destination}")
with smart_open.open(path, "rb") as src, smart_open.open(destination, "wb") as dest:
dest.write(src.read())
return destination
def split_basename_and_extension(path: str) -> Tuple[str, str]:
"""
Get the path and extension from a given file path. If a file has multiple
extensions, they will be joined with a period, e.g. "foo/bar/baz.tar.gz"
will return ("foo/bar/baz", ".tar.gz"). If the file has no extension, the
second element of the tuple will be an empty string. Works with both local
and remote (e.g. s3://) paths.
Args:
path (str): The file path.
Returns:
Tuple[str, str]: A tuple containing the path and extension.
"""
prot, (*parts, filename) = split_path(path)
base, *ext_parts = filename.split(".")
ext = ("." + ".".join(ext_parts)) if ext_parts else ""
return join_path(prot, *parts, base), ext
def decompress_path(path: str, dest: Optional[str] = None) -> str:
"""
Decompresses a file at the given path and returns the path to the decompressed file.
Args:
path (str): The path to the file to be decompressed.
dest (str, optional): The destination path for the decompressed file.
If not provided, a destination path will be computed based on the original
file name and the cache directory.
Returns:
str: The path to the decompressed file. If the file cannot be decompressed,
the original path will be returned.
"""
for supported_ext in get_supported_extensions():
# not the supported extension
if not path.endswith(supported_ext):
continue
if dest is None:
# compute the name for the decompressed file; to do this, we first hash for
# resource and then remove the extension.
base_fn, ext = split_basename_and_extension(resource_to_filename(path))
# to get the decompressed file name, we remove the bit of the extension that
# indicates the compression type.
decompressed_fn = base_fn + ext.replace(supported_ext, "")
# finally, we get cache directory and join the decompressed file name to it
dest = join_path("", get_cache_dir(), decompressed_fn)
# here we do the actual decompression
with smart_open.open(path, "rb") as fr, smart_open.open(dest, "wb") as fw:
fw.write(fr.read())
# return the path to the decompressed file
return dest
# already decompressed or can't be decompressed
return path
def split_ext(path: str) -> Tuple[str, Tuple[str, ...], str]:
"""
Split a path into its protocol and extensions.
"""
prot, parts = split_path(path)
if not parts:
return prot, (), ""
filename = parts[-1]
extensions = []
while True:
filename, ext = os.path.splitext(filename)
if not ext:
break
extensions.append(ext)
return prot, (*parts[:-1], filename), "".join(reversed(extensions))
def get_unified_path(paths: List[str]) -> str:
"""Get a unified path for a list of paths."""
if len(paths) == 1:
# if there is only one path, we don't need to unify anything
return paths[0]
# get shared root for all paths; we will put the unified path here
root, relative = make_relative(paths)
# get the extension from the first path; assume all paths have the same extension
_, _, ext = split_ext(relative[0])
# hash all the sorted relative paths in order to get a unique name
# the type: ignore is needed because mypy fails to infer the type of the lambda
# (the "or" ensures that the lambda returns the same type as the first argument, which is a hash)
h = reduce(lambda h, p: h.update(p.encode()) or h, sorted(relative), sha256()) # type: ignore
# return the unified path
return join_path(root, h.hexdigest() + ext)

View File

@ -1,27 +0,0 @@
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class BeakerState:
job_id: Optional[str] = None
job_kind: Optional[str] = None
task_id: Optional[str] = None
experiment_id: Optional[str] = None
replica_rank: Optional[str] = None
leader_replica_hostname: Optional[str] = None
leader_replica_node_id: Optional[str] = None
user_id: Optional[str] = None
def __post_init__(self):
for key, value in os.environ.items():
if not key.startswith("BEAKER_"):
continue
setattr(self, key.lstrip("BEAKER_").lower(), value)
@property
def url(self) -> Optional[str]:
if self.job_id:
return f"https://beaker.org/jobs/{self.job_id}"
return None

View File

@ -1,165 +0,0 @@
import glob
import logging
import os
import re
from typing import Optional
import boto3
from datasets import Dataset, load_dataset
from filelock import FileLock
from olmocr.data.renderpdf import get_pdf_media_box_width_height
from olmocr.prompts.anchor import get_anchor_text
from olmocr.s3_utils import parse_custom_id, parse_s3_path
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR)
logging.getLogger("smart_open").setLevel(logging.ERROR)
def list_dataset_files(s3_glob_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
if s3_glob_path.startswith("s3://"):
s3 = boto3.client("s3")
match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path)
if not match:
logger.error(f"Invalid S3 path: {s3_glob_path}")
raise ValueError(f"Invalid S3 path: {s3_glob_path}")
bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
files = []
pattern = re.compile(prefix_pattern.replace("*", ".*"))
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files
else:
return glob.glob(s3_glob_path)
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: Optional[int] = None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
all_json_files = list_dataset_files(s3_glob_path)
if first_n_files:
all_json_files = all_json_files[:first_n_files]
# Use datasets library to load JSON files from S3
dataset = load_dataset(
"json",
data_files=all_json_files,
)
return dataset
def extract_openai_batch_response(example):
custom_id = example.get("custom_id", None)
# Parse the custom id into an s3 document path and page number (1indexed)
s3_path, page_num = parse_custom_id(custom_id)
response_body = example.get("response", {}).get("body", {})
choices = response_body.get("choices", [])
response = ""
finish_reason = ""
if choices:
first_choice = choices[0]
message = first_choice.get("message", {})
response = message.get("content", "")
finish_reason = first_choice.get("finish_reason", "")
# TODO Maybe in the future we can parse the response (which is a structured JSON document itself)
# into its own columns
return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason}
def _cache_s3_file(s3_path: str, local_cache_dir: str):
"""
Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file.
"""
bucket, key = parse_s3_path(s3_path)
# Define the local file path
local_file_path = os.path.join(local_cache_dir, bucket + "__" + key.replace("/", "_"))
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
lock_file = f"{local_file_path}.lock"
# Use a file lock to prevent concurrent writes
with FileLock(lock_file):
if not os.path.exists(local_file_path):
logger.info(f"Downloading {s3_path} to {local_file_path}")
s3_client = boto3.client("s3", aws_access_key_id=os.getenv("DS_AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("DS_AWS_SECRET_ACCESS_KEY"))
s3_client.download_file(bucket, key, local_file_path)
else:
pass
# logger.info(f"File {local_file_path} already exists, skipping download.")
return local_file_path
def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset:
"""
Caches all S3 paths in the dataset to the local cache directory.
"""
# Define the download function to use in parallel processing
def cache_file(example):
s3_path = example["s3_path"]
if s3_path:
# Download the file and cache it locally
local_path = _cache_s3_file(s3_path, pdf_cache_location)
return {"local_pdf_path": local_path}
return {"local_pdf_path": None}
# Map the caching function to the dataset (with parallelism if needed)
dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False)
return dataset
def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str] = None, num_proc: int = 32) -> Dataset:
if pdf_cache_location is None:
pdf_cache_location = os.path.join(os.path.expanduser("~"), ".cache", "olmocr_pdfs")
logger.info("Loading fine tuning dataset from OpenAI style batch responses")
response_data = load_jsonl_into_ds(response_glob_path)
response_data = response_data["train"]
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)
# Don't include data where the model cut off due to a length issue, or moderation issue
logger.info("Filtering on finish_reason == stop")
final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
# Cache all the s3_paths that were accessed to a local storage location,
final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc)
# Filter out pages where you cannot get an anchor text generated, to prevent errors during actual training
def _can_create_anchor_text(example):
try:
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=4000)
_ = get_pdf_media_box_width_height(example["local_pdf_path"], example["page_num"])
return anchor_text is not None
except:
logger.exception("Could not generate anchor text for file, be sure you have all dependencies installed")
return False
final_dataset = final_dataset.filter(_can_create_anchor_text, num_proc=num_proc)
return final_dataset

View File

@ -1,164 +0,0 @@
import base64
import random
from io import BytesIO
from typing import Union
import numpy as np
import torch # Make sure to import torch as it's used in the DataCollator
from PIL import Image
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
if isinstance(target_longest_image_dim, list):
target_longest_image_dim = random.choice(target_longest_image_dim)
if isinstance(target_anchor_text_len, list):
target_anchor_text_len = random.choice(target_anchor_text_len)
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
# Prepare messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": base64_page_image},
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
],
}
]
# Apply chat template to get the text
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
# Process inputs using processor
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)
# Get labels by tokenizing the output text
labels = processor(text=[example["response"]], padding=True, return_tensors="np")
# Append an <|im_end|>\n" to the labels, because this is what it would look like
# if we passed the whole message stream in there
im_end_tokens = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype) # Ensure correct dtype
# Handle the case where labels['input_ids'] is empty
if labels["input_ids"].shape[1] == 0:
labels_input_ids_0 = np.array([], dtype=inputs.input_ids.dtype)
else:
labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype)
labels["input_ids"] = np.concatenate([labels_input_ids_0, im_end_tokens])
labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0)
# Concatenate input_ids and labels
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
# All columns will participate in attention fully
attention_mask = np.ones_like(input_ids)
# Create labels, masking the input portion with -100
labels_full = np.full_like(input_ids, fill_value=-100)
labels_full[len(inputs.input_ids[0]) :] = labels.input_ids[0]
# TODO Maybe cap the max length
# Return as dict, including pixel_values
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0],
}
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
# Process each example in the batch using the helper function
processed_examples = []
for i in range(len(batch["response"])):
example = {"local_pdf_path": batch["local_pdf_path"][i], "page_num": batch["page_num"][i], "response": batch["response"][i]}
processed_example = prepare_data_for_qwen2_training(
example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
)
processed_examples.append(processed_example)
return {
"input_ids": [x["input_ids"] for x in processed_examples],
"attention_mask": [x["attention_mask"] for x in processed_examples],
"labels": [x["labels"] for x in processed_examples],
"pixel_values": [x["pixel_values"] for x in processed_examples],
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}
def prepare_data_for_molmo_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
if isinstance(target_longest_image_dim, list):
target_longest_image_dim = random.choice(target_longest_image_dim)
if isinstance(target_anchor_text_len, list):
target_anchor_text_len = random.choice(target_anchor_text_len)
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
# Process the input text and image
inputs = processor.process(
images=[main_image],
text=build_finetuning_prompt(anchor_text),
)
# Get labels by tokenizing the output text
labels = processor.tokenizer(example["response"], return_tensors="np")["input_ids"][0]
# Concatenate input_ids and labels
full_input_ids = torch.cat([inputs["input_ids"], torch.from_numpy(labels)], dim=0)
labels_full = torch.cat([torch.ones_like(inputs["input_ids"]) * -100, torch.from_numpy(labels)], dim=0)
# Create a full attention mask
attention_mask = torch.ones_like(full_input_ids)
# image_input_idx does not need adjustment as images are inserted before labels
image_input_idx = inputs["image_input_idx"]
return {
"input_ids": full_input_ids,
"labels": labels_full,
"images": inputs["images"],
"image_input_idx": image_input_idx,
"image_masks": inputs["image_masks"],
"attention_mask": attention_mask,
}
def batch_prepare_data_for_molmo_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
# Assume batch size 1 and process the single example
example = {"local_pdf_path": batch["local_pdf_path"][0], "page_num": batch["page_num"][0], "response": batch["response"][0]}
processed_example = prepare_data_for_molmo_training(
example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
)
# Return in the same format as the qwen2 function
return {
"input_ids": [processed_example["input_ids"]],
"attention_mask": [processed_example["attention_mask"]],
"labels": [processed_example["labels"]],
"images": [processed_example["images"]],
"image_input_idx": [processed_example["image_input_idx"]],
"image_masks": [processed_example["image_masks"]],
}

View File

@ -1,122 +0,0 @@
import argparse
import concurrent.futures
import json
import os
import boto3
import torch
from smart_open import smart_open
from tqdm import tqdm
from transformers import Qwen2_5_VLForConditionalGeneration
from olmocr.s3_utils import parse_s3_path
s3_client = boto3.client("s3")
def download_file_from_s3(bucket_name, key, local_file_path):
"""Download a single file from S3."""
s3_client.download_file(bucket_name, key, local_file_path)
print(f"Downloaded {key} to {local_file_path}")
def download_model_from_s3(bucket_name, model_s3_key, local_model_dir):
if not os.path.exists(local_model_dir):
os.makedirs(local_model_dir)
# List objects in the S3 model path
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=model_s3_key)
objects = response.get("Contents", [])
# Prepare list of download tasks
download_tasks = []
for obj in objects:
key = obj["Key"]
if key.endswith("/"):
continue # Skip directories
local_file_path = os.path.join(local_model_dir, os.path.basename(key))
download_tasks.append((bucket_name, key, local_file_path))
# Use a ThreadPoolExecutor to download files in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks]
# Wait for all downloads to complete and handle any exceptions
for future in tqdm(concurrent.futures.as_completed(futures)):
try:
future.result() # This will raise any exceptions encountered during download
except Exception as e:
print(f"Error downloading file: {e}")
def upload_file_to_s3(local_file_path, bucket_name, s3_key):
"""Upload a single file to S3."""
try:
s3_client.upload_file(local_file_path, bucket_name, s3_key)
print(f"Uploaded {local_file_path} to s3://{bucket_name}/{s3_key}")
except Exception as e:
print(f"Error uploading {local_file_path} to s3://{bucket_name}/{s3_key}: {e}")
def save_model_to_s3(local_model_dir, bucket_name, s3_model_key):
"""Upload the model directory to S3 in parallel."""
# Collect all file paths to be uploaded
upload_tasks = []
for root, dirs, files in os.walk(local_model_dir):
for file in files:
local_file_path = os.path.join(root, file)
s3_key = os.path.join(s3_model_key, file)
upload_tasks.append((local_file_path, bucket_name, s3_key))
# Use a ThreadPoolExecutor to upload files in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(upload_file_to_s3, local_file_path, bucket_name, s3_key) for local_file_path, bucket_name, s3_key in upload_tasks]
# Wait for all uploads to complete and handle any exceptions
for future in concurrent.futures.as_completed(futures):
try:
future.result() # This will raise any exceptions encountered during upload
except Exception as e:
print(f"Error during upload: {e}")
def main():
parser = argparse.ArgumentParser(description="Fix up a Qwen2VL checkpoint saved on s3 or otherwise, so that it will load properly in vllm/birr")
parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.")
args = parser.parse_args()
# Now, download the config.json from the original path and verify the architectures
config_path = os.path.join(args.s3_path, "config.json")
with smart_open(config_path, "r") as f:
config_data = json.load(f)
assert config_data["architectures"] == ["Qwen2_5_VLForConditionalGeneration"]
if config_data["torch_dtype"] == "float32":
print("Detected model is float32, this is probably an FSDP checkpoint")
print("Saving to _bf16 location with adjusted parameters")
bucket, prefix = parse_s3_path(args.s3_path)
td = "/tmp/qwen2_checkpoint_saving"
download_model_from_s3(bucket, prefix, td)
print("Downloaded entire model from s3, resaving as bfloat16")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(td)
model = model.to(torch.bfloat16)
os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)
print("Saving...")
model.save_pretrained(os.path.join(td, "bf16_checkpoint"))
print("Uploading")
save_model_to_s3(os.path.join(td, "bf16_checkpoint"), bucket, prefix.rstrip("/") + "/bf16")
args.s3_path = args.s3_path.rstrip("/") + "/bf16"
print("Model updated successfully.")
if __name__ == "__main__":
main()

View File

@ -1,371 +0,0 @@
# Script to generate parquet dataset files to upload to hugging face
# Input is a dataset location /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
# Each json line has a custom id that looks like {"custom_id": "s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1", ... more data}
# Fix this script so that it works, and that it will take a path to an input dataset, and sqllite database location
# And then it will build a parquet file with rows that look like: "id", "url", "page_number", "response"
# Where Id will be the output of parse_pdf_hash plus "-" plus the page number
# The url will be the result of get_uri_from_db
# Rresponse will be NormalizedEntry.text
import argparse
import concurrent.futures
import glob
import json
import multiprocessing
import os
import re
import sqlite3
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse
import boto3
import pandas as pd
from pypdf import PdfReader, PdfWriter
from tqdm import tqdm
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
"""
Extracts a hash from a pretty PDF S3 URL.
For example, given:
s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1
it will return "de80a57e6c57b45796d2e020173227f7eae44232".
"""
# Allow an optional "-<number>" at the end.
if pretty_pdf_path.startswith("s3://ai2-s2-pdfs/"):
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$"
match = re.match(pattern, pretty_pdf_path)
if match:
return match.group(1) + match.group(2)
return None
elif pretty_pdf_path.startswith("s3://ai2-oe-data/reganh/iabooks/"):
return urlparse(pretty_pdf_path).path.split("/")[-1]
else:
raise NotImplementedError()
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
"""
Looks up the URL for the given pdf_hash in the sqlite database.
Assumes there is a table called 'pdf_mapping' with a column 'uri'.
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
result = cursor.fetchone()
conn.close()
return result[0].strip() if result and result[0] else None
@dataclass(frozen=True)
class NormalizedEntry:
s3_path: str
pagenum: int
text: Optional[str]
finish_reason: Optional[str]
error: Optional[str] = None
@staticmethod
def from_goldkey(goldkey: str, **kwargs):
"""
Constructs a NormalizedEntry from a goldkey string.
The goldkey is expected to be of the format:
<s3_path>-<page_number>
"""
s3_path = goldkey[: goldkey.rindex("-")]
page_num = int(goldkey[goldkey.rindex("-") + 1 :])
return NormalizedEntry(s3_path, page_num, **kwargs)
@property
def goldkey(self):
return f"{self.s3_path}-{self.pagenum}"
def normalize_json_entry(data: dict) -> NormalizedEntry:
"""
Normalizes a JSON entry from any of the supported formats.
It supports:
- Birr: looks for an "outputs" field.
- Already normalized entries: if they contain s3_path, pagenum, etc.
- OpenAI: where the response is in data["response"]["body"]["choices"].
- SGLang: where the response is in data["response"]["choices"].
"""
if "outputs" in data:
# Birr case
if data["outputs"] is None:
text = None
finish_reason = None
else:
text = data["outputs"][0]["text"]
finish_reason = data["outputs"][0]["finish_reason"]
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=text,
finish_reason=finish_reason,
error=data.get("completion_error", None),
)
elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]):
# Already normalized
return NormalizedEntry(**data)
elif "response" in data and "body" in data["response"] and "choices" in data["response"]["body"]:
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=data["response"]["body"]["choices"][0]["message"]["content"],
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"],
)
else:
raise ValueError("Unsupported JSON format")
def parse_s3_url(s3_url: str) -> Tuple[str, str]:
"""
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
"""
if not s3_url.startswith("s3://"):
raise ValueError(f"Invalid S3 URL: {s3_url}")
s3_path = s3_url[5:]
bucket, key = s3_path.split("/", 1)
return bucket, key
def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
"""
Downloads the PDF from the given S3 URL into the specified cache directory.
The destination filename is based on the parsed PDF hash.
Returns the path to the downloaded PDF.
"""
try:
bucket, key = parse_s3_url(s3_url)
s3_client = boto3.client("s3")
pdf_hash = parse_pdf_hash(s3_url)
if not pdf_hash:
# Fallback: use a sanitized version of the s3_url
pdf_hash = re.sub(r"\W+", "_", s3_url)
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
# Avoid re-downloading if already exists
if not os.path.exists(dest_path):
s3_client.download_file(bucket, key, dest_path)
return dest_path
except Exception as e:
print(f"Error downloading {s3_url}: {e}")
return None
def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
"""
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
Writes a new single-page PDF to the output_pdf_dir using the combined_id as the filename.
Returns the relative path to the new PDF (e.g., "pdfs/<combined_id>.pdf").
"""
try:
local_cached_pdf = pdf_cache.get(s3_url)
if not local_cached_pdf or not os.path.exists(local_cached_pdf):
print(f"Cached PDF not found for {s3_url}")
return None
reader = PdfReader(local_cached_pdf)
# pypdf uses 0-indexed page numbers
page_index = page_number - 1
if page_index < 0 or page_index >= len(reader.pages):
print(f"Page number {page_number} out of range for PDF {s3_url}")
return None
writer = PdfWriter()
writer.add_page(reader.pages[page_index])
output_filename = f"{combined_id}.pdf"
output_path = os.path.join(output_pdf_dir, output_filename)
with open(output_path, "wb") as f_out:
writer.write(f_out)
# Return the relative path (assuming pdfs/ folder is relative to the parquet file location)
return os.path.join("pdfs", output_filename)
except Exception as e:
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
return None
def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
"""
Process a single file and return a tuple:
(list of valid rows, number of rows skipped due to missing URL or PDF extraction/filtering).
For each JSON entry, the function:
- Normalizes the JSON.
- Skips entries whose response contains the word "resume" (any case) along with either an email address or a phone number.
- Extracts the PDF hash and builds the combined id.
- Looks up the corresponding URL from the sqlite database.
- Extracts the specified page from the cached PDF and writes it to output_pdf_dir.
- Outputs a row with "id", "url", "page_number", "response".
"""
rows = []
missing_count = 0
email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"
phone_regex = r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b"
try:
with open(file_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError as e:
print(f"Skipping invalid JSON at {file_path}:{line_num} - {e}")
continue
try:
normalized = normalize_json_entry(data)
except Exception as e:
print(f"Error normalizing entry at {file_path}:{line_num} - {e}")
continue
# Apply filter: skip if response contains "resume" (any case) and an email or phone number.
response_text = normalized.text if normalized.text else ""
if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)):
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
continue
# Extract the PDF hash from the s3_path.
pdf_hash = parse_pdf_hash(normalized.s3_path)
if pdf_hash is None:
print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}")
continue
# The output id is the pdf hash plus '-' plus the page number.
combined_id = f"{pdf_hash}-{normalized.pagenum}"
# Look up the corresponding URL from the sqlite database.
url = get_uri_from_db(db_path, pdf_hash)
if not url:
print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}")
missing_count += 1
continue
# Process PDF: extract the specified page from the cached PDF.
local_pdf_path = process_pdf_page(normalized.s3_path, normalized.pagenum, combined_id, output_pdf_dir, pdf_cache)
if local_pdf_path is None:
print(f"Skipping entry because PDF processing failed for {normalized.s3_path} page {normalized.pagenum} at {file_path}:{line_num}")
missing_count += 1
continue
row = {
"id": combined_id,
"url": url,
"page_number": normalized.pagenum,
"response": normalized.text,
}
rows.append(row)
except Exception as e:
print(f"Error processing file {file_path}: {e}")
return rows, missing_count
def scan_file_for_s3_urls(file_path: str) -> Set[str]:
"""
Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
"""
urls = set()
try:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
normalized = normalize_json_entry(data)
urls.add(normalized.s3_path)
except Exception:
# Skip entries that cannot be normalized
continue
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return urls
def main():
parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.")
parser.add_argument(
"input_dataset",
help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')",
)
parser.add_argument("db_path", help="Path to the SQLite database file.")
parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.")
args = parser.parse_args()
files = glob.glob(args.input_dataset)
print(f"Found {len(files)} files matching pattern: {args.input_dataset}")
# Determine output directory and create 'pdfs' subfolder.
output_abs_path = os.path.abspath(args.output)
output_dir = os.path.dirname(output_abs_path)
pdfs_dir = os.path.join(output_dir, "pdfs")
os.makedirs(pdfs_dir, exist_ok=True)
# Create a temporary directory for caching PDFs.
pdf_cache_dir = "/tmp/pdf_cache"
os.makedirs(pdf_cache_dir, exist_ok=True)
print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")
# ---------------------------------------------------------------------
# Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor.
unique_s3_urls: Set[str] = set()
print("Scanning input files to collect unique PDF URLs...")
num_cpus = multiprocessing.cpu_count()
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 4) as executor:
results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files"))
for url_set in results:
unique_s3_urls |= url_set
print(f"Found {len(unique_s3_urls)} unique PDF URLs.")
# ---------------------------------------------------------------------
# Step 2: Download all unique PDFs to the cache directory.
pdf_cache: Dict[str, str] = {}
print("Caching PDFs from S3...")
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor:
future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls}
for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"):
s3_url = future_to_url[future]
try:
local_path = future.result()
if local_path:
pdf_cache[s3_url] = local_path
else:
print(f"Failed to cache PDF for {s3_url}")
except Exception as e:
print(f"Error caching PDF for {s3_url}: {e}")
# ---------------------------------------------------------------------
# Step 3: Process input files using the precached PDFs.
all_rows = []
total_missing = 0
print("Processing files...")
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files}
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"):
file_path = futures[future]
try:
rows, missing_count = future.result()
all_rows.extend(rows)
total_missing += missing_count
except Exception as e:
print(f"Error processing file {file_path}: {e}")
if all_rows:
df = pd.DataFrame(all_rows)
# Set the "id" column as the index.
df.set_index("id", inplace=True)
df.to_parquet(args.output)
valid_count = len(df)
total_processed = valid_count + total_missing
print(f"Successfully wrote {valid_count} rows to {args.output}")
print(f"Rows skipped due to missing URL/PDF or filtering: {total_missing} out of {total_processed} processed rows")
else:
print("No valid rows to write. Exiting.")
if __name__ == "__main__":
main()

View File

@ -1,97 +0,0 @@
import logging
import os
import tarfile
from math import ceil
from huggingface_hub import HfApi
# Configuration
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
os.makedirs(tarball_dir, exist_ok=True)
repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID
# Set up logging to file
logging.basicConfig(filename="upload.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def process_chunk(args):
"""
Worker function to create a tar.gz file for a given chunk.
Returns a tuple: (chunk_index, success (bool), message).
"""
chunk_index, chunk_files = args
tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz"
tarball_path = os.path.join(tarball_dir, tarball_name)
try:
with tarfile.open(tarball_path, "w:gz") as tar:
for pdf_filename in chunk_files:
pdf_path = os.path.join(pdf_dir, pdf_filename)
# Add the file with its basename to maintain a flat structure
tar.add(pdf_path, arcname=pdf_filename)
logging.info(f"Chunk {chunk_index:04d}: Created '{tarball_name}' with {len(chunk_files)} PDFs.")
return chunk_index, True, "Success"
except Exception as e:
error_msg = f"Chunk {chunk_index:04d}: Error creating '{tarball_name}': {e}"
logging.error(error_msg)
return chunk_index, False, error_msg
def main():
# List all PDF files (assuming a flat directory)
try:
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")])
except Exception as e:
logging.error(f"Error listing PDFs in '{pdf_dir}': {e}")
return
total_files = len(pdf_files)
chunk_size = 5000
total_chunks = ceil(total_files / chunk_size)
logging.info(f"Found {total_files} PDFs; dividing into {total_chunks} chunks of up to {chunk_size} files each.")
# # Enumerate chunks (starting at 0000)
# chunks = []
# for idx in range(total_chunks):
# start = idx * chunk_size
# end = start + chunk_size
# chunk_files = pdf_files[start:end]
# chunks.append((idx, chunk_files))
# # Create tarballs in parallel
# results = []
# with ProcessPoolExecutor() as executor:
# futures = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
# for future in tqdm(as_completed(futures), total=len(futures), desc="Creating tarballs"):
# try:
# result = future.result()
# results.append(result)
# chunk_index, success, message = result
# if not success:
# logging.error(f"Chunk {chunk_index:04d} failed: {message}")
# except Exception as e:
# logging.error(f"Unexpected error processing a chunk: {e}")
# # Abort upload if any tarball creation failed
# failed_chunks = [r for r in results if not r[1]]
# if failed_chunks:
# logging.error(f"{len(failed_chunks)} chunk(s) failed to create. Aborting upload.")
# return
# All tarballs created successfully; now upload the entire tarball directory
api = HfApi()
logging.info("Starting upload of tarballs folder to Hugging Face Hub...")
# This will upload all files in tarball_dir to the repo under "pdf_tarballs"
api.upload_large_folder(
folder_path=tarball_dir,
repo_id=repo_id,
# path_in_repo="pdf_tarballs",
repo_type="dataset",
)
logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.")
if __name__ == "__main__":
main()

View File

@ -1,138 +0,0 @@
#!/usr/bin/env python3
import argparse
import sqlite3
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import boto3
from tqdm import tqdm
from warcio.archiveiterator import ArchiveIterator
def parse_s3_path(s3_path):
"""
Parses an S3 path of the form s3://bucket/prefix and returns the bucket and prefix.
"""
if not s3_path.startswith("s3://"):
raise ValueError("S3 path must start with s3://")
without_prefix = s3_path[5:]
parts = without_prefix.split("/", 1)
bucket = parts[0]
prefix = parts[1] if len(parts) > 1 else ""
return bucket, prefix
def list_s3_warc_objects(s3_path, suffix=".warc.gz"):
"""
Lists all objects under the given S3 path that end with the provided suffix.
Uses a paginator to handle large result sets.
"""
bucket, prefix = parse_s3_path(s3_path)
s3_client = boto3.client("s3")
paginator = s3_client.get_paginator("list_objects_v2")
warc_keys = []
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
if "Contents" in page:
for obj in page["Contents"]:
key = obj["Key"]
if key.endswith(suffix):
warc_keys.append(key)
return bucket, warc_keys, s3_client
def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
"""
Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request,
and extracts the first response record's target URI from the HTTP headers.
"""
target_uri = None
try:
response = s3_client.get_object(Bucket=bucket, Key=key, Range=f"bytes=0-{head_bytes-1}")
stream = response["Body"]
for record in ArchiveIterator(stream):
for name, value in record.rec_headers.headers:
if name == "WARC-Target-URI":
target_uri = value
break
if target_uri:
break # Only use the first valid response record
except Exception as e:
tqdm.write(f"Error processing s3://{bucket}/{key}: {e}")
return target_uri
def create_db(db_path):
"""
Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists,
including an index on pdf_hash.
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pdf_mapping (
pdf_hash TEXT PRIMARY KEY,
uri TEXT
)
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash)
"""
)
conn.commit()
return conn
def process_warc_file(key, bucket, s3_client):
"""
Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri)
if successful, otherwise returns None.
"""
uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576)
if uri:
# Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf.
pdf_hash = key.split("/")[-1].replace(".warc.gz", ".pdf")
return (pdf_hash, uri)
else:
tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}")
return None
def process_s3_folder(s3_path, db_path):
"""
Lists all .warc.gz files under the provided S3 path, then processes each file in parallel
to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's
basename with .warc.gz replaced by .pdf) is stored in the SQLite database.
"""
bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix=".warc.gz")
conn = create_db(db_path)
cursor = conn.cursor()
# Process WARC files concurrently using ThreadPoolExecutor.
results = []
func = partial(process_warc_file, bucket=bucket, s3_client=s3_client)
with ThreadPoolExecutor() as executor:
for result in tqdm(executor.map(func, warc_keys), total=len(warc_keys), desc="Processing S3 WARC files"):
if result is not None:
results.append(result)
# Bulk insert into the database.
conn.execute("BEGIN")
for pdf_hash, uri in results:
cursor.execute("INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", (pdf_hash, uri))
conn.commit()
conn.close()
def main():
parser = argparse.ArgumentParser(description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files.")
parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files")
parser.add_argument("db_file", help="Path for the output SQLite database file")
args = parser.parse_args()
process_s3_folder(args.s3_path, args.db_file)
if __name__ == "__main__":
main()

View File

@ -1,66 +0,0 @@
import base64
from io import BytesIO
import torch
import torch.distributed
from PIL import Image
from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import build_openai_silver_data_prompt
@torch.no_grad()
def run_inference(model_name: str):
config = AutoConfig.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
# If it doesn't load, change the type:mrope key to "default"
# model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model.eval()
# local_pdf_path = os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "horribleocr.pdf")
local_pdf_path = "/root/brochure.pdf"
page = 1
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
# Preparation for inference
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(inputs["input_ids"], output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
print(output_text[0])
def main():
run_inference(model_name="Qwen/Qwen2.5-VL-7B-Instruct")
if __name__ == "__main__":
main()

View File

@ -1,50 +0,0 @@
from transformers import AutoProcessor
from olmocr.train.core.cli import make_cli
from olmocr.train.core.config import TrainConfig
from .utils import make_dataset
def main():
train_config = make_cli(TrainConfig) # pyright: ignore
processor = AutoProcessor.from_pretrained(train_config.model.name_or_path, trust_remote_code=True)
train_dataset, valid_dataset = make_dataset(train_config, processor)
print("Training dataset........")
print(train_dataset)
train_example = train_dataset[0]
print(train_example)
print({(x, y.shape) for x, y in train_example.items()})
print("\nTokens")
print(processor.tokenizer.batch_decode(train_example["input_ids"]))
print("\n\n")
print("Validation dataset........")
print(valid_dataset)
print(valid_dataset[list(valid_dataset.keys())[0]][0])
print("\n\n")
print("Datasets loaded into hugging face cache directory")
# data_collator = TruncatingCollator(
# max_length=4096
# )
# train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=data_collator)
# max_seen_len = 0
# for index, entry in tqdm(enumerate(train_dataloader)):
# if index == 0:
# print(entry)
# num_input_tokens = entry["input_ids"].shape[1]
# max_seen_len = max(max_seen_len, num_input_tokens)
# print(max_seen_len)
if __name__ == "__main__":
main()

View File

@ -1,59 +0,0 @@
from transformers import PretrainedConfig
class MolmoConfig(PretrainedConfig):
model_type = "molmo"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=50304,
embedding_size=50304,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
max_position_embeddings=2048,
initializer_range=0.02,
use_cache=True,
layer_norm_eps: float = 1e-5,
rope_theta=10000.0,
clip_qkv=None,
qkv_bias: bool = False,
weight_tying: bool = False,
use_position_ids: bool = True,
tie_word_embeddings: bool = True,
attention_layer_norm: bool = False,
norm_after: bool = False,
layer_norm_type: str = "rms",
**kwargs,
):
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.layer_norm_eps = layer_norm_eps
self.weight_tying = weight_tying
self.use_position_ids = use_position_ids
self.attention_layer_norm = attention_layer_norm
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.clip_qkv = clip_qkv
self.qkv_bias = qkv_bias
self.norm_after = norm_after
self.tie_word_embeddings = tie_word_embeddings
self.layer_norm_type = layer_norm_type
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
MolmoConfig.register_for_auto_class()

View File

@ -1,497 +0,0 @@
"""Image processor class for Molmo"""
from typing import List, Optional, Union
import einops
import numpy as np
import torch
import torchvision.transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import convert_image_dtype
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput
from transformers.processing_utils import ImagesKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
def pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width, value=0):
height, width = image.shape[:2]
after_padding_width = target_width - offset_width - width
after_padding_height = target_height - offset_height - height
return np.pad(image, [[offset_height, after_padding_height], [offset_width, after_padding_width], [0, 0]], constant_values=value)
def normalize_image(image, offset, scale):
image -= np.array(offset, dtype=np.float32)[None, None, :]
image /= np.array(scale, dtype=np.float32)[None, None, :]
return image
def resize_and_pad(
image,
desired_output_size,
resize_method="torch-bilinear",
pad_value=0,
normalize=True,
image_mean=OPENAI_CLIP_MEAN,
image_std=OPENAI_CLIP_STD,
):
desired_height, desired_width = desired_output_size
height, width = image.shape[:2]
# Cast into float32 since the training code did this in float32 and it (very rarely) effects
# the results after rounding.
image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
image_scale = min(image_scale_x, image_scale_y)
scaled_height = int(np.array(height, np.float32) * image_scale)
scaled_width = int(np.array(width, np.float32) * image_scale)
if resize_method == "tensorflow":
# This how the original training code did resizing, it can produce slightly different
# results then using torch resize so we keep it just in case
import tensorflow as tf
image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
image = tf.image.resize(
image,
[scaled_height, scaled_width],
method=tf.image.ResizeMethod.BILINEAR,
antialias=True,
)
image = tf.clip_by_value(image, 0.0, 1.0)
image = image.numpy()
elif resize_method == "torch-bilinear":
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
image = convert_image_dtype(image) # resize in float32 to match the training code
image = torchvision.transforms.Resize([scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True)(image)
image = torch.clip(image, 0.0, 1.0)
image = torch.permute(image, [1, 2, 0]).numpy()
else:
raise NotImplementedError(resize_method)
top_pad = (desired_height - scaled_height) // 2
left_pad = (desired_width - scaled_width) // 2
padding = [[top_pad, desired_height - scaled_height - top_pad], [left_pad, desired_width - scaled_width - left_pad], [0, 0]]
image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
image = np.pad(image, padding, constant_values=pad_value)
if normalize:
image = normalize_image(image, offset=image_mean, scale=image_std)
return image, image_mask
def select_tiling(h, w, patch_size, max_num_patches):
"""Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
original_size = np.stack([h, w]) # [1, 2]
original_res = h * w
tilings = []
for i in range(1, max_num_patches + 1):
for j in range(1, max_num_patches + 1):
if i * j <= max_num_patches:
tilings.append((i, j))
# sort so argmin and argmax favour smaller tilings in the event of a tie
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
# How much we would need to scale the image to fit exactly in each tiling
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
if np.all(required_scale < 1):
# We are forced to downscale, so try to minimize the amount of downscaling
ix = np.argmax(required_scale)
else:
# Pick the resolution that required the least upscaling so that it most closely fits the image
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
ix = np.argmin(required_scale)
return candidate_tilings[ix]
class MolmoImagesKwargs(ImagesKwargs, total=False):
max_crops: Optional[int]
overlap_margins: Optional[List[int]]
base_image_input_size: Optional[List[int]]
image_token_length_w: Optional[int]
image_token_length_h: Optional[int]
image_patch_size: Optional[int]
image_padding_mask: Optional[bool]
class MolmoImageProcessor(BaseImageProcessor):
"""Preprocess images and multi-model inputs"""
def __init__(
self,
max_crops: int = 12,
overlap_margins: List[int] = (4, 4),
base_image_input_size: List[int] = (336, 336),
image_token_length_w: int = 12,
image_token_length_h: int = 12,
image_patch_size: int = 14,
image_padding_mask: bool = True,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.max_crops = max_crops
self.overlap_margins = overlap_margins
self.base_image_input_size = base_image_input_size
self.image_token_length_w = image_token_length_w
self.image_token_length_h = image_token_length_h
self.image_patch_size = image_patch_size
self.image_padding_mask = image_padding_mask
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
def image_to_patches_and_tokens(
self,
image: ImageInput,
image_patch_token_id: int,
image_col_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
max_crops: Optional[int] = None,
overlap_margins: Optional[List[int]] = None,
base_image_input_size: Optional[Union[int, List[int]]] = None,
image_token_length_w: Optional[int] = None,
image_token_length_h: Optional[int] = None,
image_patch_size: Optional[int] = None,
):
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
base_image_input_d = image_patch_size
tokens_per_image = image_token_length_w * image_token_length_h
image_base_patch_w = base_image_input_size[1] // base_image_input_d
image_base_patch_h = base_image_input_size[0] // base_image_input_d
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
# Discard this many patches from the (left/top, right/bottom) of crops
left_margin, right_margin = overlap_margins
# left_margin, right_margin = 2, 2
assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
total_margin_pixels = base_image_input_d * (right_margin + left_margin) # pixels removed per dim
crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
crop_window_size = crop_window_patches * base_image_input_d
tiling = select_tiling(original_image_h - total_margin_pixels, original_image_w - total_margin_pixels, crop_window_size, max_crops)
src, img_mask = resize_and_pad(image, [tiling[0] * crop_window_size + total_margin_pixels, tiling[1] * crop_window_size + total_margin_pixels])
# Now we have to split the image into crops, while keeping track of how each patch in the
# each crop should be ordered in the global image, this require a lot of tricky booking
n_crops = tiling[0] * tiling[1]
patches_arr = []
mask_arr = []
patch_ordering_arr = []
# We assume 2x2 pooling, but can allow padding the right/bottom with extra
# patches if the number of patches per side is not even
assert (crop_patches + 1) // 2 == image_token_length_h
assert (crop_patches + 1) // 2 == image_token_length_w
on = 0
on_patch = 0
for i in range(tiling[0]):
y0 = i * crop_window_size
if i == 0:
crop_y0 = 0
else:
crop_y0 = left_margin // 2
crop_h = image_base_patch_h - (right_margin + left_margin)
if i == 0:
crop_h += left_margin
if i == (tiling[0] - 1):
crop_h += right_margin
for j in range(tiling[1]):
x0 = j * crop_window_size
if j == 0:
crop_x0 = 0
else:
crop_x0 = left_margin // 2
crop_w = image_base_patch_w - (right_margin + left_margin)
if j == 0:
crop_w += left_margin
if j == (tiling[1] - 1):
crop_w += right_margin
pooled_w = (crop_w + 1) // 2
pooled_h = (crop_h + 1) // 2
patch_ordering_arr.append(
pad_to_bounding_box(
np.reshape(np.arange(on, on + pooled_h * pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)),
crop_y0,
crop_x0,
image_token_length_h,
image_token_length_w,
value=-1,
)[:, :, 0]
)
patches_arr.append(src[y0 : y0 + crop_size, x0 : x0 + crop_size])
mask_arr.append(img_mask[y0 : y0 + crop_size, x0 : x0 + crop_size])
on += pooled_h * pooled_w
on_patch += 1
patches = np.stack(patches_arr)
patch_ordering = np.stack(patch_ordering_arr)
img_mask = np.stack(mask_arr)
# Switch to [n_crops, n_patches, pixels_per_patch] format
image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
patches = einops.rearrange(
patches, "p (h dh) (w dw) c -> p (h w) (dh dw c)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w
)
img_mask = einops.rearrange(
img_mask, "p (h dh) (w dw) -> p (h w) (dh dw)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w
)
img_mask = img_mask.astype(np.float32).mean(axis=-1)
patch_ordering = np.reshape(patch_ordering, [-1])
valid = patch_ordering >= 0
# Transpose order, to get left-to-right order instead of crop-by-crop order
patch_ordering_rh = np.reshape(patch_ordering, [tiling[0], tiling[1], image_token_length_h, image_token_length_w])
patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
# The transpose will screw up which patches are masked, project the
# new order into sparse structure of `patch_ordering` to fix this
patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
# Now build the output tokens
h = tiling[0] * crop_window_patches + (right_margin + left_margin)
w = tiling[1] * crop_window_patches + (right_margin + left_margin)
per_row = np.full(
((w + 1) // 2,),
image_patch_token_id,
)
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
joint = np.tile(per_row, [(h + 1) // 2])
joint = [[image_start_token_id], joint, [image_end_token_id]]
# Finally do the same for the global image
resized, _ = resize_and_pad(image, base_image_input_size)
resized = einops.rearrange(
resized, "(h dh) (w dw) c -> (h w) (dh dw c)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w
)
patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)
# Global image goes first, so the order of patches in previous crops gets increased
patch_ordering = np.where(patch_ordering >= 0, patch_ordering + tokens_per_image, -1)
patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0)
per_row = np.full(
(image_token_length_w,),
image_patch_token_id,
)
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
extra_tokens = np.tile(per_row, [image_token_length_h])
joint = [
[image_start_token_id],
extra_tokens,
[image_end_token_id],
] + joint
joint = np.concatenate(joint, 0)
img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
return patches, joint, patch_ordering, img_mask
def build_image_input_idx(
self,
image_tokens: np.ndarray,
patch_order: np.ndarray,
image_patch_token_id: int,
no_image: Optional[bool] = None,
image_token_length_w: Optional[int] = None,
image_token_length_h: Optional[int] = None,
):
"""Converts `patch_order` into a mapping of token_id -> patch_id"""
tokens_per_image = image_token_length_w * image_token_length_h
if no_image is not None and no_image:
return np.zeros((0, tokens_per_image), np.int32)
# Indices to insert the patches
image_input_idx = image_tokens == image_patch_token_id
image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
if patch_order is not None:
n_tokens = image_input_idx.shape[0]
patch_order = np.reshape(patch_order, [-1])
n_patches = patch_order.shape[0]
valid = patch_order >= 0
n_valid_patches = valid.sum()
assert len(image_input_idx) == n_valid_patches
sorted_patch_ixs = np.zeros([n_tokens], np.int32)
sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32)
# Project the inverted mapping into same sparse structure
sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
sorted_patch_ixs_ex[valid] = sorted_patch_ixs
# Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
valid = (sorted_patch_ixs_ex >= 0).astype(np.int32)
image_input_idx = image_input_idx[sorted_patch_ixs_ex * valid]
image_input_idx = image_input_idx * valid - 100 * (1 - valid)
image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
return image_input_idx
def preprocess(
self,
image: np.ndarray,
image_patch_token_id: int,
image_col_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
max_crops: Optional[int] = None,
overlap_margins: Optional[List[int]] = None,
base_image_input_size: Optional[Union[int, List[int]]] = None,
image_token_length_w: Optional[int] = None,
image_token_length_h: Optional[int] = None,
image_patch_size: Optional[int] = None,
**kwargs,
):
"""Preprocesses an image
Returns:
crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
change between images but the other dimension are fixed
tokens: (n_tokens,) int32 tokens, pad tokens indicate where to insert the
patch features, might include other special tokens as well
image_idx: (n_crops, n_patches) index in `tokens` to put the patch features from the
crops after pooling, negative values indicates patches features to exclude
padding_mask: (n_crops, n_patches) what percent of each crop is padding, can be None
if the image mask is not being used.
"""
max_crops = max_crops or self.max_crops
overlap_margins = overlap_margins or self.overlap_margins
base_image_input_size = base_image_input_size or self.base_image_input_size
image_token_length_w = image_token_length_w or self.image_token_length_w
image_token_length_h = image_token_length_h or self.image_token_length_h
image_patch_size = image_patch_size or self.image_patch_size
crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(
image,
image_patch_token_id,
image_col_token_id,
image_start_token_id,
image_end_token_id,
max_crops,
overlap_margins,
base_image_input_size,
image_token_length_w,
image_token_length_h,
image_patch_size,
)
patch_idx = self.build_image_input_idx(
image_tokens,
patch_ordering,
image_patch_token_id,
image_token_length_w=image_token_length_w,
image_token_length_h=image_token_length_h,
)
return crops, image_tokens, patch_idx, img_mask
def multimodal_preprocess(
self,
images: np.ndarray,
tokens: List[int],
image_idx: np.ndarray,
sequence_length: int,
image_patch_token_id: int,
image_col_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
**kwargs,
):
"""Merge images and text tokens into multi-modal features for the model
:param images: images to use as input
:param tokens: input text tokens
:param image_idx: where to insert the images into `tokens`
:params image_patch_token_id: id to use of tokens that will contain image features
:params image_col_token_id: token id for image column special tokens
:params image_start_token_id: token id for image start special tokens
:params image_end_token_id: token id for image end special tokens
:params kwargs: override preprocessor default args
"""
max_total_crops = kwargs.get("max_crops") or self.max_crops
image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w
image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h
image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size
base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size
image_num_patch = (
base_image_input_size[0] // image_patch_size,
base_image_input_size[1] // image_patch_size,
)
image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask
tokens_per_image = image_token_length_w * image_token_length_h
n_pixels = image_patch_size * image_patch_size * 3
n_patches = image_num_patch[0] * image_num_patch[1]
if images is None:
return {
"input_ids": tokens,
}
else:
n = len(images)
all_crops = []
all_image_idx = []
out_tokens = []
all_crop_masks = []
for ix in range(n):
token_ix = image_idx[ix]
crops, image_tokens, patch_idx, img_mask = self.preprocess(
images[ix],
image_patch_token_id,
image_col_token_id,
image_start_token_id,
image_end_token_id,
**kwargs,
)
if token_ix == -1: # -1 is an image inserted at the very start
start = 0
token_ix = 0
end = 0
else:
start = 0 if ix == 0 else image_idx[ix - 1] + 1
end = token_ix + 1
all_image_idx.append(patch_idx + token_ix)
all_crops.append(crops)
out_tokens.append(tokens[start:token_ix])
out_tokens.append(image_tokens)
if ix == (n - 1):
out_tokens.append(tokens[end:])
if image_padding_mask:
all_crop_masks.append(img_mask)
input_ids = np.concatenate(out_tokens, 0)
images = np.concatenate(all_crops, 0)
image_input_idx = np.concatenate(all_image_idx, 0)
if image_padding_mask:
image_masks = np.concatenate(all_crop_masks, 0)
else:
image_masks = None
out = {"input_ids": input_ids, "images": images, "image_input_idx": image_input_idx}
if image_masks is not None:
out["image_masks"] = image_masks
return out
MolmoImageProcessor.register_for_auto_class()

File diff suppressed because it is too large Load Diff

View File

@ -1,184 +0,0 @@
"""
Processor class for Molmo.
"""
from typing import Optional
from PIL import ImageOps
from PIL.Image import Image
try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from .image_preprocessing_molmo import MolmoImageProcessor, MolmoImagesKwargs
logger = logging.get_logger(__name__)
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_IM_COL_TOKEN = "<im_col>"
IMAGE_PROMPT = "<|image|>"
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
def get_special_token_ids(tokenizer):
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
assert len(ids) == len(EXTRA_TOKENS)
return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
class MolmoTextKwargs(TextKwargs, total=False):
style: Optional[str]
system_prompt: Optional[str]
message_format: Optional[str]
always_start_with_space: Optional[bool]
sequence_length: Optional[int]
class MolmoProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: MolmoTextKwargs
images_kwargs: MolmoImagesKwargs
_defaults = {
"images_kwargs": {
"max_crops": 12,
"overlap_margins": [4, 4],
"base_image_input_size": [336, 336],
"image_token_length_w": 12,
"image_token_length_h": 12,
"image_patch_size": 14,
"image_padding_mask": True,
},
"text_kwargs": {
"style": "long_caption",
"system_prompt": "none",
"message_format": "role",
"always_start_with_space": True,
"sequence_length": 1536,
"padding": False,
},
}
class MolmoProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer: AutoTokenizer = None, **kwargs):
# self.image_processor = image_processor
# self.tokenizer = tokenizer
super().__init__(image_processor, tokenizer)
self._special_tokens = None
@property
def special_token_ids(self):
if self._special_tokens is None:
self._special_tokens = get_special_token_ids(self.tokenizer)
return self._special_tokens
def get_tokens_input(self, prompt, message_format, always_start_with_space):
if message_format == "none" or message_format is None:
pass
elif message_format == "role":
prompt = "User: " + prompt + " Assistant:"
else:
raise NotImplementedError(f"Message format {message_format} not implemented")
if always_start_with_space:
prompt = " " + prompt
tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
return tokens
def process(
self,
text: TextInput = None,
images: ImageInput = None,
*,
tokens: Optional[PreTokenizedInput] = None,
**kwargs: Unpack[MolmoProcessorKwargs],
):
output_kwargs = self._merge_kwargs(
MolmoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if tokens is None:
tokens = self.get_tokens_input(
text,
output_kwargs["text_kwargs"]["message_format"],
output_kwargs["text_kwargs"]["always_start_with_space"],
)
image_token_id = self.special_token_ids[IMAGE_PROMPT]
if images is not None:
if not isinstance(images, (list, tuple)):
images = [images]
image_arrays = []
for image in images:
if isinstance(image, Image):
image = image.convert("RGB")
# Handle images with EXIF orientation tags, which PIL will ignore by default
# https://github.com/python-pillow/Pillow/issues/4703
img = ImageOps.exif_transpose(image)
image_arrays.append(np.array(image))
else:
assert len(image.shape) == 3 and image.shape[-1] == 3
image_arrays.append(image.astype(np.uint8))
images = image_arrays
# For now only support inserting images at the start
image_idx = [-1] * len(images)
else:
image_idx = None
sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
out = self.image_processor.multimodal_preprocess(
images=images,
image_idx=image_idx,
tokens=np.asarray(tokens).astype(np.int32),
sequence_length=sequence_length,
image_patch_token_id=image_patch_token_id,
image_col_token_id=image_col_token_id,
image_start_token_id=image_start_token_id,
image_end_token_id=image_end_token_id,
**output_kwargs["images_kwargs"],
)
# Prepend BOS
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
out["input_ids"] = decoder_input_tokens
if "image_input_idx" in out:
# Shift patch mapping up by one since we added BOS
image_input_idx = out["image_input_idx"]
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
for k, v in out.items():
out[k] = torch.from_numpy(v)
return out
MolmoProcessor.register_for_auto_class()

View File

@ -0,0 +1,170 @@
import argparse
import json
from os import PathLike
from pathlib import Path
from typing import Optional
import pandas as pd
from huggingface_hub import snapshot_download
def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: str | PathLike, max_examples: Optional[int] = None) -> str:
"""
Prepare OLMoCR mix dataset by downloading from HuggingFace and organizing into a folder structure.
Args:
dataset_path: HuggingFace dataset path
subset: Dataset subset name
split: Dataset split (train/validation/test)
destination: Destination directory path
max_examples: Maximum number of examples to process (None for all)
"""
# Step 1: Download dataset using hugging face hub snapshot_download to destination/hugging_face folder
dest_path = Path(destination)
hugging_face_dir = dest_path / "hugging_face"
hugging_face_dir.mkdir(parents=True, exist_ok=True)
print(f"Downloading dataset {dataset_path} to {hugging_face_dir}...")
# Download the entire repository including PDFs and parquet files
local_dir = snapshot_download(
repo_id=dataset_path,
repo_type="dataset",
local_dir=hugging_face_dir,
)
print(f"Downloaded to: {local_dir}")
# Step 2: Create destination folder structure for processed markdown files
processed_dir = dest_path / f"processed_{subset}_{split}"
processed_dir.mkdir(exist_ok=True)
# Manual map to parquet files for now
assert dataset_path == "allenai/olmOCR-mix-0225", "Only supporting the olmocr-mix for now, later will support other training sets"
if subset == "00_documents" and split == "train_s2pdf":
parquet_files = [dest_path / "hugging_face" / "train-s2pdf.parquet"]
elif subset == "00_documents" and split == "eval_s2pdf":
parquet_files = [dest_path / "hugging_face" / "eval-s2pdf.parquet"]
elif subset == "01_books" and split == "train_s2pdf":
parquet_files = [dest_path / "hugging_face" / "train-iabooks.parquet"]
elif subset == "01_books" and split == "train_s2pdf":
parquet_files = [dest_path / "hugging_face" / "eval-iabooks.parquet"]
else:
raise NotImplementedError()
# Step 3: Process parquet files
total_processed = 0
total_errors = 0
for parquet_file in parquet_files:
print(f"Processing {parquet_file.name}...")
df = pd.read_parquet(parquet_file)
# Process each row
for idx, row in df.iterrows():
if max_examples and total_processed >= max_examples:
break
try:
# Extract fields from the row
# The rows in the parquet will look like url, page_number, response (json format), and id
response = row.get('response', '')
doc_id = str(idx)
assert len(doc_id) > 4
# Parse response if it's a JSON string
response_data = json.loads(response)
response = response_data
# Create folder structure using first 4 digits of id
# Make a folder structure, to prevent a huge amount of files in one folder, using the first 4 digits of the id, ex. id[:4]/id[4:].md
folder_name = doc_id[:4]
file_name = f"{doc_id[4:]}.md"
# Create directory
output_dir = processed_dir / folder_name
output_dir.mkdir(exist_ok=True)
# Write markdown file with front matter and natural text
output_file = output_dir / file_name
with open(output_file, 'w', encoding='utf-8') as f:
# Extract natural_text and other fields for front matter
natural_text = response.get('natural_text', '')
# Create front matter from other fields
front_matter = {k: v for k, v in response.items() if k != 'natural_text'}
# Write front matter
f.write("---\n")
for k, v in front_matter.items():
f.write(f"{k}: {v}\n")
f.write("---\n")
# Write natural text
f.write(natural_text)
total_processed += 1
if total_processed % 1000 == 0:
print(f"Processed {total_processed} examples...")
except Exception as ex:
print(f"Error processing line: {ex}")
total_errors += 1
if max_examples and total_processed >= max_examples:
break
print(f"Completed! Processed {total_processed} examples to {processed_dir}")
print(f"Total errors: {total_errors}")
return str(processed_dir)
def main():
parser = argparse.ArgumentParser(description="Prepare OLMoCR mix dataset")
parser.add_argument(
"--dataset-path",
type=str,
default="allenai/olmOCR-mix-0225",
help="HuggingFace dataset path (e.g., 'allenai/olmocr-mix')"
)
parser.add_argument(
"--subset",
type=str,
default="00_documents",
required=True,
help="Dataset subset name"
)
parser.add_argument(
"--split",
type=str,
default="eval_s2pdf",
required=True,
help="Dataset split ex eval_s2pdf"
)
parser.add_argument(
"--destination",
type=str,
required=True,
help="Destination directory path"
)
parser.add_argument(
"--max-examples",
type=int,
default=None,
help="Maximum number of examples to process (default: all)"
)
args = parser.parse_args()
prepare_olmocr_mix(
dataset_path=args.dataset_path,
subset=args.subset,
split=args.split,
destination=args.destination,
max_examples=args.max_examples
)
if __name__ == "__main__":
main()

View File

@ -1,228 +1,9 @@
import logging
import os
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
import torch
import torch.distributed
import wandb
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
Trainer,
TrainerCallback,
TrainingArguments,
)
from transformers.integrations import WandbCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import get_last_checkpoint
from olmocr.train.core.cli import make_cli, save_config, to_native_types
from olmocr.train.core.config import TrainConfig
from olmocr.train.core.loggers import get_logger
from olmocr.train.core.paths import copy_dir, join_path
from olmocr.train.core.state import BeakerState
from .utils import (
RunName,
TruncatingCollator,
get_local_dir,
log_trainable_parameters,
make_dataset,
setup_environment,
)
class CheckpointUploadCallback(TrainerCallback):
def __init__(self, save_path: str, logger: Optional[Logger] = None):
self.save_path = save_path
self.logger = logger or get_logger(self.__class__.__name__)
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_local_process_zero:
latest_checkpoint = get_last_checkpoint(args.output_dir)
if not latest_checkpoint:
return
dir_name = Path(latest_checkpoint).name
copy_dir(str(latest_checkpoint), f"{self.save_path}/{dir_name}")
self.logger.info("Saved checkpoint to %s", f"{self.save_path}/{dir_name}")
def update_wandb_config(config: TrainConfig, trainer: Trainer, model: torch.nn.Module):
# finding wandb callback
callbacks = [c for c in trainer.callback_handler.callbacks if isinstance(c, WandbCallback)] # pyright: ignore
if not callbacks:
raise ValueError("WandbCallback not found in trainer callbacks")
wandb_callback = callbacks[0]
peft_config = to_native_types(getattr(model, "peft_config", {}))
script_config = to_native_types(config)
beaker_envs = {k: v for k, v in os.environ.items() if k.lower().startswith("beaker")}
on_setup_fn = wandb_callback.setup
def setup_and_update(args, state, model, **kwargs):
on_setup_fn(args=args, state=state, model=model, **kwargs)
wandb.config.update({"peft": peft_config}, allow_val_change=True)
wandb.config.update({"script": script_config}, allow_val_change=True)
wandb.config.update({"beaker": beaker_envs}, allow_val_change=True)
if (run := wandb.run) and (beaker_url := BeakerState().url):
run.notes = beaker_url
wandb_callback.setup = setup_and_update
def get_rank() -> int:
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank()
return 0
def run_train(config: TrainConfig):
if get_rank() == 0:
logger_level = logging.INFO
else:
logger_level = logging.WARN
disable_progress_bars()
logger = get_logger(__name__, level=logger_level)
set_verbosity(logger_level)
run_name = RunName.get(config)
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
processor = AutoProcessor.from_pretrained(config.model.name_or_path, trust_remote_code=True)
train_dataset, valid_dataset = make_dataset(config, processor)
logger.info(train_dataset)
logger.info(valid_dataset)
if "qwen" in config.model.name_or_path.lower():
model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16, _attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
)
else:
from .molmo.config_molmo import MolmoConfig
from .molmo.modeling_molmo import MolmoForCausalLM
model_config = MolmoConfig.from_pretrained(config.model.name_or_path, trust_remote_code=True)
if model_config.max_position_embeddings < config.generate.max_length:
logger.warning(
f"ALERT, force adjusting model config max_position_embeddings upwards from {model_config.max_position_embeddings} to {config.generate.max_length}"
)
model_config.max_position_embeddings = config.generate.max_length
if config.model.use_flash_attn:
model_config.attention_type = "flash"
model = MolmoForCausalLM.from_pretrained(config.model.name_or_path, torch_dtype=torch.bfloat16, config=model_config, trust_remote_code=True)
logger.info(model)
if config.lora is not None:
peft_config = LoraConfig(
r=config.lora.rank,
lora_alpha=config.lora.alpha,
lora_dropout=config.lora.dropout,
bias=config.lora.bias, # pyright: ignore
task_type=config.lora.task_type,
target_modules=list(config.lora.target_modules),
)
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)
save_path = join_path("", config.save.path, run_name.run)
# Make sure directory exists if local
if not save_path.startswith("s3://"):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
with TemporaryDirectory() as output_dir:
training_args = TrainingArguments(
run_name=run_name.run,
logging_steps=config.hparams.log_every_steps,
output_dir=output_dir,
eval_strategy="steps",
report_to="wandb",
# report_to=[], # disable logging to wandb, we will use a custom callback
optim=config.hparams.optim,
eval_steps=config.hparams.eval_every_steps,
learning_rate=config.hparams.learning_rate,
per_device_train_batch_size=config.hparams.batch_size,
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
gradient_checkpointing=config.hparams.gradient_checkpointing,
gradient_checkpointing_kwargs=(
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
if config.hparams.gradient_checkpointing and config.lora is not None
else {}
),
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
max_steps=config.hparams.max_steps,
weight_decay=config.hparams.weight_decay,
dataloader_num_workers=config.max_workers,
load_best_model_at_end=True,
save_strategy="steps",
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
save_steps=config.save.save_every_steps,
warmup_steps=config.hparams.warmup_steps,
warmup_ratio=config.hparams.warmup_ratio,
bf16=True,
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False,
eval_on_start=True,
metric_for_best_model=config.valid_data.metric_for_best_model,
)
data_collator = TruncatingCollator(max_length=config.generate.max_length)
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
tokenizer=processor.tokenizer,
data_collator=data_collator,
callbacks=[checkpoint_callback],
)
# Train the model
trainer.train() # pyright: ignore
if get_rank() == 0:
with get_local_dir(join_path("", save_path, "best")) as best_dir:
if config.lora is not None:
logger.info("Merging LoRA adapters into the base model...")
model = model.merge_and_unload()
logger.info("LoRA adapters merged successfully.")
model.save_pretrained(best_dir)
logger.info("Saved best model to %s", best_dir)
# Uncomment to test speed of data loader
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
# for entry in tqdm(train_dataloader):
# print("Step!")
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
def main():
train_config = make_cli(TrainConfig) # pyright: ignore
run_train(train_config)
if __name__ == "__main__":
main()
# TODO Overall, this code will read in a config yaml file with omega conf
# From that config, we are going to use HuggingFace Trainer to train a model
# TODOS:
# Build a script to convert olmocr-mix to a new dataloader format
# Write a new dataloader and collator, with tests that brings in everything, only needs to support batch size 1 for this first version
# Get a basic config yaml file system working
# Get a basic hugging face trainer running, supporting Qwen2.5VL for now
# Saving and restoring training checkpoints
# Converting training checkpoints to vllm compatible checkpoinst

View File

@ -1,232 +0,0 @@
import json
import multiprocessing
import os
import random
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from hashlib import sha1
from logging import Logger
from tempfile import TemporaryDirectory
from typing import Dict, Generator, List, Optional, TypeVar
import torch
from accelerate import Accelerator
from accelerate.utils import PrecisionType
from datasets import Dataset, DatasetDict, concatenate_datasets
from transformers import AutoProcessor
from olmocr.train.dataloader import build_finetuning_dataset
from olmocr.train.dataprep import (
batch_prepare_data_for_molmo_training,
batch_prepare_data_for_qwen2_training,
)
from .core.cli import to_native_types
from .core.config import AwsConfig, DataConfig, SourceConfig, TrainConfig, WandbConfig
from .core.loggers import get_logger
from .core.paths import copy_dir, is_local
from .core.state import BeakerState
T = TypeVar("T")
def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
pt = PrecisionType(accelerator.mixed_precision)
if pt == PrecisionType.FP16:
return torch.float16
elif pt == PrecisionType.BF16:
return torch.bfloat16
elif pt == PrecisionType.FP8:
return torch.float8_e4m3fn
return torch.float32
def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) -> Dataset:
return build_finetuning_dataset(source.response_glob_path, pdf_cache_location=data_config.cache_location)
def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
random.seed(config.train_data.seed)
if "qwen" in config.model.name_or_path.lower():
batch_fn = batch_prepare_data_for_qwen2_training
elif "molmo" in config.model.name_or_path.lower():
batch_fn = batch_prepare_data_for_molmo_training
else:
raise NotImplementedError("Model format not supported")
# Retrieve the two target lengths from the first source for comparison
first_source = config.train_data.sources[0]
target_longest_image_dim = first_source.target_longest_image_dim
target_anchor_text_len = first_source.target_anchor_text_len
# Verify that all sources have the same target lengths
for source in config.train_data.sources:
if source.target_longest_image_dim != target_longest_image_dim:
raise ValueError(f"Inconsistent target_longest_image_dim found in source {source}")
if source.target_anchor_text_len != target_anchor_text_len:
raise ValueError(f"Inconsistent target_anchor_text_len found in source {source}")
# Concatenate datasets first, unfortunately you can't apply the transform before concatenation due to the library
train_dataset = concatenate_datasets([get_rawdataset_from_source(config.train_data, source) for source in config.train_data.sources])
# Apply the transform to the concatenated dataset
train_dataset = train_dataset.with_transform(
partial(
batch_fn,
processor=processor,
target_longest_image_dim=list(target_longest_image_dim),
target_anchor_text_len=list(target_anchor_text_len),
)
)
# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: get_rawdataset_from_source(config.valid_data, source).with_transform(
partial(
batch_fn,
processor=processor,
target_longest_image_dim=list(source.target_longest_image_dim),
target_anchor_text_len=list(source.target_anchor_text_len),
)
)
for source in config.valid_data.sources
}
)
return train_dataset, valid_dataset
def setup_environment(aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str):
multiprocessing.set_start_method("spawn", force=True)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "false"
if wandb_config:
os.environ["WANDB_WATCH"] = "false"
for key, value in to_native_types(wandb_config or {}).items():
if value is not None:
os.environ[f"WANDB_{key.upper()}"] = str(value)
for key, value in to_native_types(aws_config or {}).items():
if value is not None:
os.environ[f"AWS_{key.upper()}"] = str(value)
os.environ.update(kwargs)
@dataclass
class RunName:
run: str
group: str
@classmethod
def get(cls, config: TrainConfig, accelerator: Optional[Accelerator] = None) -> "RunName":
job_rank = f"-{accelerator.process_index}" if accelerator else ""
if beaker_job_id := BeakerState().job_id:
job_id = f"-{beaker_job_id}"
else:
job_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
(config_hash := sha1()).update(json.dumps(to_native_types(config)).encode())
model_name = config.model.name_or_path.replace("/", "_")
group_name = f"{model_name}-{config_hash.hexdigest()[:6]}"
run_name = f"{group_name}{job_id}{job_rank}"
return cls(group=group_name, run=run_name)
@contextmanager
def override_torch_threads(n: int):
torch_num_threads = torch.get_num_threads()
torch.set_num_threads(n)
yield
torch.set_num_threads(torch_num_threads)
@contextmanager
def temp_args(obj: T, **kwargs) -> Generator[T, None, None]:
orig = {k: getattr(obj, k) for k in kwargs.keys()}
for k, v in kwargs.items():
setattr(obj, k, v)
yield obj
for k, v in orig.items():
setattr(obj, k, v)
def log_trainable_parameters(model: torch.nn.Module, logger: Optional[Logger] = None):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for name, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
(logger or get_logger(__name__)).info(f"training with {name}")
trainable_params += param.numel()
(logger or get_logger(__name__)).info(
"trainable params: %s || all params: %s || trainable%%: %s",
f"{trainable_params:,}",
f"{all_param:,}",
f"{trainable_params / all_param:.2%}",
)
class TruncatingCollator:
def __init__(self, max_length: int):
self.max_length = max_length
def __call__(self, batch: List[Dict]) -> Dict:
# Assert that we are only handling batch size 1 for now
assert len(batch) == 1, "Only batch size 1 is supported for now"
if "pixel_values" in batch[0]:
# Qwen2 case
truncated_input_ids = torch.tensor(batch[0]["input_ids"][: self.max_length]).unsqueeze(0)
truncated_attention_mask = torch.tensor(batch[0]["attention_mask"][: self.max_length]).unsqueeze(0)
truncated_labels = torch.tensor(batch[0]["labels"][: self.max_length]).unsqueeze(0)
return {
"input_ids": truncated_input_ids,
"attention_mask": truncated_attention_mask,
"labels": truncated_labels,
"pixel_values": torch.tensor(batch[0]["pixel_values"]).unsqueeze(0),
"image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]).unsqueeze(0),
}
elif "image_input_idx" in batch[0]:
# molmo case
truncated_input_ids = batch[0]["input_ids"][: self.max_length].unsqueeze(0)
truncated_attention_mask = batch[0]["attention_mask"][: self.max_length].unsqueeze(0)
truncated_labels = batch[0]["labels"][: self.max_length].unsqueeze(0)
return {
"input_ids": truncated_input_ids,
"attention_mask": truncated_attention_mask,
"labels": truncated_labels,
"images": batch[0]["images"].unsqueeze(0),
"image_input_idx": batch[0]["image_input_idx"].unsqueeze(0),
"image_masks": batch[0]["image_masks"].unsqueeze(0),
}
else:
raise NotImplementedError()
@contextmanager
def get_local_dir(output_dir: str):
with TemporaryDirectory() as tmp_dir:
if is_local(output_dir):
yield output_dir
else:
yield tmp_dir
copy_dir(tmp_dir, output_dir)