mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 16:22:29 +00:00
commit
cbfc80355a
@ -290,26 +290,9 @@ def make_report(urls):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
urls = [
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_0-e09ebadf34a7.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=JEQpJxSaMIHuc9DFHyfHuxx0dEU%3D&Expires=1737654586",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_1-c2d267f97a73.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=KMiOTQiFEvgxU94ZrlJRFAgSQZA%3D&Expires=1737654587",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_10-b806c811fb67.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=NaoHNU2ZmEGrgMsxg2JHK%2Fv5zd0%3D&Expires=1737654587",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_11-19c1936b4372.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=BjkVydyKjzzH3uZiZ1GkWAk6cbk%3D&Expires=1737654588",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_12-cd41808a7974.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=jsk8TzJTKJwHi1Ru4%2Bw%2BiHZG638%3D&Expires=1737654589",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_13-8b055079b5eb.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=SE7kkobEBip44O8JY5axoMTV2Bs%3D&Expires=1737654590",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_14-1126e0da563c.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=jLSEWpDUzpmS8P9mNXbBoDYDOwU%3D&Expires=1737654590",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_15-05704e3d000d.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=TaCbyv2%2FDGCnCOgTzUvfEXdO%2Fmo%3D&Expires=1737654591",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_16-e57f795a89da.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=stqm1etAfDIpAQGNvZwe9c%2BYUbA%3D&Expires=1737654592",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_17-041a6d042764.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=rOTroBcSqCh3oM65bOJHEfaeal8%3D&Expires=1737654592",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_18-7a29697cee63.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=abmYM9KtzjicmdacRykPWXCdQr0%3D&Expires=1737654593",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_19-d32f14c067f8.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=iVg3nxrZXVpYybkLJIgOEJ3v37E%3D&Expires=1737654594",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_2-43c553548e69.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=IK27gl7b6NY05YNnnsimMVJc99I%3D&Expires=1737654595",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_3-fb42a458ecd5.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=d1qevJe8ZQONnu7zezYSJe3cbBw%3D&Expires=1737654595",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_4-76a50eed331a.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=qwZu2q1H4Y%2Bf3Kw7DNSYcTxwI7A%3D&Expires=1737654596",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_5-150b4d3583de.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=c%2FeqjnDSIRirgQviFWRLWVowKmA%3D&Expires=1737654597",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_6-6ca285526fd3.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=tkWDDuRinY77BLQCqumtlMiFJU8%3D&Expires=1737654598",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_7-01d711ee8bf7.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=eQtFo6CHJYHGu85wK0YG5khlE5U%3D&Expires=1737654598",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_8-0f36b852f274.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=weI3WB8vhjBYjk6t85DmyLdP97k%3D&Expires=1737654599",
|
||||
"https://jakep-tinyhost.s3.amazonaws.com/review_page_9-115e33463fd2.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=b4CpkHprCUtZoL0u%2FFYzsu%2BB1yU%3D&Expires=1737654600",
|
||||
]
|
||||
urls = ['https://jakep-tinyhost.s3.amazonaws.com/review_page_0-ff70abb8f517.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=NarEyyCfvusCh%2FHdB47VfHOnnBs%3D&Expires=1738359221', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_1-0800f9af46cf.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=ncTWAu5rSndBJJsU26HRYDaK6i8%3D&Expires=1738359222', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_10-f7081f6ca6f9.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=gYX8yjGyYshRqXGgdsX17%2Fdi9Ig%3D&Expires=1738359223', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_11-355dc69335bc.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=7%2Bc5qoa8Tbk06z0VcvJiIIVAz9M%3D&Expires=1738359224', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_12-95fce9bf0c18.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=fw4PBo0LnxikmLZ8xH%2BGD%2F%2BhXMU%3D&Expires=1738359225', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_13-f88f7d7482bf.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=yXkQp9oFDtroKgiO50EwpYdGLcA%3D&Expires=1738359226', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_14-8ac0b974bfd5.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=EgZTpj1%2FdzMBUgd%2BX4pVZ1Sp%2FrA%3D&Expires=1738359226', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_15-e3136188de5c.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=YKhAv4unNIlRcerQAaHN4kjc4qI%3D&Expires=1738359227', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_16-2c5abde50d49.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=Mj8%2BK5ISKzAYQFeYvmzTgCPcRwA%3D&Expires=1738359228', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_17-f13132a4cdcc.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=%2FHuzw2cjJ4oFm91UXojPnGzYi8Q%3D&Expires=1738359229', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_18-25070f2aa05e.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=ctd%2BUIM%2FxryJm%2FcwA%2BRZ%2FbRzBp8%3D&Expires=1738359230', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_19-d436ee434162.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=jVdFKobIoHlbTQ7zziG%2BXiIQ0Fo%3D&Expires=1738359230', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_2-a5ece743fd31.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=K8hIrjWtvo4SLVQrOB8TiXLgNJk%3D&Expires=1738359231', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_3-9ce03af05f51.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=T0fLGSH%2Bv%2F19veqbxnLxoSf7gVA%3D&Expires=1738359232', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_4-94eec18f8027.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=u2R1LundKpfnAUCcD%2BdGHA6uIR0%3D&Expires=1738359233', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_5-377d0a7d8f5a.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=5R38ZQAR9ew5x%2BRmMVQbTqbfVh0%3D&Expires=1738359234', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_6-537b22646a26.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=PLOELum1qzOXW8Cm5rfZphlFeMw%3D&Expires=1738359235', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_7-a4a7dcb08f20.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=DxPHukGXEpPrEPL6TF9QBKPE1Xg%3D&Expires=1738359236', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_8-48a71c829863.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=TjEINKj69HdmXsKY59k4f3PieeM%3D&Expires=1738359237', 'https://jakep-tinyhost.s3.amazonaws.com/review_page_9-8557438928c3.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=F7sQxw5A%2FDOcOaa%2FQSeqepH0PQc%3D&Expires=1738359238']
|
||||
# import tinyhost
|
||||
|
||||
# print(tinyhost.tinyhost(urls))
|
||||
|
||||
make_report(urls)
|
||||
|
87
pdelfin/train/config/molmo-o-lora-8192.yaml
Normal file
87
pdelfin/train/config/molmo-o-lora-8192.yaml
Normal file
@ -0,0 +1,87 @@
|
||||
model:
|
||||
name_or_path: allenai/Molmo-7B-O-0924
|
||||
arch: causal
|
||||
use_flash_attn: true
|
||||
|
||||
wandb:
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
generate:
|
||||
max_length: 8192
|
||||
|
||||
train_data:
|
||||
seed: 1337
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
sources:
|
||||
- name: openai_batch_data_v5_1_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_iabooks_train
|
||||
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
|
||||
valid_data:
|
||||
cache_location: /data/jakep/pdfdata/pdelfin_cache
|
||||
metric_for_best_model: openai_batch_data_v5_1_eval_loss
|
||||
sources:
|
||||
# These tend to be small, so you can load from s3 it's no big deal
|
||||
- name: openai_batch_data_v5_1_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
- name: openai_batch_data_v5_1_iabooks_eval
|
||||
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
|
||||
target_longest_image_dim: [1024]
|
||||
target_anchor_text_len: [6000]
|
||||
|
||||
|
||||
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
|
||||
hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
find_unused_parameters: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 3e-4
|
||||
max_steps: 10000
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 10
|
||||
eval_every_steps: 100
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
||||
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
|
||||
lora:
|
||||
rank: 32
|
||||
alpha: 32
|
||||
dropout: 0.05
|
||||
task_type: CAUSAL_LM
|
||||
target_modules:
|
||||
# attention layers in main transformer
|
||||
- att_proj
|
||||
- ff_proj
|
||||
- attn_out
|
||||
- ff_out
|
||||
# vision transformer attention and FF
|
||||
- attention.wq
|
||||
- attention.wk
|
||||
- attention.wv
|
||||
- attention.wo
|
||||
- feed_forward.w1
|
||||
- feed_forward.w2
|
||||
# vision image projector
|
||||
- vision_backbone.image_projector.w1
|
||||
- vision_backbone.image_projector.w2
|
||||
- vision_backbone.image_projector.w3
|
||||
|
||||
save:
|
||||
path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/
|
||||
save_every_steps: 1000
|
||||
|
||||
max_workers: 10
|
@ -45,7 +45,8 @@ hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: false
|
||||
gradient_checkpointing: true
|
||||
find_unused_parameters: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 1e-4
|
||||
max_steps: 10000
|
||||
|
0
pdelfin/train/molmo/__init__.py
Normal file
0
pdelfin/train/molmo/__init__.py
Normal file
60
pdelfin/train/molmo/config_molmo.py
Normal file
60
pdelfin/train/molmo/config_molmo.py
Normal file
@ -0,0 +1,60 @@
|
||||
from typing import List
|
||||
|
||||
from transformers import PretrainedConfig, AutoTokenizer
|
||||
|
||||
|
||||
class MolmoConfig(PretrainedConfig):
|
||||
model_type = "molmo"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50304,
|
||||
embedding_size=50304,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
layer_norm_eps: float = 1e-5,
|
||||
rope_theta=10000.0,
|
||||
clip_qkv=None,
|
||||
qkv_bias: bool = False,
|
||||
weight_tying: bool = False,
|
||||
use_position_ids: bool=True,
|
||||
tie_word_embeddings: bool=True,
|
||||
attention_layer_norm: bool=False,
|
||||
norm_after: bool = False,
|
||||
layer_norm_type: str="rms",
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.weight_tying = weight_tying
|
||||
self.use_position_ids = use_position_ids
|
||||
self.attention_layer_norm = attention_layer_norm
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.clip_qkv = clip_qkv
|
||||
self.qkv_bias = qkv_bias
|
||||
self.norm_after = norm_after
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.layer_norm_type = layer_norm_type
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
MolmoConfig.register_for_auto_class()
|
546
pdelfin/train/molmo/image_processing_molmo.py
Normal file
546
pdelfin/train/molmo/image_processing_molmo.py
Normal file
@ -0,0 +1,546 @@
|
||||
"""Image processor class for Molmo"""
|
||||
from typing import List, Optional, Union, Mapping
|
||||
|
||||
import numpy as np
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import convert_image_dtype
|
||||
|
||||
from transformers.image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ImageInput,
|
||||
is_valid_image,
|
||||
)
|
||||
from transformers.processing_utils import ImagesKwargs
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def pad_to_bounding_box(
|
||||
image, offset_height, offset_width, target_height,
|
||||
target_width, value=0
|
||||
):
|
||||
height, width = image.shape[:2]
|
||||
after_padding_width = target_width - offset_width - width
|
||||
after_padding_height = target_height - offset_height - height
|
||||
return np.pad(image, [
|
||||
[offset_height, after_padding_height],
|
||||
[offset_width, after_padding_width],
|
||||
[0, 0]
|
||||
], constant_values=value)
|
||||
|
||||
|
||||
def normalize_image(image, offset, scale):
|
||||
image -= np.array(offset, dtype=np.float32)[None, None, :]
|
||||
image /= np.array(scale, dtype=np.float32)[None, None, :]
|
||||
return image
|
||||
|
||||
|
||||
def resize_and_pad(
|
||||
image,
|
||||
desired_output_size,
|
||||
resize_method="torch-bilinear",
|
||||
pad_value=0,
|
||||
normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
):
|
||||
desired_height, desired_width = desired_output_size
|
||||
height, width = image.shape[:2]
|
||||
|
||||
# Cast into float32 since the training code did this in float32 and it (very rarely) effects
|
||||
# the results after rounding.
|
||||
image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
|
||||
image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
|
||||
image_scale = min(image_scale_x, image_scale_y)
|
||||
scaled_height = int(np.array(height, np.float32) * image_scale)
|
||||
scaled_width = int(np.array(width, np.float32) * image_scale)
|
||||
|
||||
if resize_method == "tensorflow":
|
||||
# This how the original training code did resizing, it can produce slightly different
|
||||
# results then using torch resize so we keep it just in case
|
||||
import tensorflow as tf
|
||||
image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
|
||||
image = tf.image.resize(
|
||||
image,
|
||||
[scaled_height, scaled_width],
|
||||
method=tf.image.ResizeMethod.BILINEAR,
|
||||
antialias=True,
|
||||
)
|
||||
image = tf.clip_by_value(image, 0.0, 1.0)
|
||||
image = image.numpy()
|
||||
elif resize_method == "torch-bilinear":
|
||||
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
||||
image = convert_image_dtype(image) # resize in float32 to match the training code
|
||||
image = torchvision.transforms.Resize(
|
||||
[scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True
|
||||
)(image)
|
||||
image = torch.clip(image, 0.0, 1.0)
|
||||
image = torch.permute(image, [1, 2, 0]).numpy()
|
||||
else:
|
||||
raise NotImplementedError(resize_method)
|
||||
|
||||
top_pad = (desired_height - scaled_height) // 2
|
||||
left_pad = (desired_width - scaled_width) // 2
|
||||
padding = [
|
||||
[top_pad, desired_height - scaled_height - top_pad],
|
||||
[left_pad, desired_width - scaled_width - left_pad],
|
||||
[0, 0]
|
||||
]
|
||||
image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
|
||||
image = np.pad(image, padding, constant_values=pad_value)
|
||||
if normalize:
|
||||
image = normalize_image(image, offset=image_mean, scale=image_std)
|
||||
return image, image_mask
|
||||
|
||||
|
||||
def select_tiling(h, w, patch_size, max_num_patches):
|
||||
"""Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
||||
original_size = np.stack([h, w]) # [1, 2]
|
||||
original_res = h * w
|
||||
tilings = []
|
||||
for i in range(1, max_num_patches+1):
|
||||
for j in range(1, max_num_patches+1):
|
||||
if i*j <= max_num_patches:
|
||||
tilings.append((i, j))
|
||||
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
||||
tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
|
||||
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
||||
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
||||
|
||||
# How much we would need to scale the image to fit exactly in each tiling
|
||||
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
||||
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
|
||||
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
||||
if np.all(required_scale < 1):
|
||||
# We are forced to downscale, so try to minimize the amount of downscaling
|
||||
ix = np.argmax(required_scale)
|
||||
else:
|
||||
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
||||
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
||||
ix = np.argmin(required_scale)
|
||||
return candidate_tilings[ix]
|
||||
|
||||
|
||||
class MolmoImagesKwargs(ImagesKwargs, total=False):
|
||||
max_crops: Optional[int]
|
||||
overlap_margins: Optional[List[int]]
|
||||
base_image_input_size: Optional[List[int]]
|
||||
image_token_length_w: Optional[int]
|
||||
image_token_length_h: Optional[int]
|
||||
image_patch_size: Optional[int]
|
||||
image_padding_mask: Optional[bool]
|
||||
|
||||
|
||||
class MolmoImageProcessor(BaseImageProcessor):
|
||||
"""Preprocess images and multi-model inputs"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_crops: int = 12,
|
||||
overlap_margins: List[int] = (4, 4),
|
||||
base_image_input_size: List[int] = (336, 336),
|
||||
image_token_length_w: int = 12,
|
||||
image_token_length_h: int = 12,
|
||||
image_patch_size: int = 14,
|
||||
image_padding_mask: bool = True,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.max_crops = max_crops
|
||||
self.overlap_margins = overlap_margins
|
||||
self.base_image_input_size = base_image_input_size
|
||||
self.image_token_length_w = image_token_length_w
|
||||
self.image_token_length_h = image_token_length_h
|
||||
self.image_patch_size = image_patch_size
|
||||
self.image_padding_mask = image_padding_mask
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
|
||||
def image_to_patches_and_tokens(
|
||||
self,
|
||||
image: ImageInput,
|
||||
image_patch_token_id: int,
|
||||
image_col_token_id: int,
|
||||
image_start_token_id: int,
|
||||
image_end_token_id: int,
|
||||
max_crops: Optional[int] = None,
|
||||
overlap_margins: Optional[List[int]] = None,
|
||||
base_image_input_size: Optional[Union[int, List[int]]] = None,
|
||||
image_token_length_w: Optional[int] = None,
|
||||
image_token_length_h: Optional[int] = None,
|
||||
image_patch_size: Optional[int] = None,
|
||||
):
|
||||
if isinstance(base_image_input_size, int):
|
||||
base_image_input_size = (base_image_input_size, base_image_input_size)
|
||||
|
||||
base_image_input_d = image_patch_size
|
||||
tokens_per_image = image_token_length_w * image_token_length_h
|
||||
image_base_patch_w = base_image_input_size[1] // base_image_input_d
|
||||
image_base_patch_h = base_image_input_size[0] // base_image_input_d
|
||||
|
||||
original_image_h, original_image_w = image.shape[:2]
|
||||
crop_size = base_image_input_size[0]
|
||||
|
||||
# Discard this many patches from the (left/top, right/bottom) of crops
|
||||
left_margin, right_margin = overlap_margins
|
||||
# left_margin, right_margin = 2, 2
|
||||
assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
|
||||
total_margin_pixels = base_image_input_d*(right_margin + left_margin) # pixels removed per dim
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
||||
crop_window_size = crop_window_patches * base_image_input_d
|
||||
tiling = select_tiling(
|
||||
original_image_h - total_margin_pixels,
|
||||
original_image_w - total_margin_pixels,
|
||||
crop_window_size,
|
||||
max_crops
|
||||
)
|
||||
src, img_mask = resize_and_pad(
|
||||
image,
|
||||
[tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels]
|
||||
)
|
||||
|
||||
# Now we have to split the image into crops, while keeping track of how each patch in the
|
||||
# each crop should be ordered in the global image, this require a lot of tricky booking
|
||||
n_crops = tiling[0] * tiling[1]
|
||||
patches_arr = []
|
||||
mask_arr = []
|
||||
patch_ordering_arr = []
|
||||
|
||||
# We assume 2x2 pooling, but can allow padding the right/bottom with extra
|
||||
# patches if the number of patches per side is not even
|
||||
assert (crop_patches+1)//2 == image_token_length_h
|
||||
assert (crop_patches+1)//2 == image_token_length_w
|
||||
on = 0
|
||||
on_patch = 0
|
||||
for i in range(tiling[0]):
|
||||
y0 = i*crop_window_size
|
||||
if i == 0:
|
||||
crop_y0 = 0
|
||||
else:
|
||||
crop_y0 = left_margin // 2
|
||||
|
||||
crop_h = image_base_patch_h - (right_margin + left_margin)
|
||||
if i == 0:
|
||||
crop_h += left_margin
|
||||
if i == (tiling[0]-1):
|
||||
crop_h += right_margin
|
||||
for j in range(tiling[1]):
|
||||
x0 = j*crop_window_size
|
||||
if j == 0:
|
||||
crop_x0 = 0
|
||||
else:
|
||||
crop_x0 = left_margin // 2
|
||||
|
||||
crop_w = image_base_patch_w - (right_margin + left_margin)
|
||||
if j == 0:
|
||||
crop_w += left_margin
|
||||
if j == (tiling[1]-1):
|
||||
crop_w += right_margin
|
||||
|
||||
pooled_w = (crop_w + 1) // 2
|
||||
pooled_h = (crop_h + 1) // 2
|
||||
patch_ordering_arr.append(
|
||||
pad_to_bounding_box(
|
||||
np.reshape(np.arange(on, on+pooled_h*pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)),
|
||||
crop_y0, crop_x0, image_token_length_h, image_token_length_w, value=-1
|
||||
)[:, :, 0]
|
||||
)
|
||||
patches_arr.append(src[y0:y0+crop_size, x0:x0+crop_size])
|
||||
mask_arr.append(img_mask[y0:y0+crop_size, x0:x0+crop_size])
|
||||
|
||||
on += pooled_h*pooled_w
|
||||
on_patch += 1
|
||||
patches = np.stack(patches_arr)
|
||||
patch_ordering = np.stack(patch_ordering_arr)
|
||||
img_mask = np.stack(mask_arr)
|
||||
|
||||
# Switch to [n_crops, n_patches, pixels_per_patch] format
|
||||
image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
|
||||
patches = einops.rearrange(
|
||||
patches, 'p (h dh) (w dw) c -> p (h w) (dh dw c)',
|
||||
dh=base_image_input_d,
|
||||
dw=base_image_input_d,
|
||||
h=image_base_patch_h,
|
||||
w=image_base_patch_w
|
||||
)
|
||||
img_mask = einops.rearrange(
|
||||
img_mask, 'p (h dh) (w dw) -> p (h w) (dh dw)',
|
||||
dh=base_image_input_d,
|
||||
dw=base_image_input_d,
|
||||
h=image_base_patch_h,
|
||||
w=image_base_patch_w
|
||||
)
|
||||
|
||||
img_mask = img_mask.astype(np.float32).mean(axis=-1)
|
||||
patch_ordering = np.reshape(patch_ordering, [-1])
|
||||
valid = patch_ordering >= 0
|
||||
|
||||
# Transpose order, to get left-to-right order instead of crop-by-crop order
|
||||
patch_ordering_rh = np.reshape(
|
||||
patch_ordering,
|
||||
[tiling[0], tiling[1], image_token_length_h, image_token_length_w]
|
||||
)
|
||||
patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
|
||||
patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
|
||||
|
||||
# The transpose will screw up which patches are masked, project the
|
||||
# new order into sparse structure of `patch_ordering` to fix this
|
||||
patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
|
||||
|
||||
# Now build the output tokens
|
||||
h = tiling[0] * crop_window_patches + (right_margin+left_margin)
|
||||
w = tiling[1] * crop_window_patches + (right_margin+left_margin)
|
||||
per_row = np.full(
|
||||
((w+1)//2,),
|
||||
image_patch_token_id,
|
||||
)
|
||||
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
|
||||
|
||||
joint = np.tile(per_row, [(h+1)//2])
|
||||
joint = [
|
||||
[image_start_token_id],
|
||||
joint,
|
||||
[image_end_token_id]
|
||||
]
|
||||
|
||||
# Finally do the same for the global image
|
||||
resized, _ = resize_and_pad(image, base_image_input_size)
|
||||
resized = einops.rearrange(
|
||||
resized, '(h dh) (w dw) c -> (h w) (dh dw c)',
|
||||
dh=base_image_input_d,
|
||||
dw=base_image_input_d,
|
||||
h=image_base_patch_h,
|
||||
w=image_base_patch_w
|
||||
)
|
||||
patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)
|
||||
|
||||
# Global image goes first, so the order of patches in previous crops gets increased
|
||||
patch_ordering = np.where(
|
||||
patch_ordering >= 0,
|
||||
patch_ordering + tokens_per_image,
|
||||
-1
|
||||
)
|
||||
patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0)
|
||||
per_row = np.full(
|
||||
(image_token_length_w,),
|
||||
image_patch_token_id,
|
||||
)
|
||||
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
|
||||
extra_tokens = np.tile(per_row, [image_token_length_h])
|
||||
joint = [
|
||||
[image_start_token_id],
|
||||
extra_tokens,
|
||||
[image_end_token_id],
|
||||
] + joint
|
||||
|
||||
joint = np.concatenate(joint, 0)
|
||||
img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
|
||||
return patches, joint, patch_ordering, img_mask
|
||||
|
||||
def build_image_input_idx(
|
||||
self,
|
||||
image_tokens: np.ndarray,
|
||||
patch_order: np.ndarray,
|
||||
image_patch_token_id: int,
|
||||
no_image: Optional[bool] = None,
|
||||
image_token_length_w: Optional[int] = None,
|
||||
image_token_length_h: Optional[int] = None,
|
||||
):
|
||||
"""Converts `patch_order` into a mapping of token_id -> patch_id"""
|
||||
|
||||
tokens_per_image = image_token_length_w * image_token_length_h
|
||||
if no_image is not None and no_image:
|
||||
return np.zeros((0, tokens_per_image), np.int32)
|
||||
|
||||
# Indices to insert the patches
|
||||
image_input_idx = image_tokens == image_patch_token_id
|
||||
image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
|
||||
|
||||
if patch_order is not None:
|
||||
n_tokens = image_input_idx.shape[0]
|
||||
patch_order = np.reshape(patch_order, [-1])
|
||||
n_patches = patch_order.shape[0]
|
||||
|
||||
valid = patch_order >= 0
|
||||
n_valid_patches = valid.sum()
|
||||
assert len(image_input_idx) == n_valid_patches
|
||||
|
||||
sorted_patch_ixs = np.zeros([n_tokens], np.int32)
|
||||
sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32)
|
||||
|
||||
# Project the inverted mapping into same sparse structure
|
||||
sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
|
||||
sorted_patch_ixs_ex[valid] = sorted_patch_ixs
|
||||
|
||||
# Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
|
||||
valid = (sorted_patch_ixs_ex >= 0).astype(np.int32)
|
||||
image_input_idx = image_input_idx[sorted_patch_ixs_ex*valid]
|
||||
image_input_idx = image_input_idx*valid - 100*(1 - valid)
|
||||
image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
|
||||
return image_input_idx
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
image_patch_token_id: int,
|
||||
image_col_token_id: int,
|
||||
image_start_token_id: int,
|
||||
image_end_token_id: int,
|
||||
max_crops: Optional[int] = None,
|
||||
overlap_margins: Optional[List[int]] = None,
|
||||
base_image_input_size: Optional[Union[int, List[int]]] = None,
|
||||
image_token_length_w: Optional[int] = None,
|
||||
image_token_length_h: Optional[int] = None,
|
||||
image_patch_size: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses an image
|
||||
|
||||
Returns:
|
||||
crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
|
||||
change between images but the other dimension are fixed
|
||||
tokens: (n_tokens,) int32 tokens, pad tokens indicate where to insert the
|
||||
patch features, might include other special tokens as well
|
||||
image_idx: (n_crops, n_patches) index in `tokens` to put the patch features from the
|
||||
crops after pooling, negative values indicates patches features to exclude
|
||||
padding_mask: (n_crops, n_patches) what percent of each crop is padding, can be None
|
||||
if the image mask is not being used.
|
||||
"""
|
||||
|
||||
max_crops = max_crops or self.max_crops
|
||||
overlap_margins = overlap_margins or self.overlap_margins
|
||||
base_image_input_size = base_image_input_size or self.base_image_input_size
|
||||
image_token_length_w = image_token_length_w or self.image_token_length_w
|
||||
image_token_length_h = image_token_length_h or self.image_token_length_h
|
||||
image_patch_size = image_patch_size or self.image_patch_size
|
||||
|
||||
crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(
|
||||
image,
|
||||
image_patch_token_id,
|
||||
image_col_token_id,
|
||||
image_start_token_id,
|
||||
image_end_token_id,
|
||||
max_crops,
|
||||
overlap_margins,
|
||||
base_image_input_size,
|
||||
image_token_length_w,
|
||||
image_token_length_h,
|
||||
image_patch_size,
|
||||
)
|
||||
patch_idx = self.build_image_input_idx(
|
||||
image_tokens,
|
||||
patch_ordering,
|
||||
image_patch_token_id,
|
||||
image_token_length_w=image_token_length_w,
|
||||
image_token_length_h=image_token_length_h,
|
||||
)
|
||||
return crops, image_tokens, patch_idx, img_mask
|
||||
|
||||
def multimodal_preprocess(
|
||||
self,
|
||||
images: np.ndarray,
|
||||
tokens: List[int],
|
||||
image_idx: np.ndarray,
|
||||
sequence_length: int,
|
||||
image_patch_token_id: int,
|
||||
image_col_token_id: int,
|
||||
image_start_token_id: int,
|
||||
image_end_token_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""Merge images and text tokens into multi-modal features for the model
|
||||
|
||||
:param images: images to use as input
|
||||
:param tokens: input text tokens
|
||||
:param image_idx: where to insert the images into `tokens`
|
||||
:params image_patch_token_id: id to use of tokens that will contain image features
|
||||
:params image_col_token_id: token id for image column special tokens
|
||||
:params image_start_token_id: token id for image start special tokens
|
||||
:params image_end_token_id: token id for image end special tokens
|
||||
:params kwargs: override preprocessor default args
|
||||
"""
|
||||
max_total_crops = kwargs.get("max_crops") or self.max_crops
|
||||
image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w
|
||||
image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h
|
||||
image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size
|
||||
base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size
|
||||
image_num_patch = (
|
||||
base_image_input_size[0] // image_patch_size,
|
||||
base_image_input_size[1] // image_patch_size,
|
||||
)
|
||||
image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask
|
||||
|
||||
tokens_per_image = image_token_length_w * image_token_length_h
|
||||
n_pixels = image_patch_size * image_patch_size * 3
|
||||
n_patches = image_num_patch[0] * image_num_patch[1]
|
||||
|
||||
if images is None:
|
||||
return {
|
||||
"input_ids": tokens,
|
||||
}
|
||||
else:
|
||||
n = len(images)
|
||||
all_crops = []
|
||||
all_image_idx = []
|
||||
out_tokens = []
|
||||
all_crop_masks = []
|
||||
|
||||
for ix in range(n):
|
||||
token_ix = image_idx[ix]
|
||||
crops, image_tokens, patch_idx, img_mask = self.preprocess(
|
||||
images[ix],
|
||||
image_patch_token_id,
|
||||
image_col_token_id,
|
||||
image_start_token_id,
|
||||
image_end_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if token_ix == -1: # -1 is an image inserted at the very start
|
||||
start = 0
|
||||
token_ix = 0
|
||||
end = 0
|
||||
else:
|
||||
start = 0 if ix == 0 else image_idx[ix-1] + 1
|
||||
end = token_ix + 1
|
||||
|
||||
all_image_idx.append(patch_idx + token_ix)
|
||||
all_crops.append(crops)
|
||||
out_tokens.append(tokens[start:token_ix])
|
||||
out_tokens.append(image_tokens)
|
||||
if ix == (n - 1):
|
||||
out_tokens.append(tokens[end:])
|
||||
if image_padding_mask:
|
||||
all_crop_masks.append(img_mask)
|
||||
|
||||
input_ids = np.concatenate(out_tokens, 0)
|
||||
images = np.concatenate(all_crops, 0)
|
||||
image_input_idx = np.concatenate(all_image_idx, 0)
|
||||
if image_padding_mask:
|
||||
image_masks = np.concatenate(all_crop_masks, 0)
|
||||
else:
|
||||
image_masks = None
|
||||
|
||||
out = {
|
||||
"input_ids": input_ids,
|
||||
"images": images,
|
||||
"image_input_idx": image_input_idx
|
||||
}
|
||||
if image_masks is not None:
|
||||
out["image_masks"] = image_masks
|
||||
return out
|
||||
|
||||
|
||||
MolmoImageProcessor.register_for_auto_class()
|
2374
pdelfin/train/molmo/modeling_molmo.py
Normal file
2374
pdelfin/train/molmo/modeling_molmo.py
Normal file
File diff suppressed because it is too large
Load Diff
192
pdelfin/train/molmo/preprocessing_molmo.py
Normal file
192
pdelfin/train/molmo/preprocessing_molmo.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""
|
||||
Processor class for Molmo.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import PIL
|
||||
from PIL import ImageOps
|
||||
from PIL.Image import Image
|
||||
|
||||
try:
|
||||
from typing import Unpack
|
||||
except ImportError:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import (
|
||||
TextKwargs,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = f"<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
||||
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
||||
IMAGE_PROMPT = "<|image|>"
|
||||
|
||||
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
|
||||
|
||||
|
||||
def get_special_token_ids(tokenizer):
|
||||
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
|
||||
assert len(ids) == len(EXTRA_TOKENS)
|
||||
return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
|
||||
|
||||
|
||||
class MolmoTextKwargs(TextKwargs, total=False):
|
||||
style: Optional[str]
|
||||
system_prompt: Optional[str]
|
||||
message_format: Optional[str]
|
||||
always_start_with_space: Optional[bool]
|
||||
sequence_length: Optional[int]
|
||||
|
||||
|
||||
class MolmoProcessorKwargs(ProcessingKwargs, total=False):
|
||||
text_kwargs: MolmoTextKwargs
|
||||
images_kwargs: MolmoImagesKwargs
|
||||
_defaults = {
|
||||
"images_kwargs": {
|
||||
"max_crops": 12,
|
||||
"overlap_margins": [4, 4],
|
||||
"base_image_input_size": [336, 336],
|
||||
"image_token_length_w": 12,
|
||||
"image_token_length_h": 12,
|
||||
"image_patch_size": 14,
|
||||
"image_padding_mask": True,
|
||||
},
|
||||
"text_kwargs": {
|
||||
"style": "long_caption",
|
||||
"system_prompt": "none",
|
||||
"message_format": "role",
|
||||
"always_start_with_space": True,
|
||||
"sequence_length": 1536,
|
||||
"padding": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MolmoProcessor(ProcessorMixin):
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
|
||||
|
||||
def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer : AutoTokenizer = None, **kwargs):
|
||||
# self.image_processor = image_processor
|
||||
# self.tokenizer = tokenizer
|
||||
super().__init__(image_processor, tokenizer)
|
||||
self._special_tokens = None
|
||||
|
||||
@property
|
||||
def special_token_ids(self):
|
||||
if self._special_tokens is None:
|
||||
self._special_tokens = get_special_token_ids(self.tokenizer)
|
||||
return self._special_tokens
|
||||
|
||||
def get_tokens_input(self, prompt, message_format, always_start_with_space):
|
||||
if message_format == "none" or message_format is None:
|
||||
pass
|
||||
elif message_format == "role":
|
||||
prompt = "User: " + prompt + " Assistant:"
|
||||
else:
|
||||
raise NotImplementedError(f"Message format {message_format} not implemented")
|
||||
|
||||
if always_start_with_space:
|
||||
prompt = " " + prompt
|
||||
|
||||
tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
|
||||
|
||||
return tokens
|
||||
|
||||
def process(
|
||||
self,
|
||||
text: TextInput = None,
|
||||
images: ImageInput = None,
|
||||
*,
|
||||
tokens: Optional[PreTokenizedInput] = None,
|
||||
**kwargs: Unpack[MolmoProcessorKwargs],
|
||||
):
|
||||
output_kwargs = self._merge_kwargs(
|
||||
MolmoProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if tokens is None:
|
||||
tokens = self.get_tokens_input(
|
||||
text,
|
||||
output_kwargs["text_kwargs"]["message_format"],
|
||||
output_kwargs["text_kwargs"]["always_start_with_space"],
|
||||
)
|
||||
|
||||
image_token_id = self.special_token_ids[IMAGE_PROMPT]
|
||||
|
||||
if images is not None:
|
||||
if not isinstance(images, (list, tuple)):
|
||||
images = [images]
|
||||
image_arrays = []
|
||||
for image in images:
|
||||
if isinstance(image, Image):
|
||||
image = image.convert("RGB")
|
||||
# Handle images with EXIF orientation tags, which PIL will ignore by default
|
||||
# https://github.com/python-pillow/Pillow/issues/4703
|
||||
img = ImageOps.exif_transpose(image)
|
||||
image_arrays.append(np.array(image))
|
||||
else:
|
||||
assert len(image.shape) == 3 and image.shape[-1] == 3
|
||||
image_arrays.append(image.astype(np.uint8))
|
||||
images = image_arrays
|
||||
# For now only support inserting images at the start
|
||||
image_idx = [-1]*len(images)
|
||||
else:
|
||||
image_idx = None
|
||||
|
||||
sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
|
||||
|
||||
image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
|
||||
image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
|
||||
image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
|
||||
image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
|
||||
out = self.image_processor.multimodal_preprocess(
|
||||
images=images,
|
||||
image_idx=image_idx,
|
||||
tokens=np.asarray(tokens).astype(np.int32),
|
||||
sequence_length=sequence_length,
|
||||
image_patch_token_id=image_patch_token_id,
|
||||
image_col_token_id=image_col_token_id,
|
||||
image_start_token_id=image_start_token_id,
|
||||
image_end_token_id=image_end_token_id,
|
||||
**output_kwargs["images_kwargs"]
|
||||
)
|
||||
|
||||
# Prepend BOS
|
||||
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
|
||||
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
||||
decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
|
||||
out["input_ids"] = decoder_input_tokens
|
||||
if "image_input_idx" in out:
|
||||
# Shift patch mapping up by one since we added BOS
|
||||
image_input_idx = out["image_input_idx"]
|
||||
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
||||
|
||||
for k, v in out.items():
|
||||
out[k] = torch.from_numpy(v)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
MolmoProcessor.register_for_auto_class()
|
@ -122,7 +122,10 @@ def run_train(config: TrainConfig):
|
||||
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
||||
)
|
||||
else:
|
||||
model_config = AutoConfig.from_pretrained(config.model.name_or_path, trust_remote_code=True)
|
||||
from .molmo.config_molmo import MolmoConfig
|
||||
from .molmo.modeling_molmo import MolmoForCausalLM
|
||||
|
||||
model_config = MolmoConfig.from_pretrained(config.model.name_or_path, trust_remote_code=True)
|
||||
|
||||
if model_config.max_position_embeddings < config.generate.max_length:
|
||||
logger.warning(f"ALERT, force adjusting model config max_position_embeddings upwards from {model_config.max_position_embeddings} to {config.generate.max_length}")
|
||||
@ -131,7 +134,7 @@ def run_train(config: TrainConfig):
|
||||
if config.model.use_flash_attn:
|
||||
model_config.attention_type = "flash"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model = MolmoForCausalLM.from_pretrained(
|
||||
config.model.name_or_path, torch_dtype=torch.bfloat16,
|
||||
config=model_config,
|
||||
trust_remote_code=True
|
||||
|
@ -10,7 +10,7 @@ then
|
||||
fi
|
||||
|
||||
|
||||
EXTRA_ARGS="-c pdelfin/train/config/molmo-o-lora.yaml --num_proc 64 --save.path \"s3://ai2-oe-data/jakep/experiments/molmo-pdf/v1/models/\${BEAKER_USER_ID}\""
|
||||
EXTRA_ARGS="-c pdelfin/train/config/molmo-o-lora-8192.yaml --num_proc 64 --save.path \"s3://ai2-oe-data/jakep/experiments/molmo-pdf/v1/models/\${BEAKER_USER_ID}\""
|
||||
|
||||
run_name=$(basename "$0" .sh)
|
||||
|
||||
@ -22,8 +22,8 @@ run_name=$(basename "$0" .sh)
|
||||
CLUSTER='jupiter'
|
||||
|
||||
gantry run \
|
||||
--description "${run_name}"\
|
||||
--task-name "${run_name}"\
|
||||
--description "${run_name}-8192"\
|
||||
--task-name "${run_name}-8192"\
|
||||
--allow-dirty \
|
||||
--host-networking \
|
||||
--workspace ai2/oe-data-model-based-cleanup \
|
||||
@ -32,7 +32,6 @@ gantry run \
|
||||
--pip gantry-requirements.txt \
|
||||
--priority high \
|
||||
--gpus 8 \
|
||||
--preemptible \
|
||||
--cluster "ai2/${CLUSTER}*" \
|
||||
--budget ai2/oe-data \
|
||||
--weka "oe-data-default:/data" \
|
||||
|
Loading…
x
Reference in New Issue
Block a user