diff --git a/.gitignore b/.gitignore index 2a1c30a..3dc7288 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ localworkspace/* math_data/* math_data_big/* gpt4otestset/* +old_train/ gpt4otestset_output/* pdfs/* olmOCR-bench/* diff --git a/olmocr/train/__init__.py b/olmocr/train/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/olmocr/train/compressqwen2checkpoint.py b/olmocr/train/compressqwen2checkpoint.py deleted file mode 100644 index 70c6320..0000000 --- a/olmocr/train/compressqwen2checkpoint.py +++ /dev/null @@ -1,31 +0,0 @@ -# pip install llmcompressor - -from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration - -MODEL_ID = "/home/ubuntu/olmocr/olmOCR-7B-0225-preview" - -model = Qwen2VLForConditionalGeneration.from_pretrained( - MODEL_ID, device_map="auto", torch_dtype="auto") -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import QuantizationModifier - -# Configure the simple PTQ quantization -# recipe = QuantizationModifier( -# targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]) - -# Configure pre-defined qwen2vl recipe -recipe = QuantizationModifier( - targets="Linear", - scheme="FP8_DYNAMIC", - ignore=["re:.*lm_head", "re:visual.*"], -) - -# Apply the quantization algorithm. -oneshot(model=model, recipe=recipe) - -# Save the model. -SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe" -model.save_pretrained(SAVE_DIR) -tokenizer.save_pretrained(SAVE_DIR) \ No newline at end of file diff --git a/olmocr/train/config/molmo-o-lora-8192.yaml b/olmocr/train/config/molmo-o-lora-8192.yaml deleted file mode 100644 index 6c9300c..0000000 --- a/olmocr/train/config/molmo-o-lora-8192.yaml +++ /dev/null @@ -1,87 +0,0 @@ -model: - name_or_path: allenai/Molmo-7B-O-0924 - arch: causal - use_flash_attn: true - -wandb: - project: pdelfin - entity: ai2-llm - -generate: - max_length: 8192 - -train_data: - seed: 1337 - cache_location: /data/jakep/pdfdata/pdelfin_cache - sources: - - name: openai_batch_data_v5_1_train - response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - name: openai_batch_data_v5_1_iabooks_train - response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - -valid_data: - cache_location: /data/jakep/pdfdata/pdelfin_cache - metric_for_best_model: openai_batch_data_v5_1_eval_loss - sources: - # These tend to be small, so you can load from s3 it's no big deal - - name: openai_batch_data_v5_1_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - name: openai_batch_data_v5_1_iabooks_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - -# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh -hparams: - batch_size: 1 - eval_batch_size: 1 - gradient_accumulation_steps: 4 - gradient_checkpointing: true - find_unused_parameters: true - clip_grad_norm: 1.0 - learning_rate: 3e-4 - max_steps: 10000 - pad_multiple_of: 16 - log_every_steps: 10 - eval_every_steps: 100 - optim: adamw_torch - lr_scheduler: cosine - weight_decay: 0.01 - warmup_ratio: 0.03 - -# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py -lora: - rank: 32 - alpha: 32 - dropout: 0.05 - task_type: CAUSAL_LM - target_modules: - # attention layers in main transformer - - att_proj - - ff_proj - - attn_out - - ff_out - # vision transformer attention and FF - - attention.wq - - attention.wk - - attention.wv - - attention.wo - - feed_forward.w1 - - feed_forward.w2 - # vision image projector - - vision_backbone.image_projector.w1 - - vision_backbone.image_projector.w2 - - vision_backbone.image_projector.w3 - -save: - path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/ - save_every_steps: 1000 - -max_workers: 10 \ No newline at end of file diff --git a/olmocr/train/config/molmo-o-lora.yaml b/olmocr/train/config/molmo-o-lora.yaml deleted file mode 100644 index e6b9e70..0000000 --- a/olmocr/train/config/molmo-o-lora.yaml +++ /dev/null @@ -1,89 +0,0 @@ -model: - name_or_path: allenai/Molmo-7B-O-0924 - arch: causal - use_flash_attn: true - -wandb: - project: pdelfin - entity: ai2-llm - -generate: - max_length: 4096 - -train_data: - seed: 1337 - cache_location: /data/jakep/pdfdata/pdelfin_cache - sources: - - name: openai_batch_data_v5_1_train - response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - name: openai_batch_data_v5_1_iabooks_train - response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - -valid_data: - cache_location: /data/jakep/pdfdata/pdelfin_cache - metric_for_best_model: openai_batch_data_v5_1_eval_loss - sources: - # These tend to be small, so you can load from s3 it's no big deal - - name: openai_batch_data_v5_1_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - name: openai_batch_data_v5_1_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - - - -# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh -hparams: - batch_size: 1 - eval_batch_size: 1 - gradient_accumulation_steps: 4 - gradient_checkpointing: true - find_unused_parameters: true - clip_grad_norm: 1.0 - learning_rate: 1e-4 - max_steps: 10000 - pad_multiple_of: 16 - log_every_steps: 10 - eval_every_steps: 100 - optim: adamw_torch - lr_scheduler: cosine - weight_decay: 0.01 - warmup_ratio: 0.03 - -# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py -lora: - rank: 32 - alpha: 32 - dropout: 0.05 - task_type: CAUSAL_LM - target_modules: - # attention layers in main transformer - - att_proj - - ff_proj - - attn_out - - ff_out - # vision transformer attention and FF - - attention.wq - - attention.wk - - attention.wv - - attention.wo - - feed_forward.w1 - - feed_forward.w2 - # vision image projector - - vision_backbone.image_projector.w1 - - vision_backbone.image_projector.w2 - - vision_backbone.image_projector.w3 - -save: - path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/ - save_every_steps: 1000 - -max_workers: 10 \ No newline at end of file diff --git a/olmocr/train/config/qwen25vl-7b.yaml b/olmocr/train/config/qwen25vl-7b.yaml deleted file mode 100644 index 36d2239..0000000 --- a/olmocr/train/config/qwen25vl-7b.yaml +++ /dev/null @@ -1,73 +0,0 @@ -model: - name_or_path: Qwen/Qwen2.5-VL-7B-Instruct - arch: causal - use_flash_attn: true - -wandb: - project: pdelfin - entity: ai2-llm - -generate: - max_length: 8192 - -train_data: - seed: 1337 - cache_location: /data/jakep/pdfdata/pdelfin_cache - sources: - # These tend to be small, so you can load from s3 it's no big deal - - name: openai_batch_data_v5_1_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - name: openai_batch_data_v5_1_iabooks_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - # - name: openai_batch_data_v5_1_train - # response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json - # target_longest_image_dim: [1024] - # target_anchor_text_len: [6000] - # - name: openai_batch_data_v5_1_iabooks_train - # response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json - # target_longest_image_dim: [1024] - # target_anchor_text_len: [6000] - -valid_data: - cache_location: /data/jakep/pdfdata/pdelfin_cache - metric_for_best_model: openai_batch_data_v5_1_eval_loss - sources: - # These tend to be small, so you can load from s3 it's no big deal - - name: openai_batch_data_v5_1_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - name: openai_batch_data_v5_1_iabooks_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - - -# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh -hparams: - batch_size: 1 - eval_batch_size: 1 - gradient_accumulation_steps: 4 - gradient_checkpointing: true - clip_grad_norm: 1.0 - learning_rate: 1e-6 - max_steps: 10000 - pad_multiple_of: 16 - log_every_steps: 10 - eval_every_steps: 100 - optim: adamw_torch - lr_scheduler: cosine - weight_decay: 0.01 - warmup_ratio: 0.03 - - -save: - path: s3://ai2-oe-data/jakep/experiments/qwen25vl-pdf/v1/models/ - save_every_steps: 9500 - -max_workers: 10 \ No newline at end of file diff --git a/olmocr/train/config/qwen2vl-2b-lora.yaml b/olmocr/train/config/qwen2vl-2b-lora.yaml deleted file mode 100644 index 7345ce8..0000000 --- a/olmocr/train/config/qwen2vl-2b-lora.yaml +++ /dev/null @@ -1,89 +0,0 @@ -model: - name_or_path: Qwen/Qwen2-VL-2B-Instruct - arch: causal - use_flash_attn: true - -wandb: - project: pdelfin - entity: ai2-llm - -# TODO This is not used -format: - instruction_template: "Original:" - response_template: "Rewritten:" - # Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30 - chat_template: | - {% for message in messages %} - {{'<|im_start|>' + message['role'] + '\n' + message['content']}} - {% if loop.last %} - {{ '<|im_end|>'}} - {% else %} - {{ '<|im_end|>\n' }} - {% endif %} - {% endfor %} - -generate: - max_length: 4096 - -train_data: - seed: 1337 - sources: - - name: openai_batch_data_v2 - query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl - response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json - backend: - - openai - size: 100_000 - -valid_data: - sources: - - name: openai_batch_data_eval_mini - query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl - response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json - backend: - - openai - size: 100_000 - -# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh -hparams: - batch_size: 1 - eval_batch_size: 1 - gradient_accumulation_steps: 4 - gradient_checkpointing: false - clip_grad_norm: 1.0 - learning_rate: 3e-4 - max_steps: 2000 - pad_multiple_of: 16 - log_every_steps: 50 - eval_every_steps: 1000 - optim: adamw_torch - lr_scheduler: cosine - weight_decay: 0.01 - warmup_ratio: 0.03 - -# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py -lora: - rank: 32 - alpha: 32 - dropout: 0.05 - task_type: causal_lm - target_modules: - - q_proj - - k_proj - - v_proj - - o_proj - - gate_proj - - up_proj - - down_proj - - visual.blocks.[0-9]+.attn.qkv - - visual.blocks.[0-9]+.attn.proj - - visual.blocks.[0-9]+.mlp.fc1 - - visual.blocks.[0-9]+.mlp.fc2 - - visual.merger.mlp.0 - - visual.merger.mlp.2 - -save: - path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/ - save_every_steps: 1000 - -max_workers: 10 \ No newline at end of file diff --git a/olmocr/train/config/qwen2vl-2b.yaml b/olmocr/train/config/qwen2vl-2b.yaml deleted file mode 100644 index 6374e7f..0000000 --- a/olmocr/train/config/qwen2vl-2b.yaml +++ /dev/null @@ -1,84 +0,0 @@ -model: - name_or_path: Qwen/Qwen2-VL-2B-Instruct - arch: causal - use_flash_attn: true - -wandb: - project: pdelfin - entity: ai2-llm - -# TODO This is not used -format: - instruction_template: "Original:" - response_template: "Rewritten:" - # Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30 - chat_template: | - {% for message in messages %} - {{'<|im_start|>' + message['role'] + '\n' + message['content']}} - {% if loop.last %} - {{ '<|im_end|>'}} - {% else %} - {{ '<|im_end|>\n' }} - {% endif %} - {% endfor %} - -generate: - max_length: 4096 - -train_data: - seed: 1337 - sources: - - name: openai_batch_data_v2 - query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl - response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json - backend: - - openai - size: 100_000 - -valid_data: - sources: - - name: openai_batch_data_eval_mini - query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl - response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json - backend: - - openai - size: 100_000 - -# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh -hparams: - batch_size: 1 - eval_batch_size: 1 - gradient_accumulation_steps: 4 - gradient_checkpointing: false - clip_grad_norm: 1.0 - learning_rate: 3e-4 - max_steps: 2000 - pad_multiple_of: 16 - log_every_steps: 50 - eval_every_steps: 1000 - optim: adamw_torch - lr_scheduler: cosine - weight_decay: 0.01 - warmup_ratio: 0.03 - -# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py -# Disable LORA for now, because we want the visual network to get trained too -# lora: -# rank: 32 -# alpha: 32 -# dropout: 0.05 -# task_type: causal_lm -# target_modules: -# - q_proj -# - k_proj -# - v_proj -# - o_proj -# - gate_proj -# - up_proj -# - down_proj - -save: - path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/ - save_every_steps: 1000 - -max_workers: 10 \ No newline at end of file diff --git a/olmocr/train/config/qwen2vl-7b-lora.yaml b/olmocr/train/config/qwen2vl-7b-lora.yaml deleted file mode 100644 index aaa471e..0000000 --- a/olmocr/train/config/qwen2vl-7b-lora.yaml +++ /dev/null @@ -1,84 +0,0 @@ -model: - name_or_path: Qwen/Qwen2-VL-7B-Instruct - arch: causal - use_flash_attn: true - -wandb: - project: pdelfin - entity: ai2-llm - -generate: - max_length: 8192 - -train_data: - seed: 1337 - cache_location: /data/jakep/pdfdata/pdelfin_cache - sources: - - name: openai_batch_data_v5_1_train - response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json - target_longest_image_dim: 1024 - target_anchor_text_len: 6000 - - name: openai_batch_data_v5_1_iabooks_train - response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json - target_longest_image_dim: 1024 - target_anchor_text_len: 6000 - -valid_data: - cache_location: /data/jakep/pdfdata/pdelfin_cache - metric_for_best_model: openai_batch_data_v5_1_eval_loss - sources: - # These tend to be small, so you can load from s3 it's no big deal - - name: openai_batch_data_v5_1_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json - target_longest_image_dim: 1024 - target_anchor_text_len: 6000 - - name: openai_batch_data_v5_1_iabooks_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json - target_longest_image_dim: 1024 - target_anchor_text_len: 6000 - - - -# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh -hparams: - batch_size: 1 - eval_batch_size: 1 - gradient_accumulation_steps: 4 - gradient_checkpointing: true - clip_grad_norm: 1.0 - learning_rate: 1e-4 - max_steps: 10000 - pad_multiple_of: 16 - log_every_steps: 10 - eval_every_steps: 100 - optim: adamw_torch - lr_scheduler: cosine - weight_decay: 0.01 - warmup_ratio: 0.03 - -# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py -lora: - rank: 32 - alpha: 32 - dropout: 0.05 - task_type: causal_lm - target_modules: - - q_proj - - k_proj - - v_proj - - o_proj - - gate_proj - - up_proj - - down_proj - - visual.blocks.[0-9]+.attn.qkv - - visual.blocks.[0-9]+.attn.proj - - visual.blocks.[0-9]+.mlp.fc1 - - visual.blocks.[0-9]+.mlp.fc2 - - visual.merger.mlp.0 - - visual.merger.mlp.2 - -save: - path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/ - save_every_steps: 1000 - -max_workers: 10 \ No newline at end of file diff --git a/olmocr/train/config/qwen2vl-7b.yaml b/olmocr/train/config/qwen2vl-7b.yaml deleted file mode 100644 index 7642964..0000000 --- a/olmocr/train/config/qwen2vl-7b.yaml +++ /dev/null @@ -1,64 +0,0 @@ -model: - name_or_path: Qwen/Qwen2-VL-7B-Instruct - arch: causal - use_flash_attn: true - -wandb: - project: pdelfin - entity: ai2-llm - -generate: - max_length: 8192 - -train_data: - seed: 1337 - cache_location: /data/jakep/pdfdata/pdelfin_cache - sources: - - name: openai_batch_data_v5_1_train - response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - name: openai_batch_data_v5_1_iabooks_train - response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - -valid_data: - cache_location: /data/jakep/pdfdata/pdelfin_cache - metric_for_best_model: openai_batch_data_v5_1_eval_loss - sources: - # These tend to be small, so you can load from s3 it's no big deal - - name: openai_batch_data_v5_1_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - name: openai_batch_data_v5_1_iabooks_eval - response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json - target_longest_image_dim: [1024] - target_anchor_text_len: [6000] - - - -# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh -hparams: - batch_size: 1 - eval_batch_size: 1 - gradient_accumulation_steps: 4 - gradient_checkpointing: true - clip_grad_norm: 1.0 - learning_rate: 1e-6 - max_steps: 10000 - pad_multiple_of: 16 - log_every_steps: 10 - eval_every_steps: 100 - optim: adamw_torch - lr_scheduler: cosine - weight_decay: 0.01 - warmup_ratio: 0.03 - - -save: - path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/ - save_every_steps: 9500 - -max_workers: 10 \ No newline at end of file diff --git a/olmocr/train/core/__init__.py b/olmocr/train/core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/olmocr/train/core/adapters.py b/olmocr/train/core/adapters.py deleted file mode 100644 index 3caeff9..0000000 --- a/olmocr/train/core/adapters.py +++ /dev/null @@ -1,96 +0,0 @@ -import json -from logging import Logger -from typing import Optional, Type - -import smart_open -import torch -from peft.peft_model import PeftModel -from transformers import ( - AutoModelForCausalLM, - AutoModelForSeq2SeqLM, - AutoModelWithLMHead, - AutoTokenizer, -) - -from .config import ModelConfig -from .loggers import get_logger -from .paths import cached_path, exists, get_cache_dir, join_path, resource_to_filename - -__all__ = ["load_model", "cache_merged_model"] - - -def get_model_cls(config: ModelConfig) -> Type[AutoModelWithLMHead]: - if config.arch == "seq2seq": - return AutoModelForSeq2SeqLM # pyright: ignore - elif config.arch == "causal" or config.arch == "vllm": - return AutoModelForCausalLM # pyright: ignore - else: - raise ValueError(f"Unsupported model architecture: {config.arch}") - - -def get_adapter_config(config: ModelConfig) -> dict: - local_path = cached_path(config.name_or_path) - if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")): - with smart_open.open(adapter_config_path, "rt", encoding="utf-8") as f: - return json.load(f) - return {} - - -def load_model(config: ModelConfig, logger: Optional[Logger] = None) -> AutoModelWithLMHead: - logger = logger or get_logger(__file__, level="INFO") - - logger.info(f"Loading model from {config.name_or_path}") - local_path = cached_path(config.name_or_path) - if local_path != config.name_or_path: - logger.info(f"Model cached at {local_path}") - - if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")): - logger.info(f"Loading LoRA adapter from {adapter_config_path}") - with smart_open.open(adapter_config_path) as f: - adapter_config = json.load(f) - base_model_name_or_path = adapter_config["base_model_name_or_path"] - enable_lora = True - else: - base_model_name_or_path = local_path - enable_lora = False - - model = get_model_cls(config).from_pretrained( - base_model_name_or_path, - device_map="auto", - trust_remote_code=config.trust_remote_code, - # low_cpu_mem_usage=model_config.low_cpu_mem_usage, - use_flash_attention_2=True if config.use_flash_attn else False, - revision=config.model_revision, - torch_dtype=torch.bfloat16 if config.use_flash_attn else getattr(torch, config.dtype), - ) - logger.info(f"Successfully loaded base model from {base_model_name_or_path}") - - if enable_lora: - peft_model = PeftModel.from_pretrained(model, local_path) - model = peft_model.merge_and_unload() - logger.info(f"Successfully loaded LoRA adapter from base model: {base_model_name_or_path}") - - return model - - -def cache_merged_model(config: ModelConfig, logger: Optional[Logger] = None) -> str: - logger = logger or get_logger(__file__, level="INFO") - - base_local_path = cached_path(config.name_or_path) - adapter_config = get_adapter_config(config) - if not adapter_config: - logger.info("No adapter config found; using base model") - return base_local_path - - local_fn = resource_to_filename(json.dumps({"adapter": adapter_config, "model": config.name_or_path})) - merged_local_path = f"{get_cache_dir()}/{local_fn}" - - if not exists(merged_local_path): - model = load_model(config=config, logger=logger) - tokenizer = AutoTokenizer.from_pretrained(base_local_path) - - logger.info(f"Saving merged model to {merged_local_path}") - model.save_pretrained(merged_local_path) - tokenizer.save_pretrained(merged_local_path) - - return merged_local_path diff --git a/olmocr/train/core/cli.py b/olmocr/train/core/cli.py deleted file mode 100644 index 4be8366..0000000 --- a/olmocr/train/core/cli.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -Utilities to work with a OmegaConf structured config object - -From Dolma Toolkit: https://github.com/allenai/dolma/blob/64886d9db15bd99acea9e28740ae20a510875dfb/python/dolma/cli/__init__.py - -Author: Luca Soldaini (@soldni) -""" # noqa: E501 - -from argparse import ArgumentParser, Namespace -from collections.abc import Iterable -from copy import deepcopy -from dataclasses import Field -from dataclasses import field as dataclass_field -from dataclasses import is_dataclass -from logging import warning -from typing import ( - Any, - Dict, - Literal, - Optional, - Protocol, - Type, - TypeVar, - Union, - get_args, - get_origin, -) - -import smart_open -from necessary import necessary -from omegaconf import MISSING, DictConfig, ListConfig -from omegaconf import OmegaConf as om -from omegaconf.errors import OmegaConfBaseException -from rich.console import Console -from rich.syntax import Syntax -from yaml import safe_load # type: ignore - -from .errors import DolmaRefineError - -__all__ = ["field", "namespace_to_nested_omegaconf", "print_config", "make_cli", "read_config", "to_native_types"] - - -T = TypeVar("T", bound=Any) -D = TypeVar("D", bound="DataClass") -A = TypeVar("A", bound="ArgumentParser") - - -def _field_nargs(default: Any) -> Union[Literal["?"], Literal["*"]]: - # return '+' if _default is iterable but not string/bytes, else 1 - if isinstance(default, (str, bytes)): - return "?" - - if isinstance(default, Iterable): - return "*" - - return "?" - - -def field(default: T = MISSING, help: Optional[str] = None, **extra: Any) -> T: - metadata = {"help": help, "type": type(default), "default": default, "nargs": _field_nargs(default), **extra} - return dataclass_field(default_factory=lambda: deepcopy(default), metadata=metadata) - - -class DataClass(Protocol): - __dataclass_fields__: Dict[str, Field] - - -def read_config(path: Union[None, str]) -> Dict[str, Any]: - """Read a configuration file if it exists""" - if path is None: - return {} - - try: - with smart_open.open(path, mode="rt") as f: - return dict(safe_load(f)) - except FileNotFoundError as ex: - raise DolmaRefineError(f"Config file not found: {path}") from ex - except Exception as ex: - raise DolmaRefineError(f"Error while reading config file: {path}") from ex - - -def save_config(config: Union[dict, DictConfig, list, ListConfig, DataClass], path: str) -> None: - """Save a configuration to a file""" - if isinstance(config, (list, dict)): - config = om.create(config) - elif is_dataclass(config): - config = om.structured(config) - - with smart_open.open(path, mode="wt") as f: - f.write(om.to_yaml(config)) - - -def _make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None) -> A: - for field_name, dt_field in config.__dataclass_fields__.items(): - # get type from annotations or metadata - typ_ = config.__annotations__.get(field_name, dt_field.metadata.get("type", MISSING)) - - if typ_ is MISSING: - warning(f"No type annotation for field {field_name} in {config.__name__}") - continue - - # join prefix and field name - field_name = f"{prefix}.{field_name}" if prefix else field_name - - # This section here is to handle Optional[T] types; we only care for cases where T is a dataclass - # So we first check if type is Union since Optional[T] is just a shorthand for Union[T, None] - # and that the union contains only one non-None type - if get_origin(typ_) == Union: - # get all non-None types - args = [a for a in get_args(typ_) if a is not type(None)] # noqa: E721 - - if len(args) == 1: - # simple Optional[T] type - typ_ = args[0] - - # here's where we check if T is a dataclass - if is_dataclass(typ_): - # recursively add subparsers - _make_parser(parser, typ_, prefix=field_name) # type: ignore - continue - - if typ_ is bool: - # for boolean values, we add two arguments: --field_name and --no-field_name - parser.add_argument( - f"--{field_name}", - help=dt_field.metadata.get("help"), - dest=field_name, - action="store_true", - default=MISSING, - ) - parser.add_argument( - f"--no-{field_name}", - help=f"Disable {field_name}", - dest=field_name, - action="store_false", - default=MISSING, - ) - else: - # else it's just a normal argument - parser.add_argument( - f"--{field_name}", - help=dt_field.metadata.get("help"), - nargs=dt_field.metadata.get("nargs", "?"), - default=MISSING, - ) - - return parser - - -def make_nested_dict(key: str, value: Any, d: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - d = d or {} - - if "." in key: - key, rest = key.split(".", 1) - value = make_nested_dict(rest, value, d.get(key)) - - # the value was provided (is not MISSING constant) and is not an empty dict or list - if value != MISSING and (not isinstance(value, (dict, list)) or len(value) > 0): - d[key] = value - - return d - - -def to_native_types(obj: Any, resolve: bool = True, throw_on_missing: bool = True, enum_to_str: bool = True) -> Any: - """Converts an OmegaConf object to native types (dicts, lists, etc.)""" - - # convert dataclass to structured config - if hasattr(obj, "to_dict"): - # huggingface objects have a to_dict method, we prefer that - obj = obj.to_dict() - elif is_dataclass(obj): - # we go through structured config instead and hope for the best - obj = om.to_container(obj) - - if isinstance(obj, DictConfig) or isinstance(obj, ListConfig): - obj = om.to_container(obj, resolve=resolve, throw_on_missing=throw_on_missing, enum_to_str=enum_to_str) - - if isinstance(obj, dict): - return {k: to_native_types(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [to_native_types(v) for v in obj] - else: - return obj - - -def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config: Optional[dict] = None) -> T: - nested_config_dict: Dict[str, Any] = {} - for key, value in vars(args).items(): - nested_config_dict = make_nested_dict(key, value, nested_config_dict) - - untyped_config: DictConfig = om.merge( - om.create(config or {}), om.create(nested_config_dict) - ) # pyright: ignore (pylance is confused because om.create might return a DictConfig or a ListConfig) - - # resolve any interpolations in the config - om.resolve(untyped_config) - - # create structured config from cli dataclass - base_structured_config: DictConfig = om.structured(structured) - - # merge with options parsed from config file and - merged_config = om.merge(base_structured_config, untyped_config) - - # check for type - if not isinstance(merged_config, DictConfig): - raise DolmaRefineError(f"Expected a DictConfig, got {type(merged_config).__name__}") - - # try resolving all cross references in the config, raise a DolmaConfigError if it fails - try: - om.resolve(merged_config) - except OmegaConfBaseException as ex: - raise DolmaRefineError(f"Invalid error while parsing key `{ex.full_key}`: {type(ex).__name__}") from ex - - return merged_config # pyright: ignore - - -def print_config(config: Any, console: Optional[Console] = None) -> None: - if not isinstance(config, (DictConfig, ListConfig)): - config = om.create(config) - - # print the config as yaml using a rich syntax highlighter - console = console or Console() - yaml_config = om.to_yaml(config, sort_keys=True).strip() - highlighted = Syntax(code=yaml_config, lexer="yaml", theme="ansi_dark") - console.print(highlighted) - - -def _patch_old_omegaconf(): - """Monkey patch omegaconf below version 2.3.0 to support custom resolver returning - lists or dicts. Applies patch https://github.com/omry/omegaconf/pull/1093""" - - if necessary(("omegaconf", "2.4.0"), soft=True): - # no need to patch - return - - if getattr(_patch_old_omegaconf, "__patched__", False): - # already patched - return - - from omegaconf import _impl # pylint: disable=import-outside-toplevel - from omegaconf import ( # pylint: disable=import-outside-toplevel - Container, - Node, - ValueNode, - ) - from omegaconf._utils import ( # noqa: F401 # pylint: disable=import-outside-toplevel - _ensure_container, - _get_value, - is_primitive_container, - is_structured_config, - ) - from omegaconf.errors import ( # pylint: disable=import-outside-toplevel - InterpolationToMissingValueError, - ) - from omegaconf.nodes import ( # pylint: disable=import-outside-toplevel - InterpolationResultNode, - ) - - def _resolve_container_value(cfg: Container, key: Any) -> None: - node = cfg._get_child(key) # pylint: disable=protected-access - assert isinstance(node, Node) - if node._is_interpolation(): # pylint: disable=protected-access - try: - resolved = node._dereference_node() # pylint: disable=protected-access - except InterpolationToMissingValueError: - node._set_value(MISSING) # pylint: disable=protected-access - else: - if isinstance(resolved, Container): - _impl._resolve(resolved) # pylint: disable=protected-access - if isinstance(resolved, InterpolationResultNode): - resolved_value = _get_value(resolved) - if is_primitive_container(resolved_value) or is_structured_config(resolved_value): - resolved = _ensure_container(resolved_value) - if isinstance(resolved, Container) and isinstance(node, ValueNode): - cfg[key] = resolved - else: - node._set_value(_get_value(resolved)) # pylint: disable=protected-access - else: - _impl._resolve(node) # pylint: disable=protected-access - - # set new function and mark as patched - setattr(_impl, "_resolve_container_value", _resolve_container_value) - setattr(_patch_old_omegaconf, "__patched__", True) - - -# actually executes the patch -_patch_old_omegaconf() - - -def make_cli(config_cls: Type[D], _config_flag: str = "config", _dryrun_flag: str = "dryrun") -> D: - """Create a CLI parser for a dataclass and parse the arguments into a structured config object.""" - - if hasattr(config_cls, _config_flag): - raise DolmaRefineError(f"`{_config_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`") - - if hasattr(config_cls, _dryrun_flag): - raise DolmaRefineError(f"`{_dryrun_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`") - - parser = ArgumentParser() - parser.add_argument(f"-{_config_flag[0]}", f"--{_config_flag}", help="Path to config file", default=None, type=str) - parser.add_argument( - f"-{_dryrun_flag[0]}", - f"--{_dryrun_flag}", - help="Dry run mode: print config and exit", - action="store_true", - default=False, - ) - - parser = _make_parser(parser, config_cls) - args = parser.parse_args() - - parsed_config: Dict[str, Any] = {} - if (config_path := getattr(args, _config_flag)) is not None: - parsed_config = read_config(config_path) - delattr(args, _config_flag) - - only_dryrun = getattr(args, _dryrun_flag, False) - delattr(args, _dryrun_flag) - - full_config = namespace_to_nested_omegaconf(args, config_cls, parsed_config) - - print_config(full_config) - - if only_dryrun: - exit(0) - - return full_config diff --git a/olmocr/train/core/compression.py b/olmocr/train/core/compression.py deleted file mode 100644 index de8a6c4..0000000 --- a/olmocr/train/core/compression.py +++ /dev/null @@ -1,16 +0,0 @@ -from smart_open import register_compressor - -__all__ = ["mk_compression"] - - -def mk_compression(): - def _handle_zst(file_obj, mode): - try: - import zstandard as zstd - except ImportError: - raise ImportError("zstandard is required for zstd support") - - return zstd.open(file_obj, mode) - - register_compressor(".zstd", _handle_zst) - register_compressor(".zst", _handle_zst) diff --git a/olmocr/train/core/config.py b/olmocr/train/core/config.py deleted file mode 100644 index fe2b663..0000000 --- a/olmocr/train/core/config.py +++ /dev/null @@ -1,131 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -from peft import TaskType # pyright: ignore - -from .cli import field - - -@dataclass -class ModelConfig: - """Configuration for loading a model; includes model name and type.""" - - name_or_path: str = field(help="The model name or path to load; must be compatible with huggingface transformers.") - arch: str = field(help="The model type to load; can be 'vllm', 'causal', or 'vllm'") - dtype: str = field(help="The precision to use for the model", default="bfloat16") - use_flash_attn: bool = field(help="Whether to use the flash attention for the model.", default=False) - trust_remote_code: bool = field(help="Whether to trust remote code for the model.", default=False) - low_cpu_mem_usage: bool = field(help="Whether to use low cpu memory usage for the model.", default=False) - fast_tokenizer: bool = field(help="Whether to use the fast tokenizer for the model.", default=True) - model_revision: Optional[str] = field(help="The model revision to use for the model.", default=None) - - -@dataclass -class GenerateConfig: - max_length: int = field(help="The maximum length of the generated text", default=4096) - temperature: float = field(default=0.2, help="The temperature to use for generation") - top_k: int = field(default=50, help="The top k to use for generation") - top_p: float = field(default=1.0, help="The top p to use for generation") - num_beams: int = field(default=1, help="The number of beams to use for generation") - truncate_prompt_tokens: bool = field(default=True, help="Whether to truncate the prompt tokens for generation") - max_num_seqs: int = field(default=16, help="The maximum number of sequences to generate") - - -@dataclass -class WandbConfig: - entity: str = field(help="The wandb entity to use for logging", default="ai2-llm") - project: str = field(help="The wandb project to use for logging", default="pdf-qwen2vl") - wandb_api_key: Optional[str] = field(help="The wandb api key to use for logging", default=None) - mode: str = field(help="The wandb mode to use for logging. Set it to `offline`", default="online") - watch: str = field(help="The wandb watch to use for logging", default="false") - - -@dataclass -class AwsConfig: - profile: Optional[str] = field(help="The aws profile to use for s3 access", default=None) - access_key_id: Optional[str] = field(help="The aws access key id to use for s3 access", default=None) - secret_access_key: Optional[str] = field(help="The aws secret access key to use for s3 access", default=None) - default_region: Optional[str] = field(help="The default region to use for s3 access", default=None) - - -@dataclass -class SourceConfig: - name: str = field(help="The name of the source") - response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai") - target_longest_image_dim: list[int] = field(help="Dimensions to render the pdf page image to") - target_anchor_text_len: list[int] = field(help="Maximum amount of anchor text (aka prompt hint)") - - -@dataclass -class DataConfig: - seed: int = field(default=42, help="The seed to use for data loading") - cache_location: Optional[str] = field(help="Location to store s3 pdfs that need to be used to compute page images", default=None) - metric_for_best_model: Optional[str] = field(help="metric to pass to trainer args to use for picking best model checkpoint at end", default=None) - sources: List[SourceConfig] = field(help="The source configurations") - - -@dataclass -class HyperparamConfig: - batch_size: int = field(default=8, help="The batch size to use for training") - eval_batch_size: Optional[int] = field(default=None, help="The batch size to use for evaluation; default is the same as the training batch size") - learning_rate: float = field(default=2e-5, help="The learning rate to use for training") - max_steps: int = field(default=-1, help="The maximum number of steps to train the model") - pad_multiple_of: int = field(default=16, help="The padding multiple to use for the model") - log_every_steps: int = field(default=5, help="The number of steps to log training metrics") - eval_every_steps: int = field(default=100, help="The number of steps to evaluate the model") - weight_decay: float = field(default=0.0, help="The weight decay to use for training") - warmup_steps: int = field(default=0, help="The number of warmup steps to use for training") - warmup_ratio: float = field(default=0.0, help="The ratio of warmup steps to use for training") - lr_scheduler: str = field(default="linear", help="The learning rate scheduler to use for training") - gradient_accumulation_steps: int = field(default=1, help="The number of gradient accumulation steps to use for training") - gradient_checkpointing: bool = field(default=False, help="Whether to use gradient checkpointing for training") - seed: int = field(default=42, help="The seed to use for training") - reduce_loss: str = field(default="mean", help="The loss reduction to use for training") - clip_grad_norm: float = field(default=0.0, help="The gradient norm to clip to for training") - optim: str = field(default="adamw_torch", help="The optimizer to use for training") - find_unused_parameters: bool = field(default=False, help="Whether to find unused parameters for training") - - -@dataclass -class SaveConfig: - path: str = field(default="./results", help="The output directory to save the model") - limit: Optional[int] = field(default=None, help="The number of checkpoints to save") - save_every_steps: int = field(default="${hparams.eval_every_steps}", help="The number of steps to save the model") # type: ignore - - -@dataclass -class LoraConfig: - rank: int = field(default=16, help="The rank of the LoRA attention") - alpha: int = field(default=16, help="The alpha parameter for LoRA scaling") - dropout: float = field(default=0.05, help="The dropout probability for LoRA layers") - bias: str = field(default="none", help="The bias to use for LoRA layers (none, causal, or full)") - task_type: str = field(default=TaskType.CAUSAL_LM, help="The task type for the model") - target_modules: List[str] = field( - default=["k_proj", "q_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - help="The target modules in the model that will be replaced with LoRA layers", - ) - - -@dataclass -class TrainConfig: - model: ModelConfig = field(default=ModelConfig(), help="The model configuration") - lora: Optional[LoraConfig] = field(default=None, help="The LoRA configuration") - aws: AwsConfig = field(default=AwsConfig(), help="Configuration for AWS S3") - wandb: WandbConfig = field(default=WandbConfig(), help="Configuration for Weights and Biases") - train_data: DataConfig = field(default=DataConfig(), help="Configuration for the training data") - valid_data: DataConfig = field(default=DataConfig(), help="Configuration for the validation data") - generate: GenerateConfig = field(default=GenerateConfig(), help="Configuration for text generation") - num_proc: int = field(default=1, help="The maximum number of workers to use for data processing") - max_workers: int = field(default=1, help="The maximum number of workers to use for data loaders") - hparams: HyperparamConfig = field(default=HyperparamConfig(), help="Hyperparameters for training") - save: SaveConfig = field(default=SaveConfig(), help="Configuration for saving the model") - - -@dataclass -class DemoConfig: - title: str = field(default="# Dolma Rewriter Demo") - description: str = field(default="Internal use only, **DO NOT SHARE OUTSIDE AI2**.") - share: bool = field(default=False, help="Share the demo publicly.") - - model: ModelConfig = field(default=ModelConfig()) - generate: GenerateConfig = field(default=GenerateConfig()) diff --git a/olmocr/train/core/errors.py b/olmocr/train/core/errors.py deleted file mode 100644 index a24dbe0..0000000 --- a/olmocr/train/core/errors.py +++ /dev/null @@ -1 +0,0 @@ -class DolmaRefineError(RuntimeError): ... diff --git a/olmocr/train/core/loggers.py b/olmocr/train/core/loggers.py deleted file mode 100644 index ced2782..0000000 --- a/olmocr/train/core/loggers.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging -import multiprocessing -from typing import Union - -LOGGER_PREFIX = "dolma-refine" - - -def get_logger(name: str, level: Union[int, str] = logging.WARN) -> logging.Logger: - if (proc_name := multiprocessing.current_process().name) == "MainProcess": - proc_name = "main" - proc_name = proc_name.replace(" ", "_") - - # set the log level - level = level if isinstance(level, int) else getattr(logging, level.strip().upper(), logging.WARN) - - # set name - name = f"{LOGGER_PREFIX}.{proc_name}.{name}" - logger = logging.getLogger(name) - logger.setLevel(level) - - # add handler - if not logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter("[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") - handler.setFormatter(formatter) - logger.addHandler(handler) - - return logger - - -def reset_level(level: Union[int, str]) -> None: - """ - Reset the log level for all Dolma loggers. - - Args: - level (Union[int, str]): The log level to set. It can be either an integer - representing the log level (e.g., logging.DEBUG) or a string - representing the log level name (e.g., 'debug'). - - Returns: - None - """ - if isinstance(level, str): - if (level_tmp := getattr(logging, level.strip().upper(), None)) is not None: - level = level_tmp - else: - raise ValueError(f"Invalid log level: {level}") - - for logger in logging.Logger.manager.loggerDict.values(): - if isinstance(logger, logging.Logger): - if logger.name.startswith(LOGGER_PREFIX): - logger.setLevel(level) diff --git a/olmocr/train/core/paths.py b/olmocr/train/core/paths.py deleted file mode 100644 index 977c6ca..0000000 --- a/olmocr/train/core/paths.py +++ /dev/null @@ -1,615 +0,0 @@ -import glob -import os -import re -from concurrent.futures import ThreadPoolExecutor -from functools import partial, reduce -from hashlib import sha256 -from itertools import chain -from pathlib import Path -from shutil import copyfileobj -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union -from urllib.parse import urlparse - -import platformdirs -import smart_open -from fsspec import AbstractFileSystem, get_filesystem_class -from smart_open.compression import get_supported_extensions - -from .loggers import LOGGER_PREFIX, get_logger - -__all__ = [ - "glob_path", - "sub_prefix", - "add_suffix", - "sub_suffix", - "make_relative", - "mkdir_p", - "split_path", - "join_path", - "is_glob", - "split_glob", - "partition_path", -] - - -FS_KWARGS: Dict[str, Dict[str, Any]] = { - "": {"auto_mkdir": True}, -} - - -RE_ANY_ESCAPE = re.compile(r"(? AbstractFileSystem: - """ - Get the filesystem class for a given path. - """ - path = str(path) - protocol = urlparse(path).scheme - fs = get_filesystem_class(protocol)(**FS_KWARGS.get(protocol, {})) - - global PATCHED_GLOB # pylint: disable=global-statement - - # patch glob method to support recursive globbing - if protocol == "" and not PATCHED_GLOB: - fs.glob = partial(glob.glob, recursive=True) - - # only patch once - PATCHED_GLOB = True - - return fs - - -def _escape_glob(s: Union[str, Path]) -> str: - """ - Escape glob characters in a string. - """ - s = str(s) - s = RE_GLOB_STAR_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["*"], s) - s = RE_GLOB_ONE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["?"], s) - s = RE_GLOB_OPEN_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["["], s) - s = RE_GLOB_CLOSE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["]"], s) - return s - - -def _unescape_glob(s: Union[str, Path]) -> str: - """ - Unescape glob characters in a string. - """ - s = str(s) - for k, v in REVERSE_ESCAPE_SYMBOLS_MAP.items(): - s = s.replace(k, v) - return s - - -def _pathify(path: Union[Path, str]) -> Tuple[str, Path]: - """ - Return the protocol and path of a given path. - """ - path = _escape_glob(str(path)) - parsed = urlparse(path) - path = Path(f"{parsed.netloc}/{parsed.path}") if parsed.netloc else Path(parsed.path) - return parsed.scheme, path - - -def _unpathify(protocol: str, path: Path) -> str: - """ - Return a path from its protocol and path components. - """ - path_str = _unescape_glob(str(path)) - if protocol: - path_str = f"{protocol}://{path_str.lstrip('/')}" - return path_str - - -def remove_params(path: str) -> str: - """ - Remove parameters from a path. - """ - parsed = urlparse(path) - return (f"{parsed.scheme}://" if parsed.scheme else "") + f"{parsed.netloc}{parsed.path}" - - -def is_local(path: str) -> bool: - """ - Check if a path is local. - """ - prot, _ = _pathify(path) - return prot == "" or prot == "file" - - -def copy_file(src: str, dest: str) -> None: - """Copy a file using shutil.copyfileobj for efficient chunked copying.""" - with smart_open.open(src, "rb") as src_file, smart_open.open(dest, "wb") as dest_file: - copyfileobj(src_file, dest_file) - - -def copy_dir(src: str, dst: str, src_fs: Optional[AbstractFileSystem] = None, dst_fs: Optional[AbstractFileSystem] = None): - """Copy a directory using a ThreadPoolExecutor for parallel file copying.""" - src_fs = src_fs or get_fs(src) - dst_fs = dst_fs or get_fs(dst) - logger = get_logger(__name__) - - with ThreadPoolExecutor(max_workers=8) as executor: - futures = [] - - for src_path in glob_path(src, yield_dirs=True, fs=src_fs): - rel_path = sub_prefix(src_path, src) - dest_path = join_path("", dst, rel_path) - - if is_dir(src_path, fs=src_fs): - # Recursively copy directories - copy_dir(src=src_path, dst=dest_path, src_fs=src_fs, dst_fs=dst_fs) - else: - # File; copy over using the executor for parallelism - logger.info(f"Copying {src_path} to {dest_path}") - futures.append(executor.submit(copy_file, src_path, dest_path)) - - # Wait for all futures to complete - for future in futures: - future.result() # This will raise an exception if any of the threads failed - - -def delete_file(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool: - """Delete a file.""" - - fs = fs or get_fs(path) - try: - fs.rm(path) - deleted = True - except FileNotFoundError as ex: - if not ignore_missing: - raise ex - deleted = False - - return deleted - - -def get_size(path: str, fs: Optional[AbstractFileSystem] = None) -> int: - """Get the size of a file""" - - fs = fs or get_fs(path) - - if not exists(path, fs=fs): - raise ValueError(f"Path {path} does not exist") - if is_dir(path, fs=fs): - raise ValueError(f"Path {path} is a directory") - - return fs.info(path)["size"] - - -def delete_dir(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool: - """Delete a directory.""" - - fs = fs or get_fs(path) - try: - fs.rm(path, recursive=True) - deleted = True - except FileNotFoundError as ex: - if not ignore_missing: - raise ex - deleted = False - - return deleted - - -def partition_path(path: str) -> Tuple[str, Tuple[str, ...], Tuple[str, ...]]: - """Partition a path into its protocol, symbols before a glob, and symbols after a glob.""" - # split the path into its protocol and path components - prot, path_obj = _pathify(path) - - # we need to first figure out if this path has a glob by checking if any of the escaped symbols for - # globs are in the path. - glob_locs = [i for i, p in enumerate(path_obj.parts) if any(c in p for c in REVERSE_ESCAPE_SYMBOLS_MAP)] - - # make the path components before the glob - pre_glob_path = path_obj.parts[: glob_locs[0]] if glob_locs else path_obj.parts - pre_glob_path = tuple(_unescape_glob(p) for p in pre_glob_path) - - # make the path components after the glob - post_glob_path = path_obj.parts[glob_locs[0] + 1 :] if glob_locs else () - post_glob_path = tuple(_unescape_glob(p) for p in post_glob_path) - - return prot, pre_glob_path, post_glob_path - - -def split_path(path: str) -> Tuple[str, Tuple[str, ...]]: - """ - Split a path into its protocol and path components. - """ - protocol, _path = _pathify(path) - return protocol, tuple(_unescape_glob(p) for p in _path.parts) - - -def join_path(protocol: Union[str, None], *parts: Union[str, Iterable[str]]) -> str: - """ - Join a path from its protocol and path components. - """ - all_prots, all_parts = zip(*(_pathify(p) for p in chain.from_iterable([p] if isinstance(p, str) else p for p in parts))) - path = str(Path(*all_parts)).rstrip("/") - protocol = protocol or str(all_prots[0]) - - if protocol: - path = f"{protocol}://{path.lstrip('/')}" - return _unescape_glob(path) - - -def glob_path( - path: Union[Path, str], - hidden_files: bool = False, - autoglob_dirs: bool = True, - recursive_dirs: bool = False, - yield_dirs: bool = True, - fs: Optional[AbstractFileSystem] = None, -) -> Iterator[str]: - """ - Expand a glob path into a list of paths. - """ - protocol, parsed_path = _pathify(path) - fs = fs or get_fs(path) - - if autoglob_dirs and fs.isdir(path): - path = join_path(protocol, _unescape_glob(parsed_path), "*") - - if "*" not in str(path): - # nothing to glob - yield str(path) - return - - for gl in fs.glob(path): - gl = str(gl) - - if not hidden_files and Path(gl).name.startswith("."): - continue - - if fs.isdir(gl): - if recursive_dirs: - yield from glob_path( - gl, - hidden_files=hidden_files, - autoglob_dirs=autoglob_dirs, - recursive_dirs=recursive_dirs, - yield_dirs=yield_dirs, - fs=fs, - ) - if yield_dirs: - yield join_path(protocol, gl) - else: - yield join_path(protocol, gl) - - -def sub_prefix(a: str, b: str) -> str: - """ - Return the relative path of b from a. - """ - prot_a, path_a = _pathify(a) - prot_b, path_b = _pathify(b) - - if prot_a != prot_b: - raise ValueError(f"Protocols of {a} and {b} do not match") - - try: - diff = str(path_a.relative_to(path_b)) - except ValueError: - diff = join_path(prot_a, path_a.parts) - - return _unescape_glob(diff) - - -def sub_suffix(a: str, b: str) -> str: - """ - Remove b from the end of a. - """ - prot_a, path_a = _pathify(a) - prot_b, path_b = _pathify(b) - - if prot_b: - raise ValueError(f"{b} is not a relative path") - - sub_path = re.sub(f"{path_b}$", "", str(path_a)) - sub_prot = f"{prot_a}://" if prot_a else "" - - # need to trim '/' from the end if (a) '/' is not the only symbol in the path or - # (b) there is a protocol so absolute paths don't make sense - if sub_path != "/" or sub_prot: - sub_path = sub_path.rstrip("/") - - return _unescape_glob(sub_prot + sub_path) - - -def add_suffix(a: str, b: str) -> str: - """ - Return the the path of a joined with b. - """ - prot_a, path_a = _pathify(a) - prot_b, path_b = _pathify(b) - - if prot_b: - raise ValueError(f"{b} is not a relative path") - - return join_path(prot_a, str(path_a / path_b)) - - -def exists(path: str, fs: Optional[AbstractFileSystem] = None) -> bool: - """Check if a path exists.""" - - fs = fs or get_fs(path) - return fs.exists(path) - - -def is_dir(path: str, fs: Optional[AbstractFileSystem] = None) -> bool: - """Check if a path is a directory.""" - fs = fs or get_fs(path) - if exists(path, fs=fs): - return fs.isdir(path) - return False - - -def is_file(path: str, fs: Optional[AbstractFileSystem] = None) -> bool: - """Check if a path is a file.""" - fs = fs or get_fs(path) - if exists(path, fs=fs): - return fs.isfile(path) - return False - - -def parent(path: str) -> str: - """Get the parent directory of a path; if the parent is the root, return the root.""" - - prot, parts = split_path(path) - if len(parts) == 1: - return path - return join_path(prot, *parts[:-1]) - - -def mkdir_p(path: str, fs: Optional[AbstractFileSystem] = None) -> None: - """ - Create a directory if it does not exist. - """ - if is_glob(path): - raise ValueError(f"Cannot create directory with glob pattern: {path}") - - fs = fs or get_fs(path) - fs.makedirs(path, exist_ok=True) - - -def make_relative(paths: List[str]) -> Tuple[str, List[str]]: - """Find minimum longest root shared among all paths""" - if len(paths) == 0: - raise ValueError("Cannot make relative path of empty list") - - common_prot, common_parts, _ = partition_path(paths[0]) - - for path in paths: - current_prot, current_parts, _ = partition_path(path) - if current_prot != common_prot: - raise ValueError(f"Protocols of {path} and {paths[0]} do not match") - - for i in range(min(len(common_parts), len(current_parts))): - if common_parts[i] != current_parts[i]: - common_parts = common_parts[:i] - break - - if len(common_parts) > 0: - common_path = (f"{common_prot}://" if common_prot else "") + str(Path(*common_parts)) - relative_paths = [sub_prefix(path, common_path) for path in paths] - else: - common_path = f"{common_prot}://" if common_prot else "" - relative_paths = [_unpathify("", _pathify(path)[1]) for path in paths] - - return common_path, relative_paths - - -def is_glob(path: str) -> bool: - """ - Check if a path contains a glob wildcard. - """ - return bool(re.search(r"(? Tuple[str, str]: - """ - Partition a path on the first wildcard. - """ - if not is_glob(path): - # it's not a glob, so it's all path - return path, "" - - if path[0] == "*": - # starts with a glob, so it's all glob - return "", path - - protocol, parts = split_path(path) - - i = min(i for i, c in enumerate(parts) if is_glob(c)) - - if i == 0: - # no path, so it's all glob - return protocol, join_path("", *parts) - - path = join_path(protocol, *parts[:i]) - rest = join_path("", *parts[i:]) - return path, rest - - -def get_cache_dir() -> str: - """ - Returns the path to the cache directory for the Dolma toolkit. - If the directory does not exist, it will be created. - - Returns: - str: The path to the cache directory. - """ - loc = platformdirs.user_cache_dir(LOGGER_PREFIX) - mkdir_p(loc) - return loc - - -def resource_to_filename(resource: Union[str, bytes]) -> str: - """ - Convert a ``resource`` into a hashed filename in a repeatable way. Preserves the file extensions. - """ - _, (*_, orig_filename) = split_path(remove_params(str(resource))) - _, extensions = split_basename_and_extension(orig_filename) - - resource_bytes = str(resource).encode("utf-8") - resource_hash = sha256(resource_bytes) - hash_filename = resource_hash.hexdigest() + extensions - - return hash_filename - - -def cached_path(path: str, fs: Optional[AbstractFileSystem] = None) -> str: - """ - Returns the cached path for a given resource. - - If the resource is already available locally, the function returns the path as is. - Otherwise, it downloads the resource from the specified path and saves it in the cache directory. - - Args: - path (str): The path to the resource. - - Returns: - str: The cached path of the resource. - """ - if is_local(path): - # Implementation goes here - pass - return path - - destination = f"{get_cache_dir()}/{resource_to_filename(path)}" - - remote_fs = fs or get_fs(path) - local_fs = get_fs(destination) - - if exists(destination, fs=local_fs): - LOGGER.info(f"Using cached file {destination} for {path}") - return destination - - if is_dir(path, fs=remote_fs): - for sub_path in glob_path(path, fs=remote_fs): - rel_path = sub_prefix(sub_path, path) - dest_path = join_path("", destination, rel_path) - mkdir_p(parent(dest_path), fs=local_fs) - LOGGER.info(f"Downloading {sub_path} to {dest_path}") - with smart_open.open(sub_path, "rb") as src, smart_open.open(dest_path, "wb") as dest: - dest.write(src.read()) - else: - LOGGER.info(f"Downloading {path} to {destination}") - with smart_open.open(path, "rb") as src, smart_open.open(destination, "wb") as dest: - dest.write(src.read()) - - return destination - - -def split_basename_and_extension(path: str) -> Tuple[str, str]: - """ - Get the path and extension from a given file path. If a file has multiple - extensions, they will be joined with a period, e.g. "foo/bar/baz.tar.gz" - will return ("foo/bar/baz", ".tar.gz"). If the file has no extension, the - second element of the tuple will be an empty string. Works with both local - and remote (e.g. s3://) paths. - - Args: - path (str): The file path. - - Returns: - Tuple[str, str]: A tuple containing the path and extension. - """ - prot, (*parts, filename) = split_path(path) - base, *ext_parts = filename.split(".") - ext = ("." + ".".join(ext_parts)) if ext_parts else "" - return join_path(prot, *parts, base), ext - - -def decompress_path(path: str, dest: Optional[str] = None) -> str: - """ - Decompresses a file at the given path and returns the path to the decompressed file. - - Args: - path (str): The path to the file to be decompressed. - dest (str, optional): The destination path for the decompressed file. - If not provided, a destination path will be computed based on the original - file name and the cache directory. - - Returns: - str: The path to the decompressed file. If the file cannot be decompressed, - the original path will be returned. - """ - for supported_ext in get_supported_extensions(): - # not the supported extension - if not path.endswith(supported_ext): - continue - - if dest is None: - # compute the name for the decompressed file; to do this, we first hash for - # resource and then remove the extension. - base_fn, ext = split_basename_and_extension(resource_to_filename(path)) - - # to get the decompressed file name, we remove the bit of the extension that - # indicates the compression type. - decompressed_fn = base_fn + ext.replace(supported_ext, "") - - # finally, we get cache directory and join the decompressed file name to it - dest = join_path("", get_cache_dir(), decompressed_fn) - - # here we do the actual decompression - with smart_open.open(path, "rb") as fr, smart_open.open(dest, "wb") as fw: - fw.write(fr.read()) - - # return the path to the decompressed file - return dest - - # already decompressed or can't be decompressed - return path - - -def split_ext(path: str) -> Tuple[str, Tuple[str, ...], str]: - """ - Split a path into its protocol and extensions. - """ - prot, parts = split_path(path) - if not parts: - return prot, (), "" - - filename = parts[-1] - extensions = [] - while True: - filename, ext = os.path.splitext(filename) - if not ext: - break - extensions.append(ext) - - return prot, (*parts[:-1], filename), "".join(reversed(extensions)) - - -def get_unified_path(paths: List[str]) -> str: - """Get a unified path for a list of paths.""" - - if len(paths) == 1: - # if there is only one path, we don't need to unify anything - return paths[0] - - # get shared root for all paths; we will put the unified path here - root, relative = make_relative(paths) - - # get the extension from the first path; assume all paths have the same extension - _, _, ext = split_ext(relative[0]) - - # hash all the sorted relative paths in order to get a unique name - # the type: ignore is needed because mypy fails to infer the type of the lambda - # (the "or" ensures that the lambda returns the same type as the first argument, which is a hash) - h = reduce(lambda h, p: h.update(p.encode()) or h, sorted(relative), sha256()) # type: ignore - - # return the unified path - return join_path(root, h.hexdigest() + ext) diff --git a/olmocr/train/core/state.py b/olmocr/train/core/state.py deleted file mode 100644 index 60eece0..0000000 --- a/olmocr/train/core/state.py +++ /dev/null @@ -1,27 +0,0 @@ -import os -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class BeakerState: - job_id: Optional[str] = None - job_kind: Optional[str] = None - task_id: Optional[str] = None - experiment_id: Optional[str] = None - replica_rank: Optional[str] = None - leader_replica_hostname: Optional[str] = None - leader_replica_node_id: Optional[str] = None - user_id: Optional[str] = None - - def __post_init__(self): - for key, value in os.environ.items(): - if not key.startswith("BEAKER_"): - continue - setattr(self, key.lstrip("BEAKER_").lower(), value) - - @property - def url(self) -> Optional[str]: - if self.job_id: - return f"https://beaker.org/jobs/{self.job_id}" - return None diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index 320aff2..e69de29 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -1,165 +0,0 @@ -import glob -import logging -import os -import re -from typing import Optional - -import boto3 -from datasets import Dataset, load_dataset -from filelock import FileLock - -from olmocr.data.renderpdf import get_pdf_media_box_width_height -from olmocr.prompts.anchor import get_anchor_text -from olmocr.s3_utils import parse_custom_id, parse_s3_path - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Quiet logs from pypdf and smart open -logging.getLogger("pypdf").setLevel(logging.ERROR) -logging.getLogger("smart_open").setLevel(logging.ERROR) - - -def list_dataset_files(s3_glob_path: str): - """ - Lists files in the specified S3 path that match the glob pattern. - """ - if s3_glob_path.startswith("s3://"): - s3 = boto3.client("s3") - match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path) - if not match: - logger.error(f"Invalid S3 path: {s3_glob_path}") - raise ValueError(f"Invalid S3 path: {s3_glob_path}") - - bucket, prefix_pattern = match.groups() - prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard - paginator = s3.get_paginator("list_objects_v2") - pages = paginator.paginate(Bucket=bucket, Prefix=prefix) - - files = [] - pattern = re.compile(prefix_pattern.replace("*", ".*")) - for page in pages: - for obj in page.get("Contents", []): - key = obj["Key"] - if pattern.fullmatch(key): - files.append(f"s3://{bucket}/{key}") - return files - else: - return glob.glob(s3_glob_path) - - -def load_jsonl_into_ds(s3_glob_path: str, first_n_files: Optional[int] = None) -> Dataset: - """ - Loads JSONL files from the specified S3 path into a Hugging Face Dataset. - """ - all_json_files = list_dataset_files(s3_glob_path) - - if first_n_files: - all_json_files = all_json_files[:first_n_files] - - # Use datasets library to load JSON files from S3 - dataset = load_dataset( - "json", - data_files=all_json_files, - ) - - return dataset - - -def extract_openai_batch_response(example): - custom_id = example.get("custom_id", None) - - # Parse the custom id into an s3 document path and page number (1indexed) - s3_path, page_num = parse_custom_id(custom_id) - - response_body = example.get("response", {}).get("body", {}) - choices = response_body.get("choices", []) - response = "" - finish_reason = "" - if choices: - first_choice = choices[0] - message = first_choice.get("message", {}) - response = message.get("content", "") - finish_reason = first_choice.get("finish_reason", "") - - # TODO Maybe in the future we can parse the response (which is a structured JSON document itself) - # into its own columns - - return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason} - - -def _cache_s3_file(s3_path: str, local_cache_dir: str): - """ - Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file. - """ - bucket, key = parse_s3_path(s3_path) - - # Define the local file path - local_file_path = os.path.join(local_cache_dir, bucket + "__" + key.replace("/", "_")) - os.makedirs(os.path.dirname(local_file_path), exist_ok=True) - lock_file = f"{local_file_path}.lock" - - # Use a file lock to prevent concurrent writes - with FileLock(lock_file): - if not os.path.exists(local_file_path): - logger.info(f"Downloading {s3_path} to {local_file_path}") - s3_client = boto3.client("s3", aws_access_key_id=os.getenv("DS_AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("DS_AWS_SECRET_ACCESS_KEY")) - s3_client.download_file(bucket, key, local_file_path) - else: - pass - # logger.info(f"File {local_file_path} already exists, skipping download.") - - return local_file_path - - -def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset: - """ - Caches all S3 paths in the dataset to the local cache directory. - """ - - # Define the download function to use in parallel processing - def cache_file(example): - s3_path = example["s3_path"] - if s3_path: - # Download the file and cache it locally - local_path = _cache_s3_file(s3_path, pdf_cache_location) - return {"local_pdf_path": local_path} - return {"local_pdf_path": None} - - # Map the caching function to the dataset (with parallelism if needed) - dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False) - - return dataset - - -def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str] = None, num_proc: int = 32) -> Dataset: - if pdf_cache_location is None: - pdf_cache_location = os.path.join(os.path.expanduser("~"), ".cache", "olmocr_pdfs") - - logger.info("Loading fine tuning dataset from OpenAI style batch responses") - response_data = load_jsonl_into_ds(response_glob_path) - response_data = response_data["train"] - - response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc) - - # Don't include data where the model cut off due to a length issue, or moderation issue - logger.info("Filtering on finish_reason == stop") - final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc) - - # Cache all the s3_paths that were accessed to a local storage location, - final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc) - - # Filter out pages where you cannot get an anchor text generated, to prevent errors during actual training - def _can_create_anchor_text(example): - try: - anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=4000) - _ = get_pdf_media_box_width_height(example["local_pdf_path"], example["page_num"]) - return anchor_text is not None - except: - logger.exception("Could not generate anchor text for file, be sure you have all dependencies installed") - return False - - final_dataset = final_dataset.filter(_can_create_anchor_text, num_proc=num_proc) - - return final_dataset diff --git a/olmocr/train/dataprep.py b/olmocr/train/dataprep.py deleted file mode 100644 index 4acf6db..0000000 --- a/olmocr/train/dataprep.py +++ /dev/null @@ -1,164 +0,0 @@ -import base64 -import random -from io import BytesIO -from typing import Union - -import numpy as np -import torch # Make sure to import torch as it's used in the DataCollator -from PIL import Image - -from olmocr.data.renderpdf import render_pdf_to_base64png -from olmocr.prompts import build_finetuning_prompt -from olmocr.prompts.anchor import get_anchor_text - - -def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]): - if isinstance(target_longest_image_dim, list): - target_longest_image_dim = random.choice(target_longest_image_dim) - - if isinstance(target_anchor_text_len, list): - target_anchor_text_len = random.choice(target_anchor_text_len) - - anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len) - base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim) - - # Prepare messages - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": base64_page_image}, - {"type": "text", "text": build_finetuning_prompt(anchor_text)}, - ], - } - ] - # Apply chat template to get the text - text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - - # Decode image from base64 - main_image = Image.open(BytesIO(base64.b64decode(base64_page_image))) - - # Process inputs using processor - inputs = processor( - text=[text], - images=[main_image], - padding=True, - return_tensors="np", - ) - - # Get labels by tokenizing the output text - labels = processor(text=[example["response"]], padding=True, return_tensors="np") - - # Append an <|im_end|>\n" to the labels, because this is what it would look like - # if we passed the whole message stream in there - im_end_tokens = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"] - im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype) # Ensure correct dtype - - # Handle the case where labels['input_ids'] is empty - if labels["input_ids"].shape[1] == 0: - labels_input_ids_0 = np.array([], dtype=inputs.input_ids.dtype) - else: - labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype) - - labels["input_ids"] = np.concatenate([labels_input_ids_0, im_end_tokens]) - labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0) - - # Concatenate input_ids and labels - input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0) - - # All columns will participate in attention fully - attention_mask = np.ones_like(input_ids) - - # Create labels, masking the input portion with -100 - labels_full = np.full_like(input_ids, fill_value=-100) - labels_full[len(inputs.input_ids[0]) :] = labels.input_ids[0] - - # TODO Maybe cap the max length - - # Return as dict, including pixel_values - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels_full, - "pixel_values": inputs.pixel_values, - "image_grid_thw": inputs["image_grid_thw"][0], - } - - -def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]): - # Process each example in the batch using the helper function - processed_examples = [] - for i in range(len(batch["response"])): - example = {"local_pdf_path": batch["local_pdf_path"][i], "page_num": batch["page_num"][i], "response": batch["response"][i]} - processed_example = prepare_data_for_qwen2_training( - example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len - ) - processed_examples.append(processed_example) - - return { - "input_ids": [x["input_ids"] for x in processed_examples], - "attention_mask": [x["attention_mask"] for x in processed_examples], - "labels": [x["labels"] for x in processed_examples], - "pixel_values": [x["pixel_values"] for x in processed_examples], - "image_grid_thw": [x["image_grid_thw"] for x in processed_examples], - } - - -def prepare_data_for_molmo_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]): - if isinstance(target_longest_image_dim, list): - target_longest_image_dim = random.choice(target_longest_image_dim) - - if isinstance(target_anchor_text_len, list): - target_anchor_text_len = random.choice(target_anchor_text_len) - - anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len) - base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim) - - # Decode image from base64 - main_image = Image.open(BytesIO(base64.b64decode(base64_page_image))) - - # Process the input text and image - inputs = processor.process( - images=[main_image], - text=build_finetuning_prompt(anchor_text), - ) - - # Get labels by tokenizing the output text - labels = processor.tokenizer(example["response"], return_tensors="np")["input_ids"][0] - # Concatenate input_ids and labels - full_input_ids = torch.cat([inputs["input_ids"], torch.from_numpy(labels)], dim=0) - - labels_full = torch.cat([torch.ones_like(inputs["input_ids"]) * -100, torch.from_numpy(labels)], dim=0) - - # Create a full attention mask - attention_mask = torch.ones_like(full_input_ids) - - # image_input_idx does not need adjustment as images are inserted before labels - image_input_idx = inputs["image_input_idx"] - - return { - "input_ids": full_input_ids, - "labels": labels_full, - "images": inputs["images"], - "image_input_idx": image_input_idx, - "image_masks": inputs["image_masks"], - "attention_mask": attention_mask, - } - - -def batch_prepare_data_for_molmo_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]): - # Assume batch size 1 and process the single example - example = {"local_pdf_path": batch["local_pdf_path"][0], "page_num": batch["page_num"][0], "response": batch["response"][0]} - processed_example = prepare_data_for_molmo_training( - example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len - ) - - # Return in the same format as the qwen2 function - return { - "input_ids": [processed_example["input_ids"]], - "attention_mask": [processed_example["attention_mask"]], - "labels": [processed_example["labels"]], - "images": [processed_example["images"]], - "image_input_idx": [processed_example["image_input_idx"]], - "image_masks": [processed_example["image_masks"]], - } diff --git a/olmocr/train/fixqwen25vlcheckpoint.py b/olmocr/train/fixqwen25vlcheckpoint.py deleted file mode 100644 index f577c84..0000000 --- a/olmocr/train/fixqwen25vlcheckpoint.py +++ /dev/null @@ -1,122 +0,0 @@ -import argparse -import concurrent.futures -import json -import os - -import boto3 -import torch -from smart_open import smart_open -from tqdm import tqdm -from transformers import Qwen2_5_VLForConditionalGeneration - -from olmocr.s3_utils import parse_s3_path - -s3_client = boto3.client("s3") - - -def download_file_from_s3(bucket_name, key, local_file_path): - """Download a single file from S3.""" - s3_client.download_file(bucket_name, key, local_file_path) - print(f"Downloaded {key} to {local_file_path}") - - -def download_model_from_s3(bucket_name, model_s3_key, local_model_dir): - if not os.path.exists(local_model_dir): - os.makedirs(local_model_dir) - - # List objects in the S3 model path - response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=model_s3_key) - objects = response.get("Contents", []) - - # Prepare list of download tasks - download_tasks = [] - for obj in objects: - key = obj["Key"] - if key.endswith("/"): - continue # Skip directories - - local_file_path = os.path.join(local_model_dir, os.path.basename(key)) - download_tasks.append((bucket_name, key, local_file_path)) - - # Use a ThreadPoolExecutor to download files in parallel - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks] - - # Wait for all downloads to complete and handle any exceptions - for future in tqdm(concurrent.futures.as_completed(futures)): - try: - future.result() # This will raise any exceptions encountered during download - except Exception as e: - print(f"Error downloading file: {e}") - - -def upload_file_to_s3(local_file_path, bucket_name, s3_key): - """Upload a single file to S3.""" - try: - s3_client.upload_file(local_file_path, bucket_name, s3_key) - print(f"Uploaded {local_file_path} to s3://{bucket_name}/{s3_key}") - except Exception as e: - print(f"Error uploading {local_file_path} to s3://{bucket_name}/{s3_key}: {e}") - - -def save_model_to_s3(local_model_dir, bucket_name, s3_model_key): - """Upload the model directory to S3 in parallel.""" - # Collect all file paths to be uploaded - upload_tasks = [] - for root, dirs, files in os.walk(local_model_dir): - for file in files: - local_file_path = os.path.join(root, file) - s3_key = os.path.join(s3_model_key, file) - upload_tasks.append((local_file_path, bucket_name, s3_key)) - - # Use a ThreadPoolExecutor to upload files in parallel - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(upload_file_to_s3, local_file_path, bucket_name, s3_key) for local_file_path, bucket_name, s3_key in upload_tasks] - - # Wait for all uploads to complete and handle any exceptions - for future in concurrent.futures.as_completed(futures): - try: - future.result() # This will raise any exceptions encountered during upload - except Exception as e: - print(f"Error during upload: {e}") - - -def main(): - parser = argparse.ArgumentParser(description="Fix up a Qwen2VL checkpoint saved on s3 or otherwise, so that it will load properly in vllm/birr") - parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.") - args = parser.parse_args() - - # Now, download the config.json from the original path and verify the architectures - config_path = os.path.join(args.s3_path, "config.json") - - with smart_open(config_path, "r") as f: - config_data = json.load(f) - - assert config_data["architectures"] == ["Qwen2_5_VLForConditionalGeneration"] - - if config_data["torch_dtype"] == "float32": - print("Detected model is float32, this is probably an FSDP checkpoint") - print("Saving to _bf16 location with adjusted parameters") - - bucket, prefix = parse_s3_path(args.s3_path) - td = "/tmp/qwen2_checkpoint_saving" - download_model_from_s3(bucket, prefix, td) - - print("Downloaded entire model from s3, resaving as bfloat16") - model = Qwen2_5_VLForConditionalGeneration.from_pretrained(td) - model = model.to(torch.bfloat16) - os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True) - - print("Saving...") - model.save_pretrained(os.path.join(td, "bf16_checkpoint")) - - print("Uploading") - save_model_to_s3(os.path.join(td, "bf16_checkpoint"), bucket, prefix.rstrip("/") + "/bf16") - - args.s3_path = args.s3_path.rstrip("/") + "/bf16" - - print("Model updated successfully.") - - -if __name__ == "__main__": - main() diff --git a/olmocr/train/hf/__init__.py b/olmocr/train/hf/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/olmocr/train/hf/convertjsontoparquet.py b/olmocr/train/hf/convertjsontoparquet.py deleted file mode 100644 index 3f76a5e..0000000 --- a/olmocr/train/hf/convertjsontoparquet.py +++ /dev/null @@ -1,371 +0,0 @@ -# Script to generate parquet dataset files to upload to hugging face -# Input is a dataset location /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json -# Each json line has a custom id that looks like {"custom_id": "s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1", ... more data} - -# Fix this script so that it works, and that it will take a path to an input dataset, and sqllite database location -# And then it will build a parquet file with rows that look like: "id", "url", "page_number", "response" -# Where Id will be the output of parse_pdf_hash plus "-" plus the page number -# The url will be the result of get_uri_from_db -# Rresponse will be NormalizedEntry.text -import argparse -import concurrent.futures -import glob -import json -import multiprocessing -import os -import re -import sqlite3 -from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Tuple -from urllib.parse import urlparse - -import boto3 -import pandas as pd -from pypdf import PdfReader, PdfWriter -from tqdm import tqdm - - -def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]: - """ - Extracts a hash from a pretty PDF S3 URL. - For example, given: - s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1 - it will return "de80a57e6c57b45796d2e020173227f7eae44232". - """ - # Allow an optional "-" at the end. - if pretty_pdf_path.startswith("s3://ai2-s2-pdfs/"): - pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$" - match = re.match(pattern, pretty_pdf_path) - if match: - return match.group(1) + match.group(2) - return None - elif pretty_pdf_path.startswith("s3://ai2-oe-data/reganh/iabooks/"): - return urlparse(pretty_pdf_path).path.split("/")[-1] - else: - raise NotImplementedError() - - -def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]: - """ - Looks up the URL for the given pdf_hash in the sqlite database. - Assumes there is a table called 'pdf_mapping' with a column 'uri'. - """ - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,)) - result = cursor.fetchone() - conn.close() - return result[0].strip() if result and result[0] else None - - -@dataclass(frozen=True) -class NormalizedEntry: - s3_path: str - pagenum: int - text: Optional[str] - finish_reason: Optional[str] - error: Optional[str] = None - - @staticmethod - def from_goldkey(goldkey: str, **kwargs): - """ - Constructs a NormalizedEntry from a goldkey string. - The goldkey is expected to be of the format: - - - """ - s3_path = goldkey[: goldkey.rindex("-")] - page_num = int(goldkey[goldkey.rindex("-") + 1 :]) - return NormalizedEntry(s3_path, page_num, **kwargs) - - @property - def goldkey(self): - return f"{self.s3_path}-{self.pagenum}" - - -def normalize_json_entry(data: dict) -> NormalizedEntry: - """ - Normalizes a JSON entry from any of the supported formats. - It supports: - - Birr: looks for an "outputs" field. - - Already normalized entries: if they contain s3_path, pagenum, etc. - - OpenAI: where the response is in data["response"]["body"]["choices"]. - - SGLang: where the response is in data["response"]["choices"]. - """ - if "outputs" in data: - # Birr case - if data["outputs"] is None: - text = None - finish_reason = None - else: - text = data["outputs"][0]["text"] - finish_reason = data["outputs"][0]["finish_reason"] - - return NormalizedEntry.from_goldkey( - goldkey=data["custom_id"], - text=text, - finish_reason=finish_reason, - error=data.get("completion_error", None), - ) - elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]): - # Already normalized - return NormalizedEntry(**data) - elif "response" in data and "body" in data["response"] and "choices" in data["response"]["body"]: - return NormalizedEntry.from_goldkey( - goldkey=data["custom_id"], - text=data["response"]["body"]["choices"][0]["message"]["content"], - finish_reason=data["response"]["body"]["choices"][0]["finish_reason"], - ) - else: - raise ValueError("Unsupported JSON format") - - -def parse_s3_url(s3_url: str) -> Tuple[str, str]: - """ - Parses an S3 URL of the form s3://bucket/key and returns (bucket, key). - """ - if not s3_url.startswith("s3://"): - raise ValueError(f"Invalid S3 URL: {s3_url}") - s3_path = s3_url[5:] - bucket, key = s3_path.split("/", 1) - return bucket, key - - -def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]: - """ - Downloads the PDF from the given S3 URL into the specified cache directory. - The destination filename is based on the parsed PDF hash. - Returns the path to the downloaded PDF. - """ - try: - bucket, key = parse_s3_url(s3_url) - s3_client = boto3.client("s3") - pdf_hash = parse_pdf_hash(s3_url) - if not pdf_hash: - # Fallback: use a sanitized version of the s3_url - pdf_hash = re.sub(r"\W+", "_", s3_url) - dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf") - # Avoid re-downloading if already exists - if not os.path.exists(dest_path): - s3_client.download_file(bucket, key, dest_path) - return dest_path - except Exception as e: - print(f"Error downloading {s3_url}: {e}") - return None - - -def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]: - """ - Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url. - Writes a new single-page PDF to the output_pdf_dir using the combined_id as the filename. - Returns the relative path to the new PDF (e.g., "pdfs/.pdf"). - """ - try: - local_cached_pdf = pdf_cache.get(s3_url) - if not local_cached_pdf or not os.path.exists(local_cached_pdf): - print(f"Cached PDF not found for {s3_url}") - return None - reader = PdfReader(local_cached_pdf) - # pypdf uses 0-indexed page numbers - page_index = page_number - 1 - if page_index < 0 or page_index >= len(reader.pages): - print(f"Page number {page_number} out of range for PDF {s3_url}") - return None - writer = PdfWriter() - writer.add_page(reader.pages[page_index]) - output_filename = f"{combined_id}.pdf" - output_path = os.path.join(output_pdf_dir, output_filename) - with open(output_path, "wb") as f_out: - writer.write(f_out) - # Return the relative path (assuming pdfs/ folder is relative to the parquet file location) - return os.path.join("pdfs", output_filename) - except Exception as e: - print(f"Error processing PDF page for {s3_url} page {page_number}: {e}") - return None - - -def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]: - """ - Process a single file and return a tuple: - (list of valid rows, number of rows skipped due to missing URL or PDF extraction/filtering). - For each JSON entry, the function: - - Normalizes the JSON. - - Skips entries whose response contains the word "resume" (any case) along with either an email address or a phone number. - - Extracts the PDF hash and builds the combined id. - - Looks up the corresponding URL from the sqlite database. - - Extracts the specified page from the cached PDF and writes it to output_pdf_dir. - - Outputs a row with "id", "url", "page_number", "response". - """ - rows = [] - missing_count = 0 - email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b" - phone_regex = r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b" - - try: - with open(file_path, "r", encoding="utf-8") as f: - for line_num, line in enumerate(f, start=1): - line = line.strip() - if not line: - continue - try: - data = json.loads(line) - except json.JSONDecodeError as e: - print(f"Skipping invalid JSON at {file_path}:{line_num} - {e}") - continue - - try: - normalized = normalize_json_entry(data) - except Exception as e: - print(f"Error normalizing entry at {file_path}:{line_num} - {e}") - continue - - # Apply filter: skip if response contains "resume" (any case) and an email or phone number. - response_text = normalized.text if normalized.text else "" - if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)): - print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}") - continue - - # Extract the PDF hash from the s3_path. - pdf_hash = parse_pdf_hash(normalized.s3_path) - if pdf_hash is None: - print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}") - continue - - # The output id is the pdf hash plus '-' plus the page number. - combined_id = f"{pdf_hash}-{normalized.pagenum}" - - # Look up the corresponding URL from the sqlite database. - url = get_uri_from_db(db_path, pdf_hash) - if not url: - print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}") - missing_count += 1 - continue - - # Process PDF: extract the specified page from the cached PDF. - local_pdf_path = process_pdf_page(normalized.s3_path, normalized.pagenum, combined_id, output_pdf_dir, pdf_cache) - if local_pdf_path is None: - print(f"Skipping entry because PDF processing failed for {normalized.s3_path} page {normalized.pagenum} at {file_path}:{line_num}") - missing_count += 1 - continue - - row = { - "id": combined_id, - "url": url, - "page_number": normalized.pagenum, - "response": normalized.text, - } - rows.append(row) - except Exception as e: - print(f"Error processing file {file_path}: {e}") - return rows, missing_count - - -def scan_file_for_s3_urls(file_path: str) -> Set[str]: - """ - Scans a single file and returns a set of unique S3 URLs found in the JSON entries. - """ - urls = set() - try: - with open(file_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - try: - data = json.loads(line) - normalized = normalize_json_entry(data) - urls.add(normalized.s3_path) - except Exception: - # Skip entries that cannot be normalized - continue - except Exception as e: - print(f"Error reading file {file_path}: {e}") - return urls - - -def main(): - parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.") - parser.add_argument( - "input_dataset", - help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')", - ) - parser.add_argument("db_path", help="Path to the SQLite database file.") - parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.") - - args = parser.parse_args() - - files = glob.glob(args.input_dataset) - print(f"Found {len(files)} files matching pattern: {args.input_dataset}") - - # Determine output directory and create 'pdfs' subfolder. - output_abs_path = os.path.abspath(args.output) - output_dir = os.path.dirname(output_abs_path) - pdfs_dir = os.path.join(output_dir, "pdfs") - os.makedirs(pdfs_dir, exist_ok=True) - - # Create a temporary directory for caching PDFs. - pdf_cache_dir = "/tmp/pdf_cache" - os.makedirs(pdf_cache_dir, exist_ok=True) - - print(f"Caching PDFs to temporary directory: {pdf_cache_dir}") - - # --------------------------------------------------------------------- - # Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor. - unique_s3_urls: Set[str] = set() - print("Scanning input files to collect unique PDF URLs...") - num_cpus = multiprocessing.cpu_count() - with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 4) as executor: - results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files")) - for url_set in results: - unique_s3_urls |= url_set - - print(f"Found {len(unique_s3_urls)} unique PDF URLs.") - - # --------------------------------------------------------------------- - # Step 2: Download all unique PDFs to the cache directory. - pdf_cache: Dict[str, str] = {} - print("Caching PDFs from S3...") - with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor: - future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls} - for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"): - s3_url = future_to_url[future] - try: - local_path = future.result() - if local_path: - pdf_cache[s3_url] = local_path - else: - print(f"Failed to cache PDF for {s3_url}") - except Exception as e: - print(f"Error caching PDF for {s3_url}: {e}") - - # --------------------------------------------------------------------- - # Step 3: Process input files using the precached PDFs. - all_rows = [] - total_missing = 0 - print("Processing files...") - with concurrent.futures.ProcessPoolExecutor() as executor: - futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files} - for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"): - file_path = futures[future] - try: - rows, missing_count = future.result() - all_rows.extend(rows) - total_missing += missing_count - except Exception as e: - print(f"Error processing file {file_path}: {e}") - - if all_rows: - df = pd.DataFrame(all_rows) - # Set the "id" column as the index. - df.set_index("id", inplace=True) - df.to_parquet(args.output) - - valid_count = len(df) - total_processed = valid_count + total_missing - print(f"Successfully wrote {valid_count} rows to {args.output}") - print(f"Rows skipped due to missing URL/PDF or filtering: {total_missing} out of {total_processed} processed rows") - else: - print("No valid rows to write. Exiting.") - - -if __name__ == "__main__": - main() diff --git a/olmocr/train/hf/hfhub_upload.py b/olmocr/train/hf/hfhub_upload.py deleted file mode 100644 index 33f922a..0000000 --- a/olmocr/train/hf/hfhub_upload.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -import os -import tarfile -from math import ceil - -from huggingface_hub import HfApi - -# Configuration -pdf_dir = "pdfs" # Directory with PDF files (flat structure) -tarball_dir = "tarballs" # Directory where tar.gz files will be saved -os.makedirs(tarball_dir, exist_ok=True) -repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID - -# Set up logging to file -logging.basicConfig(filename="upload.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - - -def process_chunk(args): - """ - Worker function to create a tar.gz file for a given chunk. - Returns a tuple: (chunk_index, success (bool), message). - """ - chunk_index, chunk_files = args - tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz" - tarball_path = os.path.join(tarball_dir, tarball_name) - - try: - with tarfile.open(tarball_path, "w:gz") as tar: - for pdf_filename in chunk_files: - pdf_path = os.path.join(pdf_dir, pdf_filename) - # Add the file with its basename to maintain a flat structure - tar.add(pdf_path, arcname=pdf_filename) - logging.info(f"Chunk {chunk_index:04d}: Created '{tarball_name}' with {len(chunk_files)} PDFs.") - return chunk_index, True, "Success" - except Exception as e: - error_msg = f"Chunk {chunk_index:04d}: Error creating '{tarball_name}': {e}" - logging.error(error_msg) - return chunk_index, False, error_msg - - -def main(): - # List all PDF files (assuming a flat directory) - try: - pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")]) - except Exception as e: - logging.error(f"Error listing PDFs in '{pdf_dir}': {e}") - return - - total_files = len(pdf_files) - chunk_size = 5000 - total_chunks = ceil(total_files / chunk_size) - logging.info(f"Found {total_files} PDFs; dividing into {total_chunks} chunks of up to {chunk_size} files each.") - - # # Enumerate chunks (starting at 0000) - # chunks = [] - # for idx in range(total_chunks): - # start = idx * chunk_size - # end = start + chunk_size - # chunk_files = pdf_files[start:end] - # chunks.append((idx, chunk_files)) - - # # Create tarballs in parallel - # results = [] - # with ProcessPoolExecutor() as executor: - # futures = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} - # for future in tqdm(as_completed(futures), total=len(futures), desc="Creating tarballs"): - # try: - # result = future.result() - # results.append(result) - # chunk_index, success, message = result - # if not success: - # logging.error(f"Chunk {chunk_index:04d} failed: {message}") - # except Exception as e: - # logging.error(f"Unexpected error processing a chunk: {e}") - - # # Abort upload if any tarball creation failed - # failed_chunks = [r for r in results if not r[1]] - # if failed_chunks: - # logging.error(f"{len(failed_chunks)} chunk(s) failed to create. Aborting upload.") - # return - - # All tarballs created successfully; now upload the entire tarball directory - - api = HfApi() - logging.info("Starting upload of tarballs folder to Hugging Face Hub...") - # This will upload all files in tarball_dir to the repo under "pdf_tarballs" - api.upload_large_folder( - folder_path=tarball_dir, - repo_id=repo_id, - # path_in_repo="pdf_tarballs", - repo_type="dataset", - ) - logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.") - - -if __name__ == "__main__": - main() diff --git a/olmocr/train/hf/warc_parser.py b/olmocr/train/hf/warc_parser.py deleted file mode 100644 index 715c5e4..0000000 --- a/olmocr/train/hf/warc_parser.py +++ /dev/null @@ -1,138 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import sqlite3 -from concurrent.futures import ThreadPoolExecutor -from functools import partial - -import boto3 -from tqdm import tqdm -from warcio.archiveiterator import ArchiveIterator - - -def parse_s3_path(s3_path): - """ - Parses an S3 path of the form s3://bucket/prefix and returns the bucket and prefix. - """ - if not s3_path.startswith("s3://"): - raise ValueError("S3 path must start with s3://") - without_prefix = s3_path[5:] - parts = without_prefix.split("/", 1) - bucket = parts[0] - prefix = parts[1] if len(parts) > 1 else "" - return bucket, prefix - - -def list_s3_warc_objects(s3_path, suffix=".warc.gz"): - """ - Lists all objects under the given S3 path that end with the provided suffix. - Uses a paginator to handle large result sets. - """ - bucket, prefix = parse_s3_path(s3_path) - s3_client = boto3.client("s3") - paginator = s3_client.get_paginator("list_objects_v2") - warc_keys = [] - for page in paginator.paginate(Bucket=bucket, Prefix=prefix): - if "Contents" in page: - for obj in page["Contents"]: - key = obj["Key"] - if key.endswith(suffix): - warc_keys.append(key) - return bucket, warc_keys, s3_client - - -def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576): - """ - Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request, - and extracts the first response record's target URI from the HTTP headers. - """ - target_uri = None - try: - response = s3_client.get_object(Bucket=bucket, Key=key, Range=f"bytes=0-{head_bytes-1}") - stream = response["Body"] - for record in ArchiveIterator(stream): - for name, value in record.rec_headers.headers: - if name == "WARC-Target-URI": - target_uri = value - break - if target_uri: - break # Only use the first valid response record - except Exception as e: - tqdm.write(f"Error processing s3://{bucket}/{key}: {e}") - return target_uri - - -def create_db(db_path): - """ - Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists, - including an index on pdf_hash. - """ - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS pdf_mapping ( - pdf_hash TEXT PRIMARY KEY, - uri TEXT - ) - """ - ) - cursor.execute( - """ - CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash) - """ - ) - conn.commit() - return conn - - -def process_warc_file(key, bucket, s3_client): - """ - Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri) - if successful, otherwise returns None. - """ - uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576) - if uri: - # Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf. - pdf_hash = key.split("/")[-1].replace(".warc.gz", ".pdf") - return (pdf_hash, uri) - else: - tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}") - return None - - -def process_s3_folder(s3_path, db_path): - """ - Lists all .warc.gz files under the provided S3 path, then processes each file in parallel - to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's - basename with .warc.gz replaced by .pdf) is stored in the SQLite database. - """ - bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix=".warc.gz") - conn = create_db(db_path) - cursor = conn.cursor() - - # Process WARC files concurrently using ThreadPoolExecutor. - results = [] - func = partial(process_warc_file, bucket=bucket, s3_client=s3_client) - with ThreadPoolExecutor() as executor: - for result in tqdm(executor.map(func, warc_keys), total=len(warc_keys), desc="Processing S3 WARC files"): - if result is not None: - results.append(result) - - # Bulk insert into the database. - conn.execute("BEGIN") - for pdf_hash, uri in results: - cursor.execute("INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", (pdf_hash, uri)) - conn.commit() - conn.close() - - -def main(): - parser = argparse.ArgumentParser(description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files.") - parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files") - parser.add_argument("db_file", help="Path for the output SQLite database file") - args = parser.parse_args() - process_s3_folder(args.s3_path, args.db_file) - - -if __name__ == "__main__": - main() diff --git a/olmocr/train/inference.py b/olmocr/train/inference.py deleted file mode 100644 index d28bffa..0000000 --- a/olmocr/train/inference.py +++ /dev/null @@ -1,66 +0,0 @@ -import base64 -from io import BytesIO - -import torch -import torch.distributed -from PIL import Image -from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration - -from olmocr.data.renderpdf import render_pdf_to_base64png -from olmocr.prompts.anchor import get_anchor_text -from olmocr.prompts.prompts import build_openai_silver_data_prompt - - -@torch.no_grad() -def run_inference(model_name: str): - config = AutoConfig.from_pretrained(model_name) - processor = AutoProcessor.from_pretrained(model_name) - - # If it doesn't load, change the type:mrope key to "default" - - # model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config) - model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config) - model.eval() - - # local_pdf_path = os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "horribleocr.pdf") - local_pdf_path = "/root/brochure.pdf" - page = 1 - - image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024) - anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport") - - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": build_openai_silver_data_prompt(anchor_text)}, - {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}, - ], - } - ] - - # Preparation for inference - text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - - main_image = Image.open(BytesIO(base64.b64decode(image_base64))) - - inputs = processor( - text=[text], - images=[main_image], - padding=True, - return_tensors="pt", - ) - inputs = inputs.to("cuda") - - output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500) - generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(inputs["input_ids"], output_ids)] - output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) - print(output_text[0]) - - -def main(): - run_inference(model_name="Qwen/Qwen2.5-VL-7B-Instruct") - - -if __name__ == "__main__": - main() diff --git a/olmocr/train/loaddataset.py b/olmocr/train/loaddataset.py deleted file mode 100644 index d192650..0000000 --- a/olmocr/train/loaddataset.py +++ /dev/null @@ -1,50 +0,0 @@ -from transformers import AutoProcessor - -from olmocr.train.core.cli import make_cli -from olmocr.train.core.config import TrainConfig - -from .utils import make_dataset - - -def main(): - train_config = make_cli(TrainConfig) # pyright: ignore - - processor = AutoProcessor.from_pretrained(train_config.model.name_or_path, trust_remote_code=True) - train_dataset, valid_dataset = make_dataset(train_config, processor) - - print("Training dataset........") - print(train_dataset) - - train_example = train_dataset[0] - print(train_example) - print({(x, y.shape) for x, y in train_example.items()}) - print("\nTokens") - print(processor.tokenizer.batch_decode(train_example["input_ids"])) - - print("\n\n") - - print("Validation dataset........") - print(valid_dataset) - print(valid_dataset[list(valid_dataset.keys())[0]][0]) - print("\n\n") - - print("Datasets loaded into hugging face cache directory") - - # data_collator = TruncatingCollator( - # max_length=4096 - # ) - - # train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=data_collator) - # max_seen_len = 0 - # for index, entry in tqdm(enumerate(train_dataloader)): - # if index == 0: - # print(entry) - - # num_input_tokens = entry["input_ids"].shape[1] - # max_seen_len = max(max_seen_len, num_input_tokens) - - # print(max_seen_len) - - -if __name__ == "__main__": - main() diff --git a/olmocr/train/molmo/__init__.py b/olmocr/train/molmo/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/olmocr/train/molmo/config_molmo.py b/olmocr/train/molmo/config_molmo.py deleted file mode 100644 index 24d09e4..0000000 --- a/olmocr/train/molmo/config_molmo.py +++ /dev/null @@ -1,59 +0,0 @@ -from transformers import PretrainedConfig - - -class MolmoConfig(PretrainedConfig): - model_type = "molmo" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=50304, - embedding_size=50304, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - max_position_embeddings=2048, - initializer_range=0.02, - use_cache=True, - layer_norm_eps: float = 1e-5, - rope_theta=10000.0, - clip_qkv=None, - qkv_bias: bool = False, - weight_tying: bool = False, - use_position_ids: bool = True, - tie_word_embeddings: bool = True, - attention_layer_norm: bool = False, - norm_after: bool = False, - layer_norm_type: str = "rms", - **kwargs, - ): - self.vocab_size = vocab_size - self.embedding_size = embedding_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.layer_norm_eps = layer_norm_eps - self.weight_tying = weight_tying - self.use_position_ids = use_position_ids - self.attention_layer_norm = attention_layer_norm - self.num_key_value_heads = num_key_value_heads - self.initializer_range = initializer_range - self.use_cache = use_cache - self.rope_theta = rope_theta - self.clip_qkv = clip_qkv - self.qkv_bias = qkv_bias - self.norm_after = norm_after - self.tie_word_embeddings = tie_word_embeddings - self.layer_norm_type = layer_norm_type - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -MolmoConfig.register_for_auto_class() diff --git a/olmocr/train/molmo/image_processing_molmo.py b/olmocr/train/molmo/image_processing_molmo.py deleted file mode 100644 index ba68435..0000000 --- a/olmocr/train/molmo/image_processing_molmo.py +++ /dev/null @@ -1,497 +0,0 @@ -"""Image processor class for Molmo""" - -from typing import List, Optional, Union - -import einops -import numpy as np -import torch -import torchvision.transforms -from torchvision.transforms import InterpolationMode -from torchvision.transforms.functional import convert_image_dtype -from transformers.image_processing_utils import BaseImageProcessor -from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput -from transformers.processing_utils import ImagesKwargs -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -def pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width, value=0): - height, width = image.shape[:2] - after_padding_width = target_width - offset_width - width - after_padding_height = target_height - offset_height - height - return np.pad(image, [[offset_height, after_padding_height], [offset_width, after_padding_width], [0, 0]], constant_values=value) - - -def normalize_image(image, offset, scale): - image -= np.array(offset, dtype=np.float32)[None, None, :] - image /= np.array(scale, dtype=np.float32)[None, None, :] - return image - - -def resize_and_pad( - image, - desired_output_size, - resize_method="torch-bilinear", - pad_value=0, - normalize=True, - image_mean=OPENAI_CLIP_MEAN, - image_std=OPENAI_CLIP_STD, -): - desired_height, desired_width = desired_output_size - height, width = image.shape[:2] - - # Cast into float32 since the training code did this in float32 and it (very rarely) effects - # the results after rounding. - image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32) - image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32) - image_scale = min(image_scale_x, image_scale_y) - scaled_height = int(np.array(height, np.float32) * image_scale) - scaled_width = int(np.array(width, np.float32) * image_scale) - - if resize_method == "tensorflow": - # This how the original training code did resizing, it can produce slightly different - # results then using torch resize so we keep it just in case - import tensorflow as tf - - image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32) - image = tf.image.resize( - image, - [scaled_height, scaled_width], - method=tf.image.ResizeMethod.BILINEAR, - antialias=True, - ) - image = tf.clip_by_value(image, 0.0, 1.0) - image = image.numpy() - elif resize_method == "torch-bilinear": - image = torch.permute(torch.from_numpy(image), [2, 0, 1]) - image = convert_image_dtype(image) # resize in float32 to match the training code - image = torchvision.transforms.Resize([scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True)(image) - image = torch.clip(image, 0.0, 1.0) - image = torch.permute(image, [1, 2, 0]).numpy() - else: - raise NotImplementedError(resize_method) - - top_pad = (desired_height - scaled_height) // 2 - left_pad = (desired_width - scaled_width) // 2 - padding = [[top_pad, desired_height - scaled_height - top_pad], [left_pad, desired_width - scaled_width - left_pad], [0, 0]] - image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2]) - image = np.pad(image, padding, constant_values=pad_value) - if normalize: - image = normalize_image(image, offset=image_mean, scale=image_std) - return image, image_mask - - -def select_tiling(h, w, patch_size, max_num_patches): - """Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size""" - original_size = np.stack([h, w]) # [1, 2] - original_res = h * w - tilings = [] - for i in range(1, max_num_patches + 1): - for j in range(1, max_num_patches + 1): - if i * j <= max_num_patches: - tilings.append((i, j)) - # sort so argmin and argmax favour smaller tilings in the event of a tie - tilings.sort(key=lambda x: (x[0] * x[1], x[0])) - candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2] - candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2] - - # How much we would need to scale the image to fit exactly in each tiling - original_size = np.stack([h, w], dtype=np.float32) # [1, 2] - required_scale_d = candidate_resolutions.astype(np.float32) / original_size - required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1] - if np.all(required_scale < 1): - # We are forced to downscale, so try to minimize the amount of downscaling - ix = np.argmax(required_scale) - else: - # Pick the resolution that required the least upscaling so that it most closely fits the image - required_scale = np.where(required_scale < 1.0, 10e9, required_scale) - ix = np.argmin(required_scale) - return candidate_tilings[ix] - - -class MolmoImagesKwargs(ImagesKwargs, total=False): - max_crops: Optional[int] - overlap_margins: Optional[List[int]] - base_image_input_size: Optional[List[int]] - image_token_length_w: Optional[int] - image_token_length_h: Optional[int] - image_patch_size: Optional[int] - image_padding_mask: Optional[bool] - - -class MolmoImageProcessor(BaseImageProcessor): - """Preprocess images and multi-model inputs""" - - def __init__( - self, - max_crops: int = 12, - overlap_margins: List[int] = (4, 4), - base_image_input_size: List[int] = (336, 336), - image_token_length_w: int = 12, - image_token_length_h: int = 12, - image_patch_size: int = 14, - image_padding_mask: bool = True, - do_normalize: bool = True, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.max_crops = max_crops - self.overlap_margins = overlap_margins - self.base_image_input_size = base_image_input_size - self.image_token_length_w = image_token_length_w - self.image_token_length_h = image_token_length_h - self.image_patch_size = image_patch_size - self.image_padding_mask = image_padding_mask - self.do_normalize = do_normalize - self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN - self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD - - def image_to_patches_and_tokens( - self, - image: ImageInput, - image_patch_token_id: int, - image_col_token_id: int, - image_start_token_id: int, - image_end_token_id: int, - max_crops: Optional[int] = None, - overlap_margins: Optional[List[int]] = None, - base_image_input_size: Optional[Union[int, List[int]]] = None, - image_token_length_w: Optional[int] = None, - image_token_length_h: Optional[int] = None, - image_patch_size: Optional[int] = None, - ): - if isinstance(base_image_input_size, int): - base_image_input_size = (base_image_input_size, base_image_input_size) - - base_image_input_d = image_patch_size - tokens_per_image = image_token_length_w * image_token_length_h - image_base_patch_w = base_image_input_size[1] // base_image_input_d - image_base_patch_h = base_image_input_size[0] // base_image_input_d - - original_image_h, original_image_w = image.shape[:2] - crop_size = base_image_input_size[0] - - # Discard this many patches from the (left/top, right/bottom) of crops - left_margin, right_margin = overlap_margins - # left_margin, right_margin = 2, 2 - assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling - total_margin_pixels = base_image_input_d * (right_margin + left_margin) # pixels removed per dim - crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim - crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches - crop_window_size = crop_window_patches * base_image_input_d - tiling = select_tiling(original_image_h - total_margin_pixels, original_image_w - total_margin_pixels, crop_window_size, max_crops) - src, img_mask = resize_and_pad(image, [tiling[0] * crop_window_size + total_margin_pixels, tiling[1] * crop_window_size + total_margin_pixels]) - - # Now we have to split the image into crops, while keeping track of how each patch in the - # each crop should be ordered in the global image, this require a lot of tricky booking - n_crops = tiling[0] * tiling[1] - patches_arr = [] - mask_arr = [] - patch_ordering_arr = [] - - # We assume 2x2 pooling, but can allow padding the right/bottom with extra - # patches if the number of patches per side is not even - assert (crop_patches + 1) // 2 == image_token_length_h - assert (crop_patches + 1) // 2 == image_token_length_w - on = 0 - on_patch = 0 - for i in range(tiling[0]): - y0 = i * crop_window_size - if i == 0: - crop_y0 = 0 - else: - crop_y0 = left_margin // 2 - - crop_h = image_base_patch_h - (right_margin + left_margin) - if i == 0: - crop_h += left_margin - if i == (tiling[0] - 1): - crop_h += right_margin - for j in range(tiling[1]): - x0 = j * crop_window_size - if j == 0: - crop_x0 = 0 - else: - crop_x0 = left_margin // 2 - - crop_w = image_base_patch_w - (right_margin + left_margin) - if j == 0: - crop_w += left_margin - if j == (tiling[1] - 1): - crop_w += right_margin - - pooled_w = (crop_w + 1) // 2 - pooled_h = (crop_h + 1) // 2 - patch_ordering_arr.append( - pad_to_bounding_box( - np.reshape(np.arange(on, on + pooled_h * pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)), - crop_y0, - crop_x0, - image_token_length_h, - image_token_length_w, - value=-1, - )[:, :, 0] - ) - patches_arr.append(src[y0 : y0 + crop_size, x0 : x0 + crop_size]) - mask_arr.append(img_mask[y0 : y0 + crop_size, x0 : x0 + crop_size]) - - on += pooled_h * pooled_w - on_patch += 1 - patches = np.stack(patches_arr) - patch_ordering = np.stack(patch_ordering_arr) - img_mask = np.stack(mask_arr) - - # Switch to [n_crops, n_patches, pixels_per_patch] format - image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1] - patches = einops.rearrange( - patches, "p (h dh) (w dw) c -> p (h w) (dh dw c)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w - ) - img_mask = einops.rearrange( - img_mask, "p (h dh) (w dw) -> p (h w) (dh dw)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w - ) - - img_mask = img_mask.astype(np.float32).mean(axis=-1) - patch_ordering = np.reshape(patch_ordering, [-1]) - valid = patch_ordering >= 0 - - # Transpose order, to get left-to-right order instead of crop-by-crop order - patch_ordering_rh = np.reshape(patch_ordering, [tiling[0], tiling[1], image_token_length_h, image_token_length_w]) - patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3]) - patch_ordering_rh = np.reshape(patch_ordering_rh, [-1]) - - # The transpose will screw up which patches are masked, project the - # new order into sparse structure of `patch_ordering` to fix this - patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0] - - # Now build the output tokens - h = tiling[0] * crop_window_patches + (right_margin + left_margin) - w = tiling[1] * crop_window_patches + (right_margin + left_margin) - per_row = np.full( - ((w + 1) // 2,), - image_patch_token_id, - ) - per_row = np.concatenate([per_row, [image_col_token_id]], 0) - - joint = np.tile(per_row, [(h + 1) // 2]) - joint = [[image_start_token_id], joint, [image_end_token_id]] - - # Finally do the same for the global image - resized, _ = resize_and_pad(image, base_image_input_size) - resized = einops.rearrange( - resized, "(h dh) (w dw) c -> (h w) (dh dw c)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w - ) - patches = np.concatenate([np.expand_dims(resized, 0), patches], 0) - - # Global image goes first, so the order of patches in previous crops gets increased - patch_ordering = np.where(patch_ordering >= 0, patch_ordering + tokens_per_image, -1) - patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0) - per_row = np.full( - (image_token_length_w,), - image_patch_token_id, - ) - per_row = np.concatenate([per_row, [image_col_token_id]], 0) - extra_tokens = np.tile(per_row, [image_token_length_h]) - joint = [ - [image_start_token_id], - extra_tokens, - [image_end_token_id], - ] + joint - - joint = np.concatenate(joint, 0) - img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1) - return patches, joint, patch_ordering, img_mask - - def build_image_input_idx( - self, - image_tokens: np.ndarray, - patch_order: np.ndarray, - image_patch_token_id: int, - no_image: Optional[bool] = None, - image_token_length_w: Optional[int] = None, - image_token_length_h: Optional[int] = None, - ): - """Converts `patch_order` into a mapping of token_id -> patch_id""" - - tokens_per_image = image_token_length_w * image_token_length_h - if no_image is not None and no_image: - return np.zeros((0, tokens_per_image), np.int32) - - # Indices to insert the patches - image_input_idx = image_tokens == image_patch_token_id - image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32) - - if patch_order is not None: - n_tokens = image_input_idx.shape[0] - patch_order = np.reshape(patch_order, [-1]) - n_patches = patch_order.shape[0] - - valid = patch_order >= 0 - n_valid_patches = valid.sum() - assert len(image_input_idx) == n_valid_patches - - sorted_patch_ixs = np.zeros([n_tokens], np.int32) - sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32) - - # Project the inverted mapping into same sparse structure - sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1) - sorted_patch_ixs_ex[valid] = sorted_patch_ixs - - # Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs` - valid = (sorted_patch_ixs_ex >= 0).astype(np.int32) - image_input_idx = image_input_idx[sorted_patch_ixs_ex * valid] - image_input_idx = image_input_idx * valid - 100 * (1 - valid) - image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image]) - return image_input_idx - - def preprocess( - self, - image: np.ndarray, - image_patch_token_id: int, - image_col_token_id: int, - image_start_token_id: int, - image_end_token_id: int, - max_crops: Optional[int] = None, - overlap_margins: Optional[List[int]] = None, - base_image_input_size: Optional[Union[int, List[int]]] = None, - image_token_length_w: Optional[int] = None, - image_token_length_h: Optional[int] = None, - image_patch_size: Optional[int] = None, - **kwargs, - ): - """Preprocesses an image - - Returns: - crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might - change between images but the other dimension are fixed - tokens: (n_tokens,) int32 tokens, pad tokens indicate where to insert the - patch features, might include other special tokens as well - image_idx: (n_crops, n_patches) index in `tokens` to put the patch features from the - crops after pooling, negative values indicates patches features to exclude - padding_mask: (n_crops, n_patches) what percent of each crop is padding, can be None - if the image mask is not being used. - """ - - max_crops = max_crops or self.max_crops - overlap_margins = overlap_margins or self.overlap_margins - base_image_input_size = base_image_input_size or self.base_image_input_size - image_token_length_w = image_token_length_w or self.image_token_length_w - image_token_length_h = image_token_length_h or self.image_token_length_h - image_patch_size = image_patch_size or self.image_patch_size - - crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens( - image, - image_patch_token_id, - image_col_token_id, - image_start_token_id, - image_end_token_id, - max_crops, - overlap_margins, - base_image_input_size, - image_token_length_w, - image_token_length_h, - image_patch_size, - ) - patch_idx = self.build_image_input_idx( - image_tokens, - patch_ordering, - image_patch_token_id, - image_token_length_w=image_token_length_w, - image_token_length_h=image_token_length_h, - ) - return crops, image_tokens, patch_idx, img_mask - - def multimodal_preprocess( - self, - images: np.ndarray, - tokens: List[int], - image_idx: np.ndarray, - sequence_length: int, - image_patch_token_id: int, - image_col_token_id: int, - image_start_token_id: int, - image_end_token_id: int, - **kwargs, - ): - """Merge images and text tokens into multi-modal features for the model - - :param images: images to use as input - :param tokens: input text tokens - :param image_idx: where to insert the images into `tokens` - :params image_patch_token_id: id to use of tokens that will contain image features - :params image_col_token_id: token id for image column special tokens - :params image_start_token_id: token id for image start special tokens - :params image_end_token_id: token id for image end special tokens - :params kwargs: override preprocessor default args - """ - max_total_crops = kwargs.get("max_crops") or self.max_crops - image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w - image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h - image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size - base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size - image_num_patch = ( - base_image_input_size[0] // image_patch_size, - base_image_input_size[1] // image_patch_size, - ) - image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask - - tokens_per_image = image_token_length_w * image_token_length_h - n_pixels = image_patch_size * image_patch_size * 3 - n_patches = image_num_patch[0] * image_num_patch[1] - - if images is None: - return { - "input_ids": tokens, - } - else: - n = len(images) - all_crops = [] - all_image_idx = [] - out_tokens = [] - all_crop_masks = [] - - for ix in range(n): - token_ix = image_idx[ix] - crops, image_tokens, patch_idx, img_mask = self.preprocess( - images[ix], - image_patch_token_id, - image_col_token_id, - image_start_token_id, - image_end_token_id, - **kwargs, - ) - - if token_ix == -1: # -1 is an image inserted at the very start - start = 0 - token_ix = 0 - end = 0 - else: - start = 0 if ix == 0 else image_idx[ix - 1] + 1 - end = token_ix + 1 - - all_image_idx.append(patch_idx + token_ix) - all_crops.append(crops) - out_tokens.append(tokens[start:token_ix]) - out_tokens.append(image_tokens) - if ix == (n - 1): - out_tokens.append(tokens[end:]) - if image_padding_mask: - all_crop_masks.append(img_mask) - - input_ids = np.concatenate(out_tokens, 0) - images = np.concatenate(all_crops, 0) - image_input_idx = np.concatenate(all_image_idx, 0) - if image_padding_mask: - image_masks = np.concatenate(all_crop_masks, 0) - else: - image_masks = None - - out = {"input_ids": input_ids, "images": images, "image_input_idx": image_input_idx} - if image_masks is not None: - out["image_masks"] = image_masks - return out - - -MolmoImageProcessor.register_for_auto_class() diff --git a/olmocr/train/molmo/modeling_molmo.py b/olmocr/train/molmo/modeling_molmo.py deleted file mode 100644 index da274e4..0000000 --- a/olmocr/train/molmo/modeling_molmo.py +++ /dev/null @@ -1,2319 +0,0 @@ -# type: ignore -import logging -import math -from copy import deepcopy -from dataclasses import dataclass, replace -from enum import Enum -from typing import ( - Any, - Callable, - Dict, - List, - MutableMapping, - Optional, - Sequence, - Tuple, - Union, - cast, -) - -import torch -from einops import einops -from torch import nn -from torch.nn import functional as F -from transformers import GenerationConfig, PreTrainedModel -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput -from transformers.models.auto import AutoModelForCausalLM - -from .config_molmo import MolmoConfig - -log = logging.getLogger(__name__) - - -class BufferCache(dict, MutableMapping[str, torch.Tensor]): - """ - Cache for attention biases and other things that would normally be stored as buffers. - We avoid using buffers because we've run into various issues doing so with FSDP. - In general it appears the way FSDP handles buffers is not well-defined. - It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid - since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into - NaNs when they're synchronized due to casting or some other issue. - """ - - -class StrEnum(str, Enum): - def __str__(self) -> str: - return self.value - - def __repr__(self) -> str: - return f"'{str(self)}'" - - -class ImageProjectType(StrEnum): - mlp = "mlp" - mlpx2 = "2mlp" - linear = "linear" - - -class ImagePooling2DType(StrEnum): - attention = "attention" - attention_meanq = "attention-meanq" - attention_2wide = "attention_2wide" - attention_v2 = "attention-v2" - none = "none" - stack = "stack" - - -class ActivationType(StrEnum): - quick_gelu = "quick_gelu" - gelu = "gelu" - gelu_tanh = "gelu_tanh" - relu = "relu" - silu = "silu" - llama_geglu = "llama_geglu" - llama_geglu_tanh = "llama_geglu_tanh" - llama_swiglu = "llama_swiglu" - swiglu = "swiglu" - - -def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): - """ - Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` - is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. - """ - if check_neg_inf: - x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) - if check_pos_inf: - x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) - - -class MolmoConfigurationError(Exception): - pass - - -def _non_meta_init_device(config) -> torch.device: - if config.init_device is not None and config.init_device != "meta": - return torch.device(config.init_device) - else: - return torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -class RotaryEmbedding(nn.Module): - """ - [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864). - """ - - def __init__(self, config: MolmoConfig, cache: BufferCache): - super().__init__() - self.config = config - self.__cache = cache - # Warm up cache. - self.get_rotary_embedding(config.max_position_embeddings or config.max_sequence_length, _non_meta_init_device(config)) - - def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: - if ( - (pos_sin := self.__cache.get("rope_pos_sin")) is not None - and (pos_cos := self.__cache.get("rope_pos_cos")) is not None - and pos_sin.shape[-2] >= seq_len - and pos_cos.shape[-2] >= seq_len - ): - if pos_sin.device != device: - pos_sin = pos_sin.to(device) - self.__cache["rope_pos_sin"] = pos_sin - if pos_cos.device != device: - pos_cos = pos_cos.to(device) - self.__cache["rope_pos_cos"] = pos_cos - return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :] - - with torch.autocast(device.type, enabled=False): - dim = self.config.d_model // self.config.n_heads - inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) - seq = torch.arange(seq_len, device=device, dtype=torch.float) - freqs = torch.einsum("i , j -> i j", seq, inv_freq) - if self.config.rope_impl == "interleave": - positions = freqs.repeat_interleave(2, dim=-1) - else: - positions = torch.cat((freqs, freqs), dim=-1) - pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :] - self.__cache["rope_pos_sin"] = pos_sin - self.__cache["rope_pos_cos"] = pos_cos - return pos_sin, pos_cos - - def rotate_half(self, x: torch.Tensor) -> torch.Tensor: - B, nh, T, hs = x.size() - x = x.view(B, nh, T, 2, hs // 2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - def rotate_every_two(self, x: torch.Tensor) -> torch.Tensor: - B, nh, T, hs = x.size() - x = x.view(B, nh, T, hs // 2, 2) - x1, x2 = x.unbind(dim=-1) - x = torch.stack((-x2, x1), dim=-1) - return x.view(B, nh, T, hs) - - def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - if self.config.rope_impl == "interleave": - return ((t * pos_cos) + (self.rotate_every_two(t) * pos_sin)).to(t.dtype) - else: - return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype) - - def forward(self, q: torch.Tensor, k: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if self.config.rope_full_precision: - q_, k_ = q.float(), k.float() - else: - q_, k_ = q, k - - with torch.autocast(q.device.type, enabled=False): - batch_size = q_.shape[0] - query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None - if position_ids is not None: - freqs_cis_len = self.config.max_position_embeddings or self.config.max_sequence_length - else: - freqs_cis_len = key_len - pos_sin, pos_cos = self.get_rotary_embedding(freqs_cis_len, q_.device) - pos_sin = pos_sin.type_as(q_) - pos_cos = pos_cos.type_as(q_) - if position_ids is not None: - assert query_len == key_len, "Query and key lengths must be equal when using position IDs." - pos_sin = pos_sin[0, 0][position_ids].view((batch_size, 1, key_len, pos_sin.shape[-1])) - pos_cos = pos_cos[0, 0][position_ids].view((batch_size, 1, key_len, pos_cos.shape[-1])) - q_ = self.apply_rotary_pos_emb( - pos_sin[:, :, key_len - query_len : key_len, :], - pos_cos[:, :, key_len - query_len : key_len, :], - q_, - ) - k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_) - return q_.type_as(q), k_.type_as(k) - - -class MolmoBlock(nn.Module): - """ - A base class for transformer block implementations. - """ - - def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache): - super().__init__() - self.layer_id = layer_id - self.config = config - self.hidden_size = config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model - self.__cache = cache - self._activation_checkpoint_fn = None - - # Dropout. - self.dropout = Dropout(config.residual_dropout) - - # Layer norms. - self.k_norm: Optional[LayerNormBase] = None - self.q_norm: Optional[LayerNormBase] = None - if config.attention_layer_norm: - assert config.effective_n_kv_heads is not None - self.k_norm = LayerNormBase.build( - config, - size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, - elementwise_affine=config.attention_layer_norm_with_affine, - ) - self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) - - # Make sure QKV clip coefficient is positive, otherwise it's not well-defined. - if config.clip_qkv is not None: - assert config.clip_qkv > 0 - - # Activation function. - self.act = Activation.build(config) - assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 - - # Attention output projection. - input_dim = config.d_model - self.attn_out = nn.Linear(input_dim, config.d_model, bias=config.include_bias, device=config.init_device) - - # Feed-forward output projection. - self.ff_out = nn.Linear( - int(self.act.output_multiplier * self.hidden_size), - config.d_model, - bias=config.include_bias, - device=config.init_device, - ) - self.ff_out._is_residual = True # type: ignore - - # Rotary embeddings. - if self.config.rope: - self.rotary_emb = RotaryEmbedding(config, self.__cache) - - self.flash_attn_func = None - if config.attention_type == "flash": - try: - from flash_attn import flash_attn_func # type: ignore - - self.flash_attn_func = flash_attn_func - except ModuleNotFoundError: - pass - - def reset_parameters(self): - if self.k_norm is not None: - self.k_norm.reset_parameters() - if self.q_norm is not None: - self.q_norm.reset_parameters() - init_weights( - self.config, - self.attn_out, - d=self.config.d_model, - layer_id=self.layer_id, - type_of_module=ModuleType.out_module, - ) - init_weights( - self.config, - self.ff_out, - d=self.ff_out.in_features, - layer_id=self.layer_id, - type_of_module=ModuleType.out_module, - ) - - @classmethod - def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: - target_dtype = input_dtype - # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function - # `is_autocast_cpu_enabled()` for CPU autocast. - # See https://github.com/pytorch/pytorch/issues/110966. - if bias.device.type == "cuda" and torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled(): - target_dtype = torch.get_autocast_cpu_dtype() - if bias.dtype != target_dtype: - bias = bias.to(target_dtype) - ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) - return bias - - def _scaled_dot_product_attention( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - response_dropout_p: float = 0.0, - is_causal: bool = False, - ) -> torch.Tensor: - """ - Computes scaled dot product attention on query, key and value tensors, using an optional - attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. - """ - if attn_mask is not None: - attn_mask = attn_mask.to(q.device) - - if self.flash_attn_func is not None and attn_mask is None: - r = self.flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal) - return r.transpose(1, 2) - else: - # torch's sdpa doesn't support GQA, so we're doing this - assert k.size(1) == v.size(1) - num_kv_heads = k.size(1) - num_q_heads = q.size(1) - if num_q_heads != num_kv_heads: - assert num_q_heads % num_kv_heads == 0 - k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) - v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) - - return F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - ) - - def attention( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attention_bias: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: - B, T, C = q.size() # batch size, sequence length, d_model - dtype = k.dtype - - # Optionally apply layer norm to keys and queries. - if self.q_norm is not None and self.k_norm is not None: - q = self.q_norm(q).to(dtype=dtype) - k = self.k_norm(k).to(dtype=dtype) - - # Move head forward to be next to the batch dim. - # shape: (B, nh, T, hs) - q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) - # shape: (B, n_kv_h, T, hs) - k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) - # shape: (B, n_kv_h, T, hs) - v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) - - if self.config.use_position_ids and self.config.rope: - # Apply rotary embeddings - q, k = self.rotary_emb(q, k, position_ids=position_ids) - - if layer_past is not None: - past_key, past_value = layer_past - k = torch.cat((past_key.to(k.device), k), dim=-2) - v = torch.cat((past_value.to(v.device), v), dim=-2) - - present = (k, v) if use_cache else None - query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None - - if not self.config.use_position_ids and self.config.rope: - # Apply rotary embeddings - q, k = self.rotary_emb(q, k) - - if attention_bias is not None: - # Resize and cast attention bias. - # The current dtype of the attention bias might not match the dtype that the SDP attn function will - # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding - # as down-casting the attention bias to the autocast precision will result in -infs, which will - # cause the SDP attn function to produce NaNs. - attention_bias = self._cast_attn_bias(attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype) - - # Get the attention scores. - # shape: (B, nh, T, hs) - att = self._scaled_dot_product_attention( - q, - k, - v, - attn_mask=attention_bias, - dropout_p=0.0 if not self.training else self.config.attention_dropout, - response_dropout_p=0.0 if not self.training else self.config.response_attention_dropout, - is_causal=attention_bias is None, - ) - - # Re-assemble all head outputs side-by-side. - att = att.transpose(1, 2).contiguous().view(B, T, C) - - # Apply output projection. - return self.attn_out(att), present - - def forward( - self, - x: torch.Tensor, - attention_bias: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.Tensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: - raise NotImplementedError - - @classmethod - def build(cls, layer_id: int, config: MolmoConfig, cache: BufferCache): - return MolmoSequentialBlock(layer_id, config, cache) - - -class MolmoSequentialBlock(MolmoBlock): - """ - This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` - (plus another skip connection). - """ - - def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache): - super().__init__(layer_id, config, cache) - # Layer norms. - self.attn_norm = LayerNorm.build(config) - self.ff_norm = LayerNorm.build(config) - # Attention input projection. Projects x -> (q, k, v) - - head_dim = config.d_model // config.n_heads - self.fused_dims = ( - config.d_model, - config.effective_n_kv_heads * head_dim, - config.effective_n_kv_heads * head_dim, - ) - self.att_proj = nn.Linear(config.d_model, sum(self.fused_dims), bias=config.include_bias or config.qkv_bias, device=config.init_device) - # Feed-forward input projection. - self.ff_proj = nn.Linear(config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device) - - def reset_parameters(self): - super().reset_parameters() - self.attn_norm.reset_parameters() - self.ff_norm.reset_parameters() - # NOTE: the standard deviation for these weights does not depend on the layer. - init_weights(self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module) - init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module) - - def forward( - self, - x: torch.Tensor, - attention_bias: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: - # Get query, key, value projections. - # shape: - # - for regular attn q, k, v: (batch_size, seq_len, d_model) - # - for multi-query attn q: (batch_size, seq_len, d_model) - # k, v: (batch_size, seq_len, d_model // n_heads) - # - for group query attn q: (batch_size, seq_len, d_model) - # k, v: (batch_size, seq_len, d_model // n_kv_heads) - - if not self.config.norm_after: - if self._activation_checkpoint_fn is not None: - atten_in = self._activation_checkpoint_fn(self.attn_norm, x) - else: - atten_in = self.attn_norm(x) - else: - atten_in = x - qkv = self.att_proj(atten_in) - - if self.config.clip_qkv is not None: - qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - - q, k, v = qkv.split(self.fused_dims, dim=-1) - - # Get attention scores. - if self._activation_checkpoint_fn is not None: - att, cache = self._activation_checkpoint_fn( # type: ignore - self.attention, q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache - ) - else: - att, cache = self.attention(q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache) - - if self.config.norm_after: - if self._activation_checkpoint_fn is not None: - att = self._activation_checkpoint_fn(self.attn_norm, att) - else: - att = self.attn_norm(att) - - # Add attention scores. - # shape: (B, T, C) - x = x + self.dropout(att) - - # Add feed-forward projection. - # shape: (batch_size, seq_len, d_model) - og_x = x - - if not self.config.norm_after: - if self._activation_checkpoint_fn is not None: - x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore - else: - x = self.ff_norm(x) - - x = self.ff_proj(x) - if self._activation_checkpoint_fn is not None: - x = self._activation_checkpoint_fn(self.act, x) # type: ignore - else: - x = self.act(x) - x = self.ff_out(x) - - if self.config.norm_after: - if self._activation_checkpoint_fn is not None: - x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore - else: - x = self.ff_norm(x) - - x = self.dropout(x) - x = og_x + x - - return x, cache - - -class Embedding(nn.Module): - def __init__( - self, - num_embeddings: int, - num_new_embeddings: int, - features: int, - device: Union[str, torch.device], - initializer_range: float = 0.02, - new_embed_initializer_range: float = 0.02, - ): - super().__init__() - self.initializer_range = initializer_range - self.new_embed_initializer_range = new_embed_initializer_range - self.embedding = nn.Parameter( - torch.zeros(num_embeddings, features, device=device), - ) - self.new_embedding = nn.Parameter( - torch.zeros(num_new_embeddings, features, device=device), - ) - - def reset_parameters(self): - nn.init.normal_(self.embedding, std=self.initializer_range) - nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0)) - - -class Dropout(nn.Dropout): - def __init__( - self, - p: float = 0.5, - inplace: bool = False, - mask_p: float = 0, - broadcast_dims: Sequence[int] = (), - ): - super().__init__(p, inplace) - self.mask_p = mask_p - self.broadcast_dims = broadcast_dims - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - :param input: A tensor of shape `(batch_size, seq_len, embed_dim)` - """ - if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0): - return input - else: - if self.p > 0.0 and len(self.broadcast_dims) > 0 and self.training: - keep_prob = 1.0 - self.p - dropout_shape = list(input.shape) - for dim in self.broadcast_dims: - dropout_shape[dim] = 1 - keep = input.new_empty(dropout_shape).bernoulli_(keep_prob) - multiplier = keep.broadcast_to(input.shape) - multiplier.div_(keep_prob) - input = input * multiplier - else: - return F.dropout(input, self.p, self.training, self.inplace) - - -@dataclass -class VisionBackboneConfig: - image_default_input_size: Tuple[int, int] = (336, 336) - image_patch_size: int = 14 - image_pos_patch_size: int = 14 - image_emb_dim: int = 1024 - image_num_heads: int = 16 - image_num_key_value_heads: int = 16 - image_num_layers: int = 24 - image_head_dim: int = 64 - image_mlp_dim: int = 4096 - image_mlp_activations: str = "gelu" - image_dropout_rate: float = 0.0 - image_num_pos: int = 577 - image_norm_eps: float = 1e-5 - attention_dropout: float = 0.0 - residual_dropout: float = 0.0 - initializer_range: float = 0.02 - fsdp_wrap: bool = False - resize_mode: str = "default" - - def __post_init__(self): - self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment] - - @property - def image_num_patch(self): - h, w = self.image_default_input_size - return h // self.image_patch_size, w // self.image_patch_size - - -@dataclass -class FullMolmoConfig: - d_model: int = 768 - n_heads: int = 12 - n_kv_heads: Optional[int] = None - qkv_bias: bool = False - clip_qkv: Optional[float] = None - n_layers: int = 12 - mlp_ratio: int = 4 - mlp_hidden_size: Optional[int] = None - activation_type: str = "swiglu" - block_group_size: int = 1 - rope: bool = True - rope_full_precision: bool = True - rope_theta: float = 10000.0 - rope_impl: str = "interleave" - vision_backbone: Optional[VisionBackboneConfig] = None - attention_type: str = "sdpa" - float32_attention: bool = True - attention_dropout: float = 0.1 - response_attention_dropout: float = 0.0 - multi_query_attention: Optional[bool] = None - attention_layer_norm: bool = False - residual_dropout: float = 0.1 - embedding_dropout: float = 0.1 - layer_norm_type: str = "default" - layer_norm_with_affine: bool = True - layer_norm_eps: Optional[float] = None - attention_layer_norm_with_affine: bool = True - max_sequence_length: int = 1024 - max_position_embeddings: Optional[int] = None - include_bias: bool = True - bias_for_layer_norm: Optional[bool] = None - scale_logits: bool = False - vocab_size: int = 50257 - embedding_size: Optional[int] = 50304 - additional_vocab_size: Optional[int] = None - new_embedding_init_range: float = 0.02 - weight_tying: bool = True - pad_token_id: int = -1 - init_device: Optional[str] = None - init_std: float = 0.02 - init_cutoff_factor: Optional[float] = None - norm_after: bool = False - precision: Optional[str] = None - image_padding_embed: Optional[str] = None - vit_layers: Tuple = (-1,) - image_pooling_h: int = 2 - image_pooling_w: int = 2 - image_pooling_2d: str = "attention" - image_projector: str = "mlp" - image_feature_dropout: float = 0.0 - initializer_range: float = 0.02 - normalize_input_embeds: bool = False - use_position_ids: bool = True - - @property - def effective_n_kv_heads(self) -> int: - if self.n_kv_heads is None: - if self.multi_query_attention is True: - return 1 - else: - return self.n_heads - else: - if self.multi_query_attention is None: - return self.n_kv_heads - if self.multi_query_attention: - n_kv_heads_should_be = 1 - else: - n_kv_heads_should_be = self.n_heads - if self.n_kv_heads == n_kv_heads_should_be: - return n_kv_heads_should_be - else: - raise MolmoConfigurationError("You can't set `multi_query_attention` and `n_kv_heads` at the same time.") - - @property - def image_num_patch(self): - assert self.vision_backbone is not None - return self.vision_backbone.image_num_patch - - @property - def image_patch_size(self): - assert self.vision_backbone is not None - return self.visoin_backbone.image_patch_size - - def llm_patches_per_crop(self): - h, w = self.image_num_patch - # Round up in case we need to pad the image features for pooling - h = (h + self.image_pooling_h - 1) // self.image_pooling_h - w = (w + self.image_pooling_w - 1) // self.image_pooling_w - return h, w - - -def _expand_token(token, batch_size: int): - return token.view(1, 1, -1).expand(batch_size, -1, -1) - - -class ViTMLP(nn.Module): - def __init__(self, config: FullMolmoConfig): - super().__init__() - self.config = config - v_cfg = config.vision_backbone - - self.w1 = nn.Linear( - v_cfg.image_emb_dim, - v_cfg.image_mlp_dim, - bias=True, - device=config.init_device, - ) - # Activation function. - cfg = deepcopy(config) - cfg.activation_type = v_cfg.image_mlp_activations - self.act = Activation.build(cfg) - self.w2 = nn.Linear( - v_cfg.image_mlp_dim, - v_cfg.image_emb_dim, - bias=True, - device=config.init_device, - ) - - def reset_parameters(self): - v_cfg = self.config.vision_backbone - nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0) - nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0) - nn.init.zeros_(self.w1.bias) - nn.init.zeros_(self.w2.bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.w1(x) - x = self.act(x) - x = self.w2(x) - return x - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, config: FullMolmoConfig): - super().__init__() - self.config = config - - v_cfg = config.vision_backbone - self.attention = MultiHeadDotProductAttention(config) - self.feed_forward = ViTMLP(config) - self.attention_norm = nn.LayerNorm( - v_cfg.image_emb_dim, - eps=v_cfg.image_norm_eps, - device=config.init_device, - ) - self.ffn_norm = nn.LayerNorm( - v_cfg.image_emb_dim, - eps=v_cfg.image_norm_eps, - device=config.init_device, - ) - - def reset_parameters(self): - self.attention.reset_parameters() - self.feed_forward.reset_parameters() - self.attention_norm.reset_parameters() - self.ffn_norm.reset_parameters() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.attention(self.attention_norm(x)) - x = x + self.feed_forward(self.ffn_norm(x)) - return x - - -class BlockCollection(nn.Module): - def __init__(self, config: FullMolmoConfig): - super().__init__() - self.config = config - self.grad_checkpointing: bool = False - - v_cfg = config.vision_backbone - self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)]) - - def reset_parameters(self): - for r in self.resblocks: - r.reset_parameters() - - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: - hidden_states = [] - for r in self.resblocks: - x = r(x) - hidden_states.append(x) - return hidden_states - - -class LayerNormFp32(nn.LayerNorm): - def forward(self, x: torch.Tensor) -> torch.Tensor: - orig_type = x.dtype - x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight.to(torch.float32), self.bias.to(torch.float32), self.eps) - return x.to(orig_type) - - -class VisionTransformer(nn.Module): - def __init__(self, config: FullMolmoConfig): - super().__init__() - self.config = config - - v_cfg = config.vision_backbone - # class embeddings and positional embeddings - self.scale = v_cfg.image_emb_dim**-0.5 - self.class_embedding = nn.Parameter( - torch.zeros(v_cfg.image_emb_dim, device=config.init_device), - ) - self.num_prefix_tokens: int = 1 - self.positional_embedding = nn.Parameter( - torch.zeros(v_cfg.image_num_pos, v_cfg.image_emb_dim, device=config.init_device), - ) - - image_patch_size = v_cfg.image_patch_size - self.patch_embedding = nn.Linear( - image_patch_size * image_patch_size * 3, - v_cfg.image_emb_dim, - bias=False, - device=config.init_device, - ) - - self.pre_ln = LayerNormFp32( - v_cfg.image_emb_dim, - eps=v_cfg.image_norm_eps, - ) - - self.transformer = BlockCollection(config) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.transformer.grad_checkpointing = enable - - def reset_parameters(self): - nn.init.normal_(self.class_embedding, std=self.scale) - nn.init.normal_(self.positional_embedding, std=self.scale) - nn.init.normal_(self.patch_embedding.weight, std=0.02) - self.pre_ln.reset_parameters() - self.transformer.reset_parameters() - - def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: - cls_emb = self.positional_embedding[0:1] - pos_emb = self.positional_embedding[1:] - - pos_emb = pos_emb.reshape((int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) - - (patch_num_0, patch_num_1) = patch_num - - if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: - # Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py - # antialias: default True in jax.image.resize - pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) - pos_emb = F.interpolate( - pos_emb, - size=(patch_num_0, patch_num_1), - mode="bicubic", - align_corners=False, - antialias=True, - ) - pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) - - pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) - x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype) - return x - - def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]: - """ - : param x: (batch_size, num_patch, n_pixels) - """ - if patch_num is None: - patch_num = self.config.vision_backbone.image_num_patch - B, N, D = x.shape - - x = self.patch_embedding(x) - - # class embeddings and positional embeddings - x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) - x = self.add_pos_emb(x, patch_num) - - x = self.pre_ln(x) - - hidden_states = self.transformer(x) - return hidden_states - - -class MultiHeadDotProductAttention(nn.Module): - def __init__(self, config: FullMolmoConfig, use_bias: bool = True, is_vit_layer: Optional[bool] = True): - super().__init__() - self.config = config - self.use_bias = use_bias - - v_cfg = config.vision_backbone - self.embed_dim = v_cfg.image_emb_dim - self.num_heads = v_cfg.image_num_heads - self.head_dim = v_cfg.image_head_dim - self.num_key_value_heads = v_cfg.image_num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.initializer_range = v_cfg.initializer_range - self.is_vit_layer = is_vit_layer - - nlayers = 1 if (is_vit_layer or config.vit_layers is None) else len(config.vit_layers) - - self.wq = nn.Linear( - nlayers * self.embed_dim, - self.num_heads * self.head_dim, - bias=use_bias, - device=config.init_device, - ) - self.wk = nn.Linear( - nlayers * self.embed_dim, - self.num_key_value_heads * self.head_dim, - bias=use_bias, - device=config.init_device, - ) - self.wv = nn.Linear( - nlayers * self.embed_dim, - self.num_key_value_heads * self.head_dim, - bias=use_bias, - device=config.init_device, - ) - self.wo = nn.Linear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=use_bias, - device=config.init_device, - ) - self.attention_dropout: Optional[Dropout] = None - if v_cfg.attention_dropout > 0: - self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1)) - self.residual_dropout = Dropout(v_cfg.residual_dropout) - - def reset_parameters(self): - nn.init.normal_(self.wq.weight, std=self.initializer_range) - nn.init.normal_(self.wk.weight, std=self.initializer_range) - nn.init.normal_(self.wv.weight, std=self.initializer_range) - nn.init.normal_(self.wo.weight, std=self.initializer_range) - if self.use_bias: - nn.init.constant_(self.wq.bias, 0) - nn.init.constant_(self.wk.bias, 0) - nn.init.constant_(self.wv.bias, 0) - nn.init.constant_(self.wo.bias, 0) - - def _split_heads(self, hidden_states, num_heads) -> torch.Tensor: - return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states) -> torch.Tensor: - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: - if inputs_kv is not None: - inputs_k = inputs_kv - inputs_v = inputs_kv - else: - inputs_k = inputs_q - inputs_v = inputs_q - - xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v) - - xq = self._split_heads(xq, self.num_heads) - xk = self._split_heads(xk, self.num_key_value_heads) - xv = self._split_heads(xv, self.num_key_value_heads) - - if self.num_heads != self.num_key_value_heads: - xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) - xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) - - og_dtype = xq.dtype - - if self.config.float32_attention: - xq = xq.to(torch.float) - xk = xk.to(torch.float) - - if self.config.attention_type == "direct": - attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype) - if self.attention_dropout is not None: - attn_weights = self.attention_dropout(attn_weights) - attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv) - - elif self.config.attention_type == "sdpa": - if self.config.float32_attention and not torch.is_autocast_enabled(): - xv = xv.to(torch.float32) - attn_output = F.scaled_dot_product_attention( - xq.transpose(1, 2).contiguous(), - xk.transpose(1, 2).contiguous(), - xv.transpose(1, 2).contiguous(), - is_causal=False, - dropout_p=self.config.vision_backbone.attention_dropout, - ).transpose(1, 2) - else: - raise NotImplementedError(self.config.attention_type) - attn_output = attn_output.to(og_dtype) - attn_output = self._merge_heads(attn_output) - attn_output = self.wo(attn_output) - attn_output = self.residual_dropout(attn_output) - - return attn_output - - -class MultiHeadAttentionPool(nn.Module): - def __init__( - self, - config: FullMolmoConfig, - factor: int = 1, - use_bias: bool = True, - dropout: bool = True, - output_layer: bool = True, - mean_residual: bool = False, - query: str = "mean", - is_vit_layer: Optional[bool] = True, - ): - super().__init__() - self.config = config - self.factor = factor - self.use_bias = use_bias - self.dropout = dropout - self.output_layer = output_layer - self.mean_residual = mean_residual - self.query = query - - v_cfg = config.vision_backbone - input_dim = v_cfg.image_emb_dim - self.embed_dim = v_cfg.image_emb_dim * factor - self.num_heads = v_cfg.image_num_heads - self.head_dim = v_cfg.image_head_dim * factor - self.num_key_value_heads = v_cfg.image_num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.initializer_range = v_cfg.initializer_range - - nlayers = 1 if (is_vit_layer or config.vit_layers is None) else len(config.vit_layers) - - if query != "vector": - self.wq = nn.Linear( - nlayers * input_dim, - self.num_heads * self.head_dim, - bias=use_bias, - device=config.init_device, - ) - self.wk = nn.Linear( - nlayers * input_dim, - self.num_key_value_heads * self.head_dim, - bias=use_bias, - device=config.init_device, - ) - self.wv = nn.Linear( - nlayers * input_dim, - self.num_key_value_heads * self.head_dim, - bias=use_bias, - device=config.init_device, - ) - - if query == "vector": - self.attention_query = nn.Parameter( - torch.zeros( - 1, - self.num_key_value_heads * self.head_dim, - device=config.init_device, - ), - ) - - if output_layer: - self.wo = nn.Linear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=use_bias, - device=config.init_device, - ) - self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1)) - if dropout: - self.residual_dropout = Dropout(v_cfg.residual_dropout) - - def reset_parameters(self): - if self.query != "vector": - nn.init.normal_(self.wq.weight, std=self.initializer_range) - nn.init.normal_(self.wk.weight, std=self.initializer_range) - nn.init.normal_(self.wv.weight, std=self.initializer_range) - if self.output_layer: - nn.init.normal_(self.wo.weight, std=self.initializer_range) - if self.use_bias: - if self.query != "vector": - nn.init.constant_(self.wq.bias, 0) - nn.init.constant_(self.wk.bias, 0) - nn.init.constant_(self.wv.bias, 0) - if self.output_layer: - nn.init.constant_(self.wo.bias, 0) - if self.query == "vector": - nn.init.normal_(self.attention_query, std=self.initializer_range) - - def _split_heads(self, hidden_states, num_heads): - return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) - - def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor: - xk, xv = self.wk(inputs_kv), self.wv(inputs_kv) - - if self.query == "mean": - inputs_q = inputs_kv.mean(dim=1, keepdim=True) - xq = self.wq(inputs_q) - elif self.query == "first": - inputs_q = inputs_kv[:, :1] - xq = self.wq(inputs_q) - elif self.query == "vector": - xq = self.attention_query.expand(inputs_kv.size(0), -1, -1) - elif self.query == "constant": - inputs_q = torch.ones_like(inputs_kv[:, :1]) / math.sqrt(inputs_kv.shape[-1]) - xq = self.wq(inputs_q) - else: - raise ValueError(f"Unknown query type: {self.query}") - - xq = self._split_heads(xq, self.num_heads) - xk = self._split_heads(xk, self.num_key_value_heads) - xv = self._split_heads(xv, self.num_key_value_heads) - - if self.num_heads != self.num_key_value_heads: - xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) - xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) - - xq = xq.to(torch.float) - xk = xk.to(torch.float) - - xq = xq / math.sqrt(xq.size(-1)) - attn_weights = torch.einsum("...qhd,...khd->...hqk", xq, xk) - - attn_weights = F.softmax(attn_weights, dim=-1).to(xq.dtype) - - attn_weights = self.attention_dropout(attn_weights).to(xv.dtype) - - attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights, xv) - attn_output = self._merge_heads(attn_output) - if self.output_layer: - attn_output = self.wo(attn_output) - if self.dropout: - attn_output = self.residual_dropout(attn_output) - if self.mean_residual: - attn_output += inputs_kv.mean(dim=1, keepdim=True) - - return attn_output - - -class MLP(nn.Module): - def __init__(self, config: FullMolmoConfig, input_dim: int, dropout: float = 0.0): - super().__init__() - self.config = config - self.hidden_size = config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model - self.initializer_range = config.initializer_range - - self.w1 = nn.Linear( - input_dim, - self.hidden_size // 2, - bias=False, - device=config.init_device, - ) - self.w2 = nn.Linear( - self.hidden_size // 2, - config.d_model, - bias=False, - device=config.init_device, - ) - self.w3 = nn.Linear( - input_dim, - self.hidden_size // 2, - bias=False, - device=config.init_device, - ) - # Activation function. - self.act = Activation.build(config) - self.dropout = Dropout(dropout) - - def reset_parameters(self): - nn.init.normal_(self.w1.weight, std=self.initializer_range) - nn.init.normal_(self.w2.weight, std=self.initializer_range) - nn.init.normal_(self.w3.weight, std=self.initializer_range) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.w2(self.act(self.w1(x), self.w3(x))) - x = self.dropout(x) - return x - - -class Residual(nn.Module): - def __init__(self, submodule: nn.Module): - super().__init__() - self.submodule = submodule - - def reset_parameters(self): - self.submodule.reset_parameters() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x + self.submodule(x) - - -class OLMoVisionBackbone(nn.Module): - def __init__(self, config: FullMolmoConfig): - super().__init__() - self.config = config - self.image_vit = VisionTransformer(config) - - input_dim: int = None - self.image_pooling_2d: nn.Module = None - if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}: - self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False) - input_dim = config.vision_backbone.image_emb_dim - elif config.image_pooling_2d == ImagePooling2DType.attention_2wide: - cfg = deepcopy(config) - cfg.vision_backbone.image_emb_dim *= 2 - cfg.vision_backbone.image_head_dim *= 2 - self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False) - input_dim = cfg.vision_backbone.image_emb_dim - elif config.image_pooling_2d == ImagePooling2DType.attention_v2: - assert config.vit_layers is not None - use_bias = True - dropout = True - output_layer = True - query = "mean" - mean_residual = False - factor = len(config.vit_layers) - self.image_pooling_2d = MultiHeadAttentionPool( - config, - factor=factor, - use_bias=use_bias, - dropout=dropout, - output_layer=output_layer, - mean_residual=mean_residual, - query=query, - is_vit_layer=False, - ) - input_dim = config.vision_backbone.image_emb_dim * factor - elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]: - self.image_pooling_2d = None - nlayers = 1 if config.vit_layers is None else len(config.vit_layers) - input_dim = nlayers * config.vision_backbone.image_emb_dim - else: - raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}") - - self.input_dim = input_dim - - # `MLP` assume the activation takes two inputs, so it must be a 'llama' version - if config.activation_type == ActivationType.swiglu: - mlp_config = replace(config, activation_type=ActivationType.llama_swiglu) - elif config.activation_type == ActivationType.gelu: - mlp_config = replace(config, activation_type=ActivationType.llama_geglu) - else: - mlp_config = config - if config.image_projector == ImageProjectType.mlpx2: - self.image_projector = nn.ModuleList([MLP(mlp_config, input_dim), Residual(MLP(config, input_dim))]) - elif config.image_projector == ImageProjectType.mlp: - self.image_projector = MLP(mlp_config, input_dim) - elif config.image_projector == ImageProjectType.linear: - self.image_projector = nn.Linear( - input_dim, - config.d_model, - bias=False, - device=config.init_device, - ) - else: - raise NotImplementedError(f"Unknown image projector: {config.image_projector}") - - self.image_feature_dropout = Dropout(config.image_feature_dropout) - - def reset_parameters(self): - if self.image_pooling_2d is not None: - self.image_pooling_2d.reset_parameters() - if self.config.image_projector == "2mlp": - for module in self.image_projector: - module.reset_parameters() - elif self.config.image_projector == "linear": - nn.init.xavier_uniform_(self.image_projector.weight) - else: - self.image_projector.reset_parameters() - - def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - raise NotImplementedError - - -class OLMoPretrainedVisionBackbone(OLMoVisionBackbone): - def __init__(self, config: FullMolmoConfig): - super().__init__(config) - v_cfg = self.config.vision_backbone - self.grad_checkpointing = True - - self.num_prefix_tokens = self.image_vit.num_prefix_tokens - assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported" - - self.pad_embed = None - if config.image_padding_embed: - image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers) - if config.image_padding_embed in ["pad_embed", "regress"]: - self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device)) - elif config.image_padding_embed == "pad_and_partial_pad": - self.pad_embed = nn.Parameter(torch.zeros((2, image_dim), device=config.init_device)) - else: - raise ValueError(config.image_padding_embed) - - def reset_parameters(self): - super().reset_parameters() - self.image_vit.reset_parameters() - - def encode_image(self, images: torch.Tensor) -> torch.Tensor: - """ - : param images: (batch_size, num_crops, num_patch, n_pixels) - """ - cfg = self.config - v_cfg = self.config.vision_backbone - B, T, N, D = images.shape - - mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) - - # Output all hidden states - # n_layers x (batch_num_crops, (1+)n_tokens, image_emb_dim) - images = images.view(B * T, N, D) - image_features = self.image_vit(images) - - if cfg.vit_layers is not None: - features = [] - for layer in cfg.vit_layers: - features.append(image_features[layer]) - image_features = torch.cat(features, dim=-1) - else: - image_features = image_features[-1] - - cls_embed: torch.Tensor = None - if self.num_prefix_tokens > 0: - cls_embed = image_features[:, 0] - image_features = image_features[:, 1:] - - image_features = image_features * mask - image_features = image_features.view(B, T, N, -1) - - cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None - - return image_features, cls_embed - - def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - cfg = self.config - - # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) - batch_size, num_image = images.shape[:2] - image_features, cls_embed = self.encode_image(images) - - if cfg.image_padding_embed: - assert image_masks is not None - if cfg.image_padding_embed == "pad_embed": - all_pad = (image_masks == 0).to(dtype=torch.float32) - pad_embed = self.pad_embed[None, None, None, :] - image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1) - elif cfg.image_padding_embed == "regress": - pad_embed = self.pad_embed[None, None, None, :] - image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1) - elif cfg.image_padding_embed == "pad_and_partial_pad": - pad_embed = self.pad_embed[:, None, None, None, :] - all_pad = image_masks == 0 - partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype) - all_pad = all_pad.to(dtype=image_features.dtype) - image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1) - image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1) - else: - raise ValueError(cfg.image_padding_embed) - - image_features = self.image_feature_dropout(image_features) - if cls_embed is not None: - cls_embed = self.image_feature_dropout(cls_embed) - - image_features = image_features.reshape( - (batch_size, num_image) + cfg.image_num_patch + (-1,), - ) - - if cfg.image_num_patch[0] % cfg.image_pooling_h == 1: - # Pad so we can still pool 2x2 patches - image_features = F.pad( - image_features, - (0, 0, 0, 1, 0, 1, 0, 0, 0, 0), - ) - - # image pooling - image_features = einops.rearrange( - image_features, - "b n (h dh) (w dw) c -> (b n h w) (dh dw) c", - dh=cfg.image_pooling_h, - dw=cfg.image_pooling_w, - ) - - if cfg.image_pooling_2d == ImagePooling2DType.attention_meanq: - query = image_features.mean(-2, keepdim=True) - image_features = self.image_pooling_2d(query, image_features) - elif cfg.image_pooling_2d not in {ImagePooling2DType.none, ImagePooling2DType.stack}: - if self.grad_checkpointing: - from torch.utils.checkpoint import checkpoint - - image_features = checkpoint(self.image_pooling_2d, image_features[:, :1, :], image_features, use_reentrant=False) - else: - image_features = self.image_pooling_2d(image_features[:, :1, :], image_features) - - h, w = cfg.llm_patches_per_crop() - image_features = image_features.reshape(batch_size, num_image, h * w, -1) - - # MLP layer to map the feature. - if self.grad_checkpointing: - from torch.utils.checkpoint import checkpoint - - image_features = checkpoint(self.image_projector, image_features, use_reentrant=False) - else: - image_features = self.image_projector(image_features) - - # image_features: (batch_size, num_image, num_patch, d_model) - # cls_embed: (batch_size, num_image, d_model) - return image_features, cls_embed - - -class ModuleType(str, Enum): - in_module = "in" - out_module = "out" - emb = "emb" - final_out = "final_out" - - -def init_weights( - config: FullMolmoConfig, - module: Union[nn.Linear, nn.Embedding], - d: Optional[int] = None, - layer_id: Optional[int] = None, - std_factor: float = 1.0, - type_of_module: Optional[ModuleType] = None, -) -> None: - d = d if d is not None else config.d_model - std = config.init_std * std_factor - if config.init_cutoff_factor is not None: - cutoff_value = config.init_cutoff_factor * std - nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) - else: - nn.init.normal_(module.weight, mean=0.0, std=std) - - -class LlamaSwiGLU(nn.Module): - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - return F.silu(x1) * x2 - - @property - def output_multiplier(self) -> float: - return 0.5 - - -class SwiGLU(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - - @property - def output_multiplier(self) -> float: - return 0.5 - - -class Activation(nn.Module): - def __init__(self, config: FullMolmoConfig): - super().__init__() - self.config = config - - def forward(self, x: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - - @property - def output_multiplier(self) -> float: - raise NotImplementedError - - @classmethod - def build(cls, config: FullMolmoConfig) -> "Activation": - if config.activation_type == "quick_gelu": - return QuickGELU(config) - elif config.activation_type == "gelu": - return cast(Activation, GELU(approximate="none")) - elif config.activation_type == "gelu_tanh": - return cast(Activation, GELU(approximate="tanh")) - elif config.activation_type == "relu": - return cast(Activation, ReLU(inplace=False)) - elif config.activation_type == "silu": - return cast(Activation, SiLU(inplace=False)) - # elif config.activation_type == "llama_geglu": - # return LlamaGEGLU(config) - # elif config.activation_type == "llama_geglu_tanh": - # return LlamaGEGLUTanh(config) - elif config.activation_type == "llama_swiglu": - return LlamaSwiGLU() - elif config.activation_type == "swiglu": - return SwiGLU() - else: - raise NotImplementedError(f"Unknown activation: '{config.activation_type}'") - - -class QuickGELU(Activation): - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x * torch.sigmoid(1.702 * x) - - @property - def output_multiplier(self) -> float: - return 1.0 - - -class GELU(nn.GELU): - @property - def output_multiplier(self) -> float: - return 1.0 - - -class ReLU(nn.ReLU): - @property - def output_multiplier(self) -> float: - return 1.0 - - -class SiLU(nn.SiLU): - @property - def output_multiplier(self) -> float: - return 1.0 - - -def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: - att_bias = torch.triu( - torch.ones(seq_len, seq_len, device=device, dtype=torch.float), - diagonal=1, - ) - att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) - return att_bias.view(1, 1, seq_len, seq_len) # type: ignore - - -def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor: - if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len: - if causal_bias.device != device: - causal_bias = causal_bias.to(device) - cache["causal_attention_bias"] = causal_bias - return causal_bias - with torch.autocast(device.type, enabled=False): - causal_bias = causal_attention_bias(seq_len, device) - cache["causal_attention_bias"] = causal_bias - return causal_bias - - -class LayerNormBase(nn.Module): - def __init__( - self, - config: MolmoConfig, - *, - size: Optional[int] = None, - elementwise_affine: Optional[bool] = True, - eps: float = 1e-05, - weight_initializer: Optional[Callable] = torch.ones, - bias_initializer: Optional[Callable] = torch.zeros, - ): - super().__init__() - self.config = config - self.eps = self.config.layer_norm_eps or eps - self.normalized_shape = (size or config.d_model,) - if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine): - self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device)) - use_bias = self.config.bias_for_layer_norm - if use_bias is None: - use_bias = self.config.include_bias - if use_bias: - self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device)) - else: - self.register_parameter("bias", None) - else: - self.register_parameter("bias", None) - self.register_parameter("weight", None) - - @classmethod - def build(cls, config: FullMolmoConfig, size: Optional[int] = None, **kwargs): - if config.layer_norm_type == "default": - return LayerNorm(config, size=size, low_precision=False, **kwargs) - elif config.layer_norm_type == "low_precision": - return LayerNorm(config, size=size, low_precision=True, **kwargs) - elif config.layer_norm_type == "rms": - return RMSLayerNorm(config, size=size, **kwargs) - else: - raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'") - - -class RMSLayerNorm(LayerNormBase): - """ - RMS layer norm, a simplified :class:`LayerNorm` implementation - """ - - def __init__( - self, - config: FullMolmoConfig, - size: Optional[int] = None, - elementwise_affine: Optional[bool] = None, - eps: float = 1e-5, - ): - super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - with torch.autocast(enabled=False, device_type=x.device.type): - og_dtype = x.dtype - x = x.to(torch.float32) - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.eps) - x = x.to(og_dtype) - - if self.weight is not None: - if self.bias is not None: - return self.weight * x + self.bias - else: - return self.weight * x - else: - return x - - -class LayerNorm(LayerNormBase): - """ - The default :class:`LayerNorm` implementation which can optionally run in low precision. - """ - - def __init__( - self, - config: FullMolmoConfig, - size: Optional[int] = None, - low_precision: bool = False, - elementwise_affine: Optional[bool] = None, - eps: float = 1e-05, - ): - super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) - self.low_precision = low_precision - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.low_precision: - module_device = x.device - downcast_x = self._cast_if_autocast_enabled(x) - downcast_weight = self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight - downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias - with torch.autocast(enabled=False, device_type=module_device.type): - return F.layer_norm(downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps) - else: - return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) - - -class Molmo(nn.Module): - def __init__(self, config: FullMolmoConfig, init_params: bool = True): - super().__init__() - self.config = config - self.__cache = BufferCache() - - # Validate config. - if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size: - if self.config.embedding_size < self.config.vocab_size: - raise MolmoConfigurationError("embedding size should be at least as big as vocab size") - elif self.config.embedding_size % 128 != 0: - import warnings - - warnings.warn("Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning) - torch.backends.cuda.enable_flash_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp( - True - ) # jakep: I found that setting this to true in torch 2.5.1 greatly increased performance (6sec/it from 22sec/it) - - wte = None - if self.config.additional_vocab_size is not None: - wte = Embedding( - config.embedding_size or config.vocab_size, - config.additional_vocab_size, - config.d_model, - device=config.init_device, - initializer_range=config.initializer_range, - new_embed_initializer_range=config.new_embedding_init_range, - ) - else: - wte = nn.Embedding(config.embedding_size or config.vocab_size, config.d_model, device=config.init_device) - - self.transformer = nn.ModuleDict( - dict( - wte=wte, - emb_drop=Dropout(config.embedding_dropout), - ln_f=LayerNorm.build(config), - ) - ) - - blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] - if self.config.block_group_size > 1: - raise NotImplementedError() - else: - self.transformer.update({"blocks": nn.ModuleList(blocks)}) - - if not self.config.rope: - self.transformer.update({"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}) - if not config.weight_tying: - self.transformer.update( - { - "ff_out": nn.Linear( - config.d_model, - config.embedding_size or config.vocab_size, - bias=config.include_bias, - device=config.init_device, - ) - } - ) - - self.vision_backbone: Optional[OLMoVisionBackbone] = None - if config.vision_backbone is not None: - self.vision_backbone = OLMoPretrainedVisionBackbone(config) - - self.__num_fwd_flops: Optional[int] = None - - self.gradient_checkpointing = False - - def reset_parameters(self): - if self.vision_backbone is not None: - self.vision_backbone.reset_parameters() - self.reset_non_vision_parameters() - - def reset_non_vision_parameters(self): - self.transformer.wte.reset_parameters() - if hasattr(self.transformer.wte, "new_embedding"): - nn.init.normal_(self.transformer.wte.new_embedding, std=self.config.new_embedding_init_range) - - if hasattr(self.transformer, "wpe"): - nn.init.normal_(self.transformer.wpe, mean=0.0, std=1.0) - - self.transformer.ln_f.reset_parameters() # type: ignore - - if hasattr(self.transformer, "ff_out"): - nn.init.normal_(self.transformer.ff_out, mean=0.0, std=0.02) - - if self.config.block_group_size == 1: - for block in self.transformer.blocks: - block.reset_parameters() - else: - for block_group in self.transformer.block_groups: - block_group.reset_parameters() - - def forward( - self, - input_ids: torch.LongTensor, - input_embeddings: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - attention_bias: Optional[torch.Tensor] = None, - response_mask: Optional[torch.Tensor] = None, - images: Optional[torch.Tensor] = None, - image_masks: Optional[torch.Tensor] = None, - image_input_idx: Optional[torch.Tensor] = None, - subsegment_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, - use_cache: bool = False, - last_logits_only: bool = False, - output_hidden_states: Optional[bool] = None, - append_last_valid_logits: Optional[torch.Tensor] = None, - ) -> ModelOutput: - """ - :param input_ids: A tensor of shape `(batch_size, seq_len)`. - :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input - embeddings. When provided, it is treated as the output of the input embedding layer. - :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates - which input IDs are masked. A `1` value in the mask means that - the corresponding input ID should *not* be ignored. A `0` means - that the corresponding input ID is masked. - - This has the same meaning as the `attention_mask` in HuggingFace's `transformers` - library. - :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`, - `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used - to introduce causal or other biases. - - If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]` - indicates that the i-th element in the sequence is allowed to attend to the j-th - element in the sequence. - - If the tensor is a float tensor, it will just be added to the attention - scores before the softmax. - - The default is causal, which corresponds to a lower-diagonal byte matrix of ones. - :param response_mask: A tensor of shape `(batch_size, seq_len)` that indicates - the response mask. A `1` value in the mask means that the corresponding token - is a response token. A `0` means that the corresponding token is not - a response token. - :param past_key_values: Pre-computed keys and values for each attention block. - Can be used to speed up sequential decoding. The `input_ids` which have - their past given to this model should not be passed as `input_ids` as they have already been computed. - :param use_cache: If `True`, return key and value tensors for each block. - :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. - This can speed up decoding when you only care about the next token. - """ - output_hidden_states = output_hidden_states if output_hidden_states is not None else False - - if past_key_values: - assert len(past_key_values) == self.config.n_layers - - has_image = images is not None - - assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings." - assert not (has_image and past_key_values is not None), "Cached key and values should not be used with images." - - batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] - if past_key_values is None: - past_length = 0 - else: - past_length = past_key_values[0][0].size(-2) - - if self.config.use_position_ids and attention_mask is None: - attention_mask = input_ids != -1 - - if subsegment_ids is not None: - assert not use_cache, "Subsegment_ids cannot be used with cache." - subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1) - attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1) - if position_ids is None: - raise ValueError("Positioned ids must be given if using subsegment_ids") - else: - if self.config.use_position_ids and position_ids is None: - position_ids = torch.clamp( - torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, - min=0, - ).broadcast_to((batch_size, attention_mask.shape[-1])) - - # Get embeddings of input. - # shape: (batch_size, seq_len, d_model) - if input_ids is not None: - input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) - x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore - - num_image: Optional[int] = None - if images is not None: - # shape: (batch_size, num_image, num_patch, d_model) - # cls_embed: (batch_size, num_image, d_model) - image_features, cls_embed = self.vision_backbone(images, image_masks) - num_image, num_patch = image_features.shape[1:3] - assert image_input_idx.shape == (batch_size, num_image, num_patch) - - # inster the image feature into the embedding. - image_features = image_features.view(batch_size, num_image * num_patch, -1) - image_input_idx = image_input_idx.view(batch_size, num_image * num_patch) - - valid = image_input_idx >= 0 - batch_idx = torch.arange(batch_size, device=x.device) - batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]]) - - # For hf demo/endpoint - image_features = image_features.to(x.device) - - x[batch_idx[valid], image_input_idx[valid]] += image_features[valid] - - if not self.config.rope: - # Get positional embeddings. - # shape: (1, seq_len) - pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) - # shape: (1, seq_len, d_model) - pos_emb = self.transformer.wpe(pos) # type: ignore - x = pos_emb + x - - # Add input + positional embeddings and apply dropout. - # shape: (batch_size, seq_len, d_model) - x = self.transformer.emb_drop(x) # type: ignore - - # normalized - if self.config.normalize_input_embeds: - x = x * (self.config.d_model**0.5) - - # Transform the attention mask into what the blocks expect. - if attention_mask is not None: - # shape: (batch_size, 1, 1, seq_len) - if len(attention_mask.shape) == 2: - attention_mask = attention_mask[:, : past_length + seq_len] - attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] - else: - attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float) - attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min - - # Merge attention mask with attention bias. - if ( - attention_bias is not None - or attention_mask is not None - # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly - # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute - # scores correctly. - or past_key_values is not None - ): - if attention_bias is None: - attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) - elif attention_bias.dtype in (torch.int8, torch.bool): - attention_bias = attention_bias.to(dtype=torch.float) - attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) - - # Transform to the right shape and data type. - mask_len = seq_len - if attention_mask is not None: - mask_len = attention_mask.shape[-1] - elif past_key_values is not None: - mask_len = past_key_values[0][0].shape[-2] + seq_len - attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) - - # Add in the masking bias. - if attention_mask is not None: - attention_bias = attention_bias + attention_mask - # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf. - # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead - # it can produce NaNs. - ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False) - - attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None - - # decoder layers - all_hidden_states = [] - - # Apply blocks one-by-one. - if self.config.block_group_size == 1: - for block_idx, block in enumerate(self.transformer.blocks): - if output_hidden_states: - # add hidden states - all_hidden_states.append(x) - - layer_past = None if past_key_values is None else past_key_values[block_idx] - - if self.gradient_checkpointing and self.training: - x, cache = self._gradient_checkpointing_func( - block, x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache - ) - else: - x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache) - - if attn_key_values is not None: - assert cache is not None - attn_key_values.append(cache) - else: - for group_idx, block_group in enumerate(self.transformer.block_groups): - if output_hidden_states: - # add hidden states - all_hidden_states.append(x) - - layers_past = ( - None - if past_key_values is None - else past_key_values[group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size] - ) - x, cache = block_group(x, attention_bias=attention_bias, position_ids=position_ids, layers_past=layers_past, use_cache=use_cache) - if attn_key_values is not None: - assert cache is not None - attn_key_values.extend(cache) - - if last_logits_only: - # shape: (batch_size, 1, d_model) - if append_last_valid_logits is not None: - last_valid_output = x[torch.arange(x.shape[0], device=x.device), append_last_valid_logits.to(x.device)] - x = last_valid_output.unsqueeze(1) - else: - x = x[:, -1, :].unsqueeze(1) - - # Apply final layer norm. - # shape: (batch_size, seq_len or 1, d_model) - x = self.transformer.ln_f(x) # type: ignore - if output_hidden_states: - # add final hidden state post-final-layernorm, following HuggingFace's convention - all_hidden_states.append(x) - - # Get logits. - # shape: (batch_size, seq_len or 1, vocab_size) - if self.config.weight_tying: - logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore - else: - logits = self.transformer.ff_out(x) # type: ignore - if self.config.scale_logits: - logits.mul_(1 / math.sqrt(self.config.d_model)) - - if not last_logits_only and append_last_valid_logits is not None: - last_valid_logit = logits[torch.arange(logits.shape[0], device=logits.device), append_last_valid_logits] - logits = torch.cat([logits[:, :-1], last_valid_logit[:, None]], dim=1) - - return ModelOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type] - - -class MolmoForCausalLM(PreTrainedModel): - config_class = MolmoConfig - supports_gradient_checkpointing = True - base_model_prefix = "model" - _no_split_modules = ["MolmoBlock"] - - def __init__(self, config: MolmoConfig, model: Optional[Molmo] = None, init_params: bool = False): - super().__init__(config) - - if not model: - full_config = FullMolmoConfig( - image_padding_embed="pad_and_partial_pad", - image_pooling_2d="attention-meanq", - attention_layer_norm=config.attention_layer_norm, - rope_impl="llama", - vocab_size=config.vocab_size, - max_sequence_length=config.max_position_embeddings, - qkv_bias=config.qkv_bias, - norm_after=config.norm_after, - embedding_size=config.embedding_size, - attention_type="sdpa", - embedding_dropout=0, - attention_dropout=0, - residual_dropout=0, - rope=True, - weight_tying=False, - include_bias=False, - d_model=config.hidden_size, - mlp_hidden_size=config.intermediate_size, - n_layers=config.num_hidden_layers, - additional_vocab_size=128, - n_heads=config.num_attention_heads, - n_kv_heads=config.num_key_value_heads, - rope_theta=config.rope_theta, - layer_norm_eps=config.layer_norm_eps, - layer_norm_type=config.layer_norm_type, - vit_layers=[-2, -9], - vision_backbone=VisionBackboneConfig( - image_default_input_size=(336, 336), - image_patch_size=14, - image_pos_patch_size=14, - image_emb_dim=1024, - image_num_heads=16, - image_num_key_value_heads=16, - image_num_layers=23, - image_head_dim=64, - image_mlp_dim=4096, - image_mlp_activations="quick_gelu", - image_dropout_rate=0.0, - image_num_pos=577, - image_norm_eps=1e-5, - attention_dropout=0.0, - residual_dropout=0.0, - initializer_range=0.02, - ), - ) - self.model = Molmo(full_config, init_params=init_params) - else: - self.model = model - - def forward( - self, - input_ids: torch.LongTensor = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - attention_bias: Optional[torch.Tensor] = None, - response_mask: Optional[torch.Tensor] = None, - images: Optional[torch.Tensor] = None, - image_masks: Optional[torch.Tensor] = None, - image_input_idx: Optional[torch.Tensor] = None, - subsegment_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - labels: Optional[torch.LongTensor] = None, - loss_masks: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - last_logits_only: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - append_last_valid_logits: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[ - Cache - ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 - ) -> Union[Tuple, CausalLMOutputWithPast]: - if use_cache is None: - use_cache = self.config.use_cache - - if output_attentions: - raise ValueError("output_attentions is not yet supported in Molmo") - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model.forward( - input_ids=input_ids, - input_embeddings=inputs_embeds, - attention_mask=attention_mask, - attention_bias=attention_bias, - response_mask=response_mask, - images=images, - image_masks=image_masks, - image_input_idx=image_input_idx, - subsegment_ids=subsegment_ids, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - last_logits_only=last_logits_only, - output_hidden_states=output_hidden_states, - append_last_valid_logits=append_last_valid_logits, - ) - - logits = outputs.logits - hidden_states = outputs.hidden_states - - loss = None - if labels is not None: - if loss_masks is not None: - loss_masks = loss_masks * (loss_masks > 0) - batch_size_in_tokens = max(loss_masks.sum().item(), 1) - labels = labels.long() - labels.masked_fill_(~(loss_masks > 0), -100) - labels = labels.view(-1) - logits_for_loss = logits.to(torch.float32).view(-1, logits.size(-1)) - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none") - loss = loss_fct(logits_for_loss, labels) - loss = loss.view(input_ids.shape[0], -1) - loss = loss * loss_masks - loss = loss.sum() / batch_size_in_tokens - use_zloss = getattr(self.config, "softmax_auxiliary_loss", False) - if use_zloss: - z_squared = logits_for_loss.logsumexp(-1).pow(2) - z_loss = self.config.softmax_auxiliary_loss_scale * z_squared - z_loss = z_loss.view(input_ids.shape[0], -1) - z_loss = z_loss * loss_masks - z_loss = z_loss.sum() / batch_size_in_tokens - loss += z_loss - else: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = torch.nn.CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.embedding_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.attn_key_values, - hidden_states=hidden_states, - ) - - def can_generate(self) -> bool: - return True - - @torch.no_grad() - def generate_from_batch( - self, - batch: Dict[str, Any], - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ): - if generation_config is not None: - assert generation_config.use_cache - - images = batch.get("images") - image_masks = batch.get("image_masks") - image_input_idx = batch.get("image_input_idx") - - # Validate inputs. - input_ids = batch["input_ids"] - batch_size, seq_len = input_ids.shape - attention_mask = batch.get("attention_mask", None) - max_new_tokens = generation_config.max_new_tokens - assert max_new_tokens is not None - mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len - position_ids: Optional[torch.Tensor] = None - append_last_valid_logits: Optional[torch.Tensor] = None - if self.config.use_position_ids and attention_mask is None: - attention_mask = input_ids != -1 - position_ids = torch.clamp(torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, min=0) - append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1 - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))], - dim=1, - ) - if attention_mask is not None: - assert attention_mask.shape == (batch_size, mask_len) - - out = super().generate( - batch["input_ids"], - generation_config, - attention_mask=attention_mask, - images=images, - image_masks=image_masks, - image_input_idx=image_input_idx, - position_ids=position_ids, - append_last_valid_logits=append_last_valid_logits, - **kwargs, - ) - - return out - - def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs): - if past_key_values: - # This is because we want the model to only process the last generated token. - input_ids = input_ids[:, -1:] - - if self.config.use_position_ids: - attention_mask = kwargs.get("attention_mask") - images = kwargs.get("images") - image_masks = kwargs.get("image_masks") - image_input_idx = kwargs.get("image_input_idx") - position_ids = kwargs.get("position_ids") - append_last_valid_logits = kwargs.get("append_last_valid_logits") - model_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": True, - "last_logits_only": True, - } - if past_key_values is None: - model_inputs["images"] = images - model_inputs["image_masks"] = image_masks - model_inputs["image_input_idx"] = image_input_idx - model_inputs["append_last_valid_logits"] = append_last_valid_logits - else: - model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values} - - model_inputs.update(kwargs) - model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache) - return model_inputs - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - num_new_tokens: int = 1, - ) -> Dict[str, Any]: - if self.config.use_position_ids: - model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 - if "append_last_valid_logits" in model_kwargs: - del model_kwargs["append_last_valid_logits"] - if "images" in model_kwargs: - del model_kwargs["images"] - del model_kwargs["image_masks"] - del model_kwargs["image_input_idx"] - cache_name, cache = super()._extract_past_from_model_output(outputs) - model_kwargs[cache_name] = cache - model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens - return model_kwargs - - def get_input_embeddings(self) -> torch.nn.Module: - return self.model.transformer.wte - - def set_input_embeddings(self, value: torch.nn.Module): - self.model.transformer.wte = value - - def get_output_embeddings(self): - if self.config.weight_tying: - return self.model.transformer.wte - else: - return self.model.transformer.ff_out - - def set_output_embeddings(self, value: torch.nn.Module): - if self.config.weight_tying: - self.model.transformer.wte = value - else: - self.model.transformer.ff_out = value - - def tie_weights(self): - """ - This function is intentionally left as a no-op. - - Weight tying is handled as follows: - - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration. - See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`. - - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled. - See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method. - - Therefore, there is no need to explicitly tie the weights in this function. - """ - pass - - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None) -> torch.nn.Embedding: - """ - Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`. - - Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. - - Arguments: - new_num_tokens (`int`, *optional*): - The new number of tokens in the embedding matrix. Increasing the size will add newly initialized - vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just - returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. - pad_to_multiple_of (`int`, *optional*): - If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to - `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more - details about this, or help on choosing the correct value for resizing, refer to this guide: - https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc - - Return: - `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. - - Note: - This method differs from the base class implementation by resizing the `embedding_size` attribute of the - model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size` - is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token - embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary. - """ - model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - if new_num_tokens is None and pad_to_multiple_of is None: - return model_embeds - - # Update base model and current model config - self.config.embedding_size = model_embeds.weight.shape[0] - self.model.config.embedding_size = model_embeds.weight.shape[0] - - # Check if the embedding size is less than the vocab size - if self.config.embedding_size < self.config.vocab_size: - warning_message = ( - f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size " - f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary " - "size is less than or equal to the new token embedding size." - ) - log.warning(warning_message) - - # Tie weights again if needed - self.tie_weights() - - return model_embeds - - -# Always register for multi-modal features -AutoModelForCausalLM.register(MolmoConfig, MolmoForCausalLM) diff --git a/olmocr/train/molmo/preprocessing_molmo.py b/olmocr/train/molmo/preprocessing_molmo.py deleted file mode 100644 index a7c63cc..0000000 --- a/olmocr/train/molmo/preprocessing_molmo.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -Processor class for Molmo. -""" - -from typing import Optional - -from PIL import ImageOps -from PIL.Image import Image - -try: - from typing import Unpack -except ImportError: - from typing_extensions import Unpack - -import numpy as np -import torch -from transformers import AutoTokenizer -from transformers.image_utils import ImageInput -from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs -from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -from transformers.utils import logging - -from .image_preprocessing_molmo import MolmoImageProcessor, MolmoImagesKwargs - -logger = logging.get_logger(__name__) - - -DEFAULT_IMAGE_PATCH_TOKEN = "" -DEFAULT_IM_START_TOKEN = "" -DEFAULT_IM_END_TOKEN = "" -DEFAULT_IM_COL_TOKEN = "" -IMAGE_PROMPT = "<|image|>" - -EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT) - - -def get_special_token_ids(tokenizer): - ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False) - assert len(ids) == len(EXTRA_TOKENS) - return {k: i for k, i in zip(EXTRA_TOKENS, ids)} - - -class MolmoTextKwargs(TextKwargs, total=False): - style: Optional[str] - system_prompt: Optional[str] - message_format: Optional[str] - always_start_with_space: Optional[bool] - sequence_length: Optional[int] - - -class MolmoProcessorKwargs(ProcessingKwargs, total=False): - text_kwargs: MolmoTextKwargs - images_kwargs: MolmoImagesKwargs - _defaults = { - "images_kwargs": { - "max_crops": 12, - "overlap_margins": [4, 4], - "base_image_input_size": [336, 336], - "image_token_length_w": 12, - "image_token_length_h": 12, - "image_patch_size": 14, - "image_padding_mask": True, - }, - "text_kwargs": { - "style": "long_caption", - "system_prompt": "none", - "message_format": "role", - "always_start_with_space": True, - "sequence_length": 1536, - "padding": False, - }, - } - - -class MolmoProcessor(ProcessorMixin): - attributes = ["image_processor", "tokenizer"] - image_processor_class = "AutoImageProcessor" - tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") - - def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer: AutoTokenizer = None, **kwargs): - # self.image_processor = image_processor - # self.tokenizer = tokenizer - super().__init__(image_processor, tokenizer) - self._special_tokens = None - - @property - def special_token_ids(self): - if self._special_tokens is None: - self._special_tokens = get_special_token_ids(self.tokenizer) - return self._special_tokens - - def get_tokens_input(self, prompt, message_format, always_start_with_space): - if message_format == "none" or message_format is None: - pass - elif message_format == "role": - prompt = "User: " + prompt + " Assistant:" - else: - raise NotImplementedError(f"Message format {message_format} not implemented") - - if always_start_with_space: - prompt = " " + prompt - - tokens = self.tokenizer.encode(prompt, add_special_tokens=False) - - return tokens - - def process( - self, - text: TextInput = None, - images: ImageInput = None, - *, - tokens: Optional[PreTokenizedInput] = None, - **kwargs: Unpack[MolmoProcessorKwargs], - ): - output_kwargs = self._merge_kwargs( - MolmoProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, - ) - - if tokens is None: - tokens = self.get_tokens_input( - text, - output_kwargs["text_kwargs"]["message_format"], - output_kwargs["text_kwargs"]["always_start_with_space"], - ) - - image_token_id = self.special_token_ids[IMAGE_PROMPT] - - if images is not None: - if not isinstance(images, (list, tuple)): - images = [images] - image_arrays = [] - for image in images: - if isinstance(image, Image): - image = image.convert("RGB") - # Handle images with EXIF orientation tags, which PIL will ignore by default - # https://github.com/python-pillow/Pillow/issues/4703 - img = ImageOps.exif_transpose(image) - image_arrays.append(np.array(image)) - else: - assert len(image.shape) == 3 and image.shape[-1] == 3 - image_arrays.append(image.astype(np.uint8)) - images = image_arrays - # For now only support inserting images at the start - image_idx = [-1] * len(images) - else: - image_idx = None - - sequence_length = output_kwargs["text_kwargs"]["sequence_length"] - - image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN] - image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN] - image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN] - image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN] - out = self.image_processor.multimodal_preprocess( - images=images, - image_idx=image_idx, - tokens=np.asarray(tokens).astype(np.int32), - sequence_length=sequence_length, - image_patch_token_id=image_patch_token_id, - image_col_token_id=image_col_token_id, - image_start_token_id=image_start_token_id, - image_end_token_id=image_end_token_id, - **output_kwargs["images_kwargs"], - ) - - # Prepend BOS - # qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token. - bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id - decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos) - out["input_ids"] = decoder_input_tokens - if "image_input_idx" in out: - # Shift patch mapping up by one since we added BOS - image_input_idx = out["image_input_idx"] - out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1) - - for k, v in out.items(): - out[k] = torch.from_numpy(v) - - return out - - -MolmoProcessor.register_for_auto_class() diff --git a/olmocr/train/prepare_olmocrmix.py b/olmocr/train/prepare_olmocrmix.py new file mode 100644 index 0000000..75cbe1d --- /dev/null +++ b/olmocr/train/prepare_olmocrmix.py @@ -0,0 +1,170 @@ +import argparse +import json +from os import PathLike +from pathlib import Path +from typing import Optional +import pandas as pd + +from huggingface_hub import snapshot_download + + +def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: str | PathLike, max_examples: Optional[int] = None) -> str: + """ + Prepare OLMoCR mix dataset by downloading from HuggingFace and organizing into a folder structure. + + Args: + dataset_path: HuggingFace dataset path + subset: Dataset subset name + split: Dataset split (train/validation/test) + destination: Destination directory path + max_examples: Maximum number of examples to process (None for all) + """ + # Step 1: Download dataset using hugging face hub snapshot_download to destination/hugging_face folder + dest_path = Path(destination) + hugging_face_dir = dest_path / "hugging_face" + hugging_face_dir.mkdir(parents=True, exist_ok=True) + + print(f"Downloading dataset {dataset_path} to {hugging_face_dir}...") + + # Download the entire repository including PDFs and parquet files + local_dir = snapshot_download( + repo_id=dataset_path, + repo_type="dataset", + local_dir=hugging_face_dir, + ) + + print(f"Downloaded to: {local_dir}") + + # Step 2: Create destination folder structure for processed markdown files + processed_dir = dest_path / f"processed_{subset}_{split}" + processed_dir.mkdir(exist_ok=True) + + # Manual map to parquet files for now + assert dataset_path == "allenai/olmOCR-mix-0225", "Only supporting the olmocr-mix for now, later will support other training sets" + if subset == "00_documents" and split == "train_s2pdf": + parquet_files = [dest_path / "hugging_face" / "train-s2pdf.parquet"] + elif subset == "00_documents" and split == "eval_s2pdf": + parquet_files = [dest_path / "hugging_face" / "eval-s2pdf.parquet"] + elif subset == "01_books" and split == "train_s2pdf": + parquet_files = [dest_path / "hugging_face" / "train-iabooks.parquet"] + elif subset == "01_books" and split == "train_s2pdf": + parquet_files = [dest_path / "hugging_face" / "eval-iabooks.parquet"] + else: + raise NotImplementedError() + + + # Step 3: Process parquet files + total_processed = 0 + total_errors = 0 + + for parquet_file in parquet_files: + print(f"Processing {parquet_file.name}...") + df = pd.read_parquet(parquet_file) + + # Process each row + for idx, row in df.iterrows(): + if max_examples and total_processed >= max_examples: + break + + try: + + # Extract fields from the row + # The rows in the parquet will look like url, page_number, response (json format), and id + response = row.get('response', '') + doc_id = str(idx) + + assert len(doc_id) > 4 + + # Parse response if it's a JSON string + response_data = json.loads(response) + response = response_data + + # Create folder structure using first 4 digits of id + # Make a folder structure, to prevent a huge amount of files in one folder, using the first 4 digits of the id, ex. id[:4]/id[4:].md + folder_name = doc_id[:4] + file_name = f"{doc_id[4:]}.md" + + # Create directory + output_dir = processed_dir / folder_name + output_dir.mkdir(exist_ok=True) + + # Write markdown file with front matter and natural text + output_file = output_dir / file_name + with open(output_file, 'w', encoding='utf-8') as f: + # Extract natural_text and other fields for front matter + natural_text = response.get('natural_text', '') + # Create front matter from other fields + front_matter = {k: v for k, v in response.items() if k != 'natural_text'} + + # Write front matter + f.write("---\n") + for k, v in front_matter.items(): + f.write(f"{k}: {v}\n") + f.write("---\n") + + # Write natural text + f.write(natural_text) + + total_processed += 1 + if total_processed % 1000 == 0: + print(f"Processed {total_processed} examples...") + except Exception as ex: + print(f"Error processing line: {ex}") + total_errors += 1 + + if max_examples and total_processed >= max_examples: + break + + print(f"Completed! Processed {total_processed} examples to {processed_dir}") + print(f"Total errors: {total_errors}") + return str(processed_dir) + + +def main(): + parser = argparse.ArgumentParser(description="Prepare OLMoCR mix dataset") + parser.add_argument( + "--dataset-path", + type=str, + default="allenai/olmOCR-mix-0225", + help="HuggingFace dataset path (e.g., 'allenai/olmocr-mix')" + ) + parser.add_argument( + "--subset", + type=str, + default="00_documents", + required=True, + help="Dataset subset name" + ) + parser.add_argument( + "--split", + type=str, + default="eval_s2pdf", + required=True, + help="Dataset split ex eval_s2pdf" + ) + parser.add_argument( + "--destination", + type=str, + required=True, + help="Destination directory path" + ) + parser.add_argument( + "--max-examples", + type=int, + default=None, + help="Maximum number of examples to process (default: all)" + ) + + args = parser.parse_args() + + prepare_olmocr_mix( + dataset_path=args.dataset_path, + subset=args.subset, + split=args.split, + destination=args.destination, + max_examples=args.max_examples + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/olmocr/train/train.py b/olmocr/train/train.py index eb40449..5088848 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -1,228 +1,9 @@ -import logging -import os -from logging import Logger -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Optional - -import torch -import torch.distributed -import wandb -from datasets.utils import disable_progress_bars -from datasets.utils.logging import set_verbosity -from peft import LoraConfig, get_peft_model # pyright: ignore -from transformers import ( - AutoProcessor, - Qwen2VLForConditionalGeneration, - Trainer, - TrainerCallback, - TrainingArguments, -) -from transformers.integrations import WandbCallback -from transformers.trainer_callback import TrainerControl, TrainerState -from transformers.trainer_utils import get_last_checkpoint - -from olmocr.train.core.cli import make_cli, save_config, to_native_types -from olmocr.train.core.config import TrainConfig -from olmocr.train.core.loggers import get_logger -from olmocr.train.core.paths import copy_dir, join_path -from olmocr.train.core.state import BeakerState - -from .utils import ( - RunName, - TruncatingCollator, - get_local_dir, - log_trainable_parameters, - make_dataset, - setup_environment, -) - - -class CheckpointUploadCallback(TrainerCallback): - def __init__(self, save_path: str, logger: Optional[Logger] = None): - self.save_path = save_path - self.logger = logger or get_logger(self.__class__.__name__) - - def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - if state.is_local_process_zero: - latest_checkpoint = get_last_checkpoint(args.output_dir) - if not latest_checkpoint: - return - - dir_name = Path(latest_checkpoint).name - copy_dir(str(latest_checkpoint), f"{self.save_path}/{dir_name}") - self.logger.info("Saved checkpoint to %s", f"{self.save_path}/{dir_name}") - - -def update_wandb_config(config: TrainConfig, trainer: Trainer, model: torch.nn.Module): - # finding wandb callback - callbacks = [c for c in trainer.callback_handler.callbacks if isinstance(c, WandbCallback)] # pyright: ignore - if not callbacks: - raise ValueError("WandbCallback not found in trainer callbacks") - - wandb_callback = callbacks[0] - peft_config = to_native_types(getattr(model, "peft_config", {})) - script_config = to_native_types(config) - beaker_envs = {k: v for k, v in os.environ.items() if k.lower().startswith("beaker")} - - on_setup_fn = wandb_callback.setup - - def setup_and_update(args, state, model, **kwargs): - on_setup_fn(args=args, state=state, model=model, **kwargs) - wandb.config.update({"peft": peft_config}, allow_val_change=True) - wandb.config.update({"script": script_config}, allow_val_change=True) - wandb.config.update({"beaker": beaker_envs}, allow_val_change=True) - if (run := wandb.run) and (beaker_url := BeakerState().url): - run.notes = beaker_url - - wandb_callback.setup = setup_and_update - - -def get_rank() -> int: - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return torch.distributed.get_rank() - return 0 - - -def run_train(config: TrainConfig): - if get_rank() == 0: - logger_level = logging.INFO - else: - logger_level = logging.WARN - disable_progress_bars() - - logger = get_logger(__name__, level=logger_level) - set_verbosity(logger_level) - - run_name = RunName.get(config) - - setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group) - - processor = AutoProcessor.from_pretrained(config.model.name_or_path, trust_remote_code=True) - train_dataset, valid_dataset = make_dataset(config, processor) - logger.info(train_dataset) - logger.info(valid_dataset) - - if "qwen" in config.model.name_or_path.lower(): - model = Qwen2VLForConditionalGeneration.from_pretrained( - config.model.name_or_path, torch_dtype=torch.bfloat16, _attn_implementation="flash_attention_2" if config.model.use_flash_attn else None - ) - else: - from .molmo.config_molmo import MolmoConfig - from .molmo.modeling_molmo import MolmoForCausalLM - - model_config = MolmoConfig.from_pretrained(config.model.name_or_path, trust_remote_code=True) - - if model_config.max_position_embeddings < config.generate.max_length: - logger.warning( - f"ALERT, force adjusting model config max_position_embeddings upwards from {model_config.max_position_embeddings} to {config.generate.max_length}" - ) - model_config.max_position_embeddings = config.generate.max_length - - if config.model.use_flash_attn: - model_config.attention_type = "flash" - - model = MolmoForCausalLM.from_pretrained(config.model.name_or_path, torch_dtype=torch.bfloat16, config=model_config, trust_remote_code=True) - - logger.info(model) - - if config.lora is not None: - peft_config = LoraConfig( - r=config.lora.rank, - lora_alpha=config.lora.alpha, - lora_dropout=config.lora.dropout, - bias=config.lora.bias, # pyright: ignore - task_type=config.lora.task_type, - target_modules=list(config.lora.target_modules), - ) - model = get_peft_model(model=model, peft_config=peft_config) - log_trainable_parameters(model=model, logger=logger) - - save_path = join_path("", config.save.path, run_name.run) - - # Make sure directory exists if local - if not save_path.startswith("s3://"): - os.makedirs(os.path.dirname(save_path), exist_ok=True) - - save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore - - with TemporaryDirectory() as output_dir: - training_args = TrainingArguments( - run_name=run_name.run, - logging_steps=config.hparams.log_every_steps, - output_dir=output_dir, - eval_strategy="steps", - report_to="wandb", - # report_to=[], # disable logging to wandb, we will use a custom callback - optim=config.hparams.optim, - eval_steps=config.hparams.eval_every_steps, - learning_rate=config.hparams.learning_rate, - per_device_train_batch_size=config.hparams.batch_size, - per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size, - gradient_checkpointing=config.hparams.gradient_checkpointing, - gradient_checkpointing_kwargs=( - dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142 - if config.hparams.gradient_checkpointing and config.lora is not None - else {} - ), - gradient_accumulation_steps=config.hparams.gradient_accumulation_steps, - max_steps=config.hparams.max_steps, - weight_decay=config.hparams.weight_decay, - dataloader_num_workers=config.max_workers, - load_best_model_at_end=True, - save_strategy="steps", - ddp_find_unused_parameters=config.hparams.find_unused_parameters, - save_steps=config.save.save_every_steps, - warmup_steps=config.hparams.warmup_steps, - warmup_ratio=config.hparams.warmup_ratio, - bf16=True, - label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885 - max_grad_norm=config.hparams.clip_grad_norm, - remove_unused_columns=False, - eval_on_start=True, - metric_for_best_model=config.valid_data.metric_for_best_model, - ) - - data_collator = TruncatingCollator(max_length=config.generate.max_length) - - checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger) - - # Initialize Trainer - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=valid_dataset, - tokenizer=processor.tokenizer, - data_collator=data_collator, - callbacks=[checkpoint_callback], - ) - - # Train the model - trainer.train() # pyright: ignore - - if get_rank() == 0: - with get_local_dir(join_path("", save_path, "best")) as best_dir: - if config.lora is not None: - logger.info("Merging LoRA adapters into the base model...") - model = model.merge_and_unload() - logger.info("LoRA adapters merged successfully.") - - model.save_pretrained(best_dir) - - logger.info("Saved best model to %s", best_dir) - - # Uncomment to test speed of data loader - # train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False) - # for entry in tqdm(train_dataloader): - # print("Step!") - # model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()}) - - -def main(): - train_config = make_cli(TrainConfig) # pyright: ignore - run_train(train_config) - - -if __name__ == "__main__": - main() +# TODO Overall, this code will read in a config yaml file with omega conf +# From that config, we are going to use HuggingFace Trainer to train a model +# TODOS: +# Build a script to convert olmocr-mix to a new dataloader format +# Write a new dataloader and collator, with tests that brings in everything, only needs to support batch size 1 for this first version +# Get a basic config yaml file system working +# Get a basic hugging face trainer running, supporting Qwen2.5VL for now +# Saving and restoring training checkpoints +# Converting training checkpoints to vllm compatible checkpoinst diff --git a/olmocr/train/utils.py b/olmocr/train/utils.py deleted file mode 100644 index 0ab59da..0000000 --- a/olmocr/train/utils.py +++ /dev/null @@ -1,232 +0,0 @@ -import json -import multiprocessing -import os -import random -from contextlib import contextmanager -from dataclasses import dataclass -from datetime import datetime -from functools import partial -from hashlib import sha1 -from logging import Logger -from tempfile import TemporaryDirectory -from typing import Dict, Generator, List, Optional, TypeVar - -import torch -from accelerate import Accelerator -from accelerate.utils import PrecisionType -from datasets import Dataset, DatasetDict, concatenate_datasets -from transformers import AutoProcessor - -from olmocr.train.dataloader import build_finetuning_dataset -from olmocr.train.dataprep import ( - batch_prepare_data_for_molmo_training, - batch_prepare_data_for_qwen2_training, -) - -from .core.cli import to_native_types -from .core.config import AwsConfig, DataConfig, SourceConfig, TrainConfig, WandbConfig -from .core.loggers import get_logger -from .core.paths import copy_dir, is_local -from .core.state import BeakerState - -T = TypeVar("T") - - -def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype: - pt = PrecisionType(accelerator.mixed_precision) - if pt == PrecisionType.FP16: - return torch.float16 - elif pt == PrecisionType.BF16: - return torch.bfloat16 - elif pt == PrecisionType.FP8: - return torch.float8_e4m3fn - return torch.float32 - - -def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) -> Dataset: - return build_finetuning_dataset(source.response_glob_path, pdf_cache_location=data_config.cache_location) - - -def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]: - random.seed(config.train_data.seed) - - if "qwen" in config.model.name_or_path.lower(): - batch_fn = batch_prepare_data_for_qwen2_training - elif "molmo" in config.model.name_or_path.lower(): - batch_fn = batch_prepare_data_for_molmo_training - else: - raise NotImplementedError("Model format not supported") - - # Retrieve the two target lengths from the first source for comparison - first_source = config.train_data.sources[0] - target_longest_image_dim = first_source.target_longest_image_dim - target_anchor_text_len = first_source.target_anchor_text_len - - # Verify that all sources have the same target lengths - for source in config.train_data.sources: - if source.target_longest_image_dim != target_longest_image_dim: - raise ValueError(f"Inconsistent target_longest_image_dim found in source {source}") - if source.target_anchor_text_len != target_anchor_text_len: - raise ValueError(f"Inconsistent target_anchor_text_len found in source {source}") - - # Concatenate datasets first, unfortunately you can't apply the transform before concatenation due to the library - train_dataset = concatenate_datasets([get_rawdataset_from_source(config.train_data, source) for source in config.train_data.sources]) - - # Apply the transform to the concatenated dataset - train_dataset = train_dataset.with_transform( - partial( - batch_fn, - processor=processor, - target_longest_image_dim=list(target_longest_image_dim), - target_anchor_text_len=list(target_anchor_text_len), - ) - ) - - # Validation sets get put into a datasetdict so each can report a loss separately - valid_dataset = DatasetDict( - **{ - source.name: get_rawdataset_from_source(config.valid_data, source).with_transform( - partial( - batch_fn, - processor=processor, - target_longest_image_dim=list(source.target_longest_image_dim), - target_anchor_text_len=list(source.target_anchor_text_len), - ) - ) - for source in config.valid_data.sources - } - ) - - return train_dataset, valid_dataset - - -def setup_environment(aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str): - multiprocessing.set_start_method("spawn", force=True) - os.environ["TOKENIZERS_PARALLELISM"] = "false" - os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "false" - - if wandb_config: - os.environ["WANDB_WATCH"] = "false" - - for key, value in to_native_types(wandb_config or {}).items(): - if value is not None: - os.environ[f"WANDB_{key.upper()}"] = str(value) - - for key, value in to_native_types(aws_config or {}).items(): - if value is not None: - os.environ[f"AWS_{key.upper()}"] = str(value) - - os.environ.update(kwargs) - - -@dataclass -class RunName: - run: str - group: str - - @classmethod - def get(cls, config: TrainConfig, accelerator: Optional[Accelerator] = None) -> "RunName": - job_rank = f"-{accelerator.process_index}" if accelerator else "" - - if beaker_job_id := BeakerState().job_id: - job_id = f"-{beaker_job_id}" - else: - job_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - - (config_hash := sha1()).update(json.dumps(to_native_types(config)).encode()) - model_name = config.model.name_or_path.replace("/", "_") - - group_name = f"{model_name}-{config_hash.hexdigest()[:6]}" - run_name = f"{group_name}{job_id}{job_rank}" - return cls(group=group_name, run=run_name) - - -@contextmanager -def override_torch_threads(n: int): - torch_num_threads = torch.get_num_threads() - torch.set_num_threads(n) - - yield - - torch.set_num_threads(torch_num_threads) - - -@contextmanager -def temp_args(obj: T, **kwargs) -> Generator[T, None, None]: - orig = {k: getattr(obj, k) for k in kwargs.keys()} - for k, v in kwargs.items(): - setattr(obj, k, v) - - yield obj - - for k, v in orig.items(): - setattr(obj, k, v) - - -def log_trainable_parameters(model: torch.nn.Module, logger: Optional[Logger] = None): - """ - Prints the number of trainable parameters in the model. - """ - trainable_params = 0 - all_param = 0 - for name, param in model.named_parameters(): - all_param += param.numel() - if param.requires_grad: - (logger or get_logger(__name__)).info(f"training with {name}") - trainable_params += param.numel() - - (logger or get_logger(__name__)).info( - "trainable params: %s || all params: %s || trainable%%: %s", - f"{trainable_params:,}", - f"{all_param:,}", - f"{trainable_params / all_param:.2%}", - ) - - -class TruncatingCollator: - def __init__(self, max_length: int): - self.max_length = max_length - - def __call__(self, batch: List[Dict]) -> Dict: - # Assert that we are only handling batch size 1 for now - assert len(batch) == 1, "Only batch size 1 is supported for now" - - if "pixel_values" in batch[0]: - # Qwen2 case - truncated_input_ids = torch.tensor(batch[0]["input_ids"][: self.max_length]).unsqueeze(0) - truncated_attention_mask = torch.tensor(batch[0]["attention_mask"][: self.max_length]).unsqueeze(0) - truncated_labels = torch.tensor(batch[0]["labels"][: self.max_length]).unsqueeze(0) - - return { - "input_ids": truncated_input_ids, - "attention_mask": truncated_attention_mask, - "labels": truncated_labels, - "pixel_values": torch.tensor(batch[0]["pixel_values"]).unsqueeze(0), - "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]).unsqueeze(0), - } - elif "image_input_idx" in batch[0]: - # molmo case - truncated_input_ids = batch[0]["input_ids"][: self.max_length].unsqueeze(0) - truncated_attention_mask = batch[0]["attention_mask"][: self.max_length].unsqueeze(0) - truncated_labels = batch[0]["labels"][: self.max_length].unsqueeze(0) - - return { - "input_ids": truncated_input_ids, - "attention_mask": truncated_attention_mask, - "labels": truncated_labels, - "images": batch[0]["images"].unsqueeze(0), - "image_input_idx": batch[0]["image_input_idx"].unsqueeze(0), - "image_masks": batch[0]["image_masks"].unsqueeze(0), - } - else: - raise NotImplementedError() - - -@contextmanager -def get_local_dir(output_dir: str): - with TemporaryDirectory() as tmp_dir: - if is_local(output_dir): - yield output_dir - else: - yield tmp_dir - copy_dir(tmp_dir, output_dir) diff --git a/olmocr/data/__init__.py b/scripts/data/__init__.py similarity index 100% rename from olmocr/data/__init__.py rename to scripts/data/__init__.py diff --git a/olmocr/data/buildsilver.py b/scripts/data/buildsilver.py similarity index 100% rename from olmocr/data/buildsilver.py rename to scripts/data/buildsilver.py diff --git a/olmocr/data/buildsilverdatasummary.py b/scripts/data/buildsilverdatasummary.py similarity index 100% rename from olmocr/data/buildsilverdatasummary.py rename to scripts/data/buildsilverdatasummary.py diff --git a/olmocr/data/buildtestset.py b/scripts/data/buildtestset.py similarity index 100% rename from olmocr/data/buildtestset.py rename to scripts/data/buildtestset.py diff --git a/olmocr/data/convertsilver_birr.py b/scripts/data/convertsilver_birr.py similarity index 100% rename from olmocr/data/convertsilver_birr.py rename to scripts/data/convertsilver_birr.py diff --git a/olmocr/data/convertsilver_openai.py b/scripts/data/convertsilver_openai.py similarity index 100% rename from olmocr/data/convertsilver_openai.py rename to scripts/data/convertsilver_openai.py diff --git a/olmocr/data/renderpdf.py b/scripts/data/renderpdf.py similarity index 100% rename from olmocr/data/renderpdf.py rename to scripts/data/renderpdf.py diff --git a/olmocr/data/runopenaibatch.py b/scripts/data/runopenaibatch.py similarity index 100% rename from olmocr/data/runopenaibatch.py rename to scripts/data/runopenaibatch.py diff --git a/olmocr/eval/__init__.py b/scripts/eval/__init__.py similarity index 100% rename from olmocr/eval/__init__.py rename to scripts/eval/__init__.py diff --git a/olmocr/eval/buildelo.py b/scripts/eval/buildelo.py similarity index 100% rename from olmocr/eval/buildelo.py rename to scripts/eval/buildelo.py diff --git a/olmocr/eval/dolma_refine/aligners.py b/scripts/eval/dolma_refine/aligners.py similarity index 100% rename from olmocr/eval/dolma_refine/aligners.py rename to scripts/eval/dolma_refine/aligners.py diff --git a/olmocr/eval/dolma_refine/metrics.py b/scripts/eval/dolma_refine/metrics.py similarity index 100% rename from olmocr/eval/dolma_refine/metrics.py rename to scripts/eval/dolma_refine/metrics.py diff --git a/olmocr/eval/dolma_refine/registry.py b/scripts/eval/dolma_refine/registry.py similarity index 100% rename from olmocr/eval/dolma_refine/registry.py rename to scripts/eval/dolma_refine/registry.py diff --git a/olmocr/eval/dolma_refine/segmenters.py b/scripts/eval/dolma_refine/segmenters.py similarity index 100% rename from olmocr/eval/dolma_refine/segmenters.py rename to scripts/eval/dolma_refine/segmenters.py diff --git a/olmocr/eval/evalhtml.py b/scripts/eval/evalhtml.py similarity index 100% rename from olmocr/eval/evalhtml.py rename to scripts/eval/evalhtml.py diff --git a/olmocr/eval/evalhtml_template.html b/scripts/eval/evalhtml_template.html similarity index 100% rename from olmocr/eval/evalhtml_template.html rename to scripts/eval/evalhtml_template.html diff --git a/olmocr/eval/runeval.py b/scripts/eval/runeval.py similarity index 100% rename from olmocr/eval/runeval.py rename to scripts/eval/runeval.py diff --git a/olmocr/eval/scoreelo.py b/scripts/eval/scoreelo.py similarity index 100% rename from olmocr/eval/scoreelo.py rename to scripts/eval/scoreelo.py