mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 08:12:22 +00:00
193 lines
6.4 KiB
Python
193 lines
6.4 KiB
Python
"""
|
|
Processor class for Molmo.
|
|
"""
|
|
|
|
from typing import Optional
|
|
|
|
import PIL
|
|
from PIL import ImageOps
|
|
from PIL.Image import Image
|
|
|
|
try:
|
|
from typing import Unpack
|
|
except ImportError:
|
|
from typing_extensions import Unpack
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from transformers.image_utils import ImageInput
|
|
from transformers.processing_utils import (
|
|
TextKwargs,
|
|
ProcessingKwargs,
|
|
ProcessorMixin,
|
|
)
|
|
|
|
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
|
from transformers.utils import logging
|
|
|
|
from transformers import AutoTokenizer
|
|
from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
|
|
DEFAULT_IM_START_TOKEN = f"<im_start>"
|
|
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
|
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
|
IMAGE_PROMPT = "<|image|>"
|
|
|
|
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
|
|
|
|
|
|
def get_special_token_ids(tokenizer):
|
|
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
|
|
assert len(ids) == len(EXTRA_TOKENS)
|
|
return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
|
|
|
|
|
|
class MolmoTextKwargs(TextKwargs, total=False):
|
|
style: Optional[str]
|
|
system_prompt: Optional[str]
|
|
message_format: Optional[str]
|
|
always_start_with_space: Optional[bool]
|
|
sequence_length: Optional[int]
|
|
|
|
|
|
class MolmoProcessorKwargs(ProcessingKwargs, total=False):
|
|
text_kwargs: MolmoTextKwargs
|
|
images_kwargs: MolmoImagesKwargs
|
|
_defaults = {
|
|
"images_kwargs": {
|
|
"max_crops": 12,
|
|
"overlap_margins": [4, 4],
|
|
"base_image_input_size": [336, 336],
|
|
"image_token_length_w": 12,
|
|
"image_token_length_h": 12,
|
|
"image_patch_size": 14,
|
|
"image_padding_mask": True,
|
|
},
|
|
"text_kwargs": {
|
|
"style": "long_caption",
|
|
"system_prompt": "none",
|
|
"message_format": "role",
|
|
"always_start_with_space": True,
|
|
"sequence_length": 1536,
|
|
"padding": False,
|
|
},
|
|
}
|
|
|
|
|
|
class MolmoProcessor(ProcessorMixin):
|
|
attributes = ["image_processor", "tokenizer"]
|
|
image_processor_class = "AutoImageProcessor"
|
|
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
|
|
|
|
def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer : AutoTokenizer = None, **kwargs):
|
|
# self.image_processor = image_processor
|
|
# self.tokenizer = tokenizer
|
|
super().__init__(image_processor, tokenizer)
|
|
self._special_tokens = None
|
|
|
|
@property
|
|
def special_token_ids(self):
|
|
if self._special_tokens is None:
|
|
self._special_tokens = get_special_token_ids(self.tokenizer)
|
|
return self._special_tokens
|
|
|
|
def get_tokens_input(self, prompt, message_format, always_start_with_space):
|
|
if message_format == "none" or message_format is None:
|
|
pass
|
|
elif message_format == "role":
|
|
prompt = "User: " + prompt + " Assistant:"
|
|
else:
|
|
raise NotImplementedError(f"Message format {message_format} not implemented")
|
|
|
|
if always_start_with_space:
|
|
prompt = " " + prompt
|
|
|
|
tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
|
|
|
|
return tokens
|
|
|
|
def process(
|
|
self,
|
|
text: TextInput = None,
|
|
images: ImageInput = None,
|
|
*,
|
|
tokens: Optional[PreTokenizedInput] = None,
|
|
**kwargs: Unpack[MolmoProcessorKwargs],
|
|
):
|
|
output_kwargs = self._merge_kwargs(
|
|
MolmoProcessorKwargs,
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
**kwargs,
|
|
)
|
|
|
|
if tokens is None:
|
|
tokens = self.get_tokens_input(
|
|
text,
|
|
output_kwargs["text_kwargs"]["message_format"],
|
|
output_kwargs["text_kwargs"]["always_start_with_space"],
|
|
)
|
|
|
|
image_token_id = self.special_token_ids[IMAGE_PROMPT]
|
|
|
|
if images is not None:
|
|
if not isinstance(images, (list, tuple)):
|
|
images = [images]
|
|
image_arrays = []
|
|
for image in images:
|
|
if isinstance(image, Image):
|
|
image = image.convert("RGB")
|
|
# Handle images with EXIF orientation tags, which PIL will ignore by default
|
|
# https://github.com/python-pillow/Pillow/issues/4703
|
|
img = ImageOps.exif_transpose(image)
|
|
image_arrays.append(np.array(image))
|
|
else:
|
|
assert len(image.shape) == 3 and image.shape[-1] == 3
|
|
image_arrays.append(image.astype(np.uint8))
|
|
images = image_arrays
|
|
# For now only support inserting images at the start
|
|
image_idx = [-1]*len(images)
|
|
else:
|
|
image_idx = None
|
|
|
|
sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
|
|
|
|
image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
|
|
image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
|
|
image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
|
|
image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
|
|
out = self.image_processor.multimodal_preprocess(
|
|
images=images,
|
|
image_idx=image_idx,
|
|
tokens=np.asarray(tokens).astype(np.int32),
|
|
sequence_length=sequence_length,
|
|
image_patch_token_id=image_patch_token_id,
|
|
image_col_token_id=image_col_token_id,
|
|
image_start_token_id=image_start_token_id,
|
|
image_end_token_id=image_end_token_id,
|
|
**output_kwargs["images_kwargs"]
|
|
)
|
|
|
|
# Prepend BOS
|
|
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
|
|
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
|
decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
|
|
out["input_ids"] = decoder_input_tokens
|
|
if "image_input_idx" in out:
|
|
# Shift patch mapping up by one since we added BOS
|
|
image_input_idx = out["image_input_idx"]
|
|
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
|
|
|
for k, v in out.items():
|
|
out[k] = torch.from_numpy(v)
|
|
|
|
return out
|
|
|
|
|
|
MolmoProcessor.register_for_auto_class()
|