mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Startng to write molmo formatters
This commit is contained in:
parent
e65747e591
commit
bede854cd5
@ -115,3 +115,67 @@ def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image
|
||||
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
||||
}
|
||||
|
||||
def prepare_data_for_molmo_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
|
||||
if isinstance(target_longest_image_dim, list):
|
||||
target_longest_image_dim = random.choice(target_longest_image_dim)
|
||||
|
||||
if isinstance(target_anchor_text_len, list):
|
||||
target_anchor_text_len = random.choice(target_anchor_text_len)
|
||||
|
||||
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
||||
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
|
||||
|
||||
# Process the input text and image
|
||||
inputs = processor.process(
|
||||
images=[main_image],
|
||||
text=build_finetuning_prompt(anchor_text),
|
||||
)
|
||||
|
||||
# Get labels by tokenizing the output text
|
||||
labels = processor.tokenizer(example["response"], return_tensors="np")['input_ids'][0]
|
||||
# Concatenate input_ids and labels
|
||||
full_input_ids = torch.cat([inputs['input_ids'], torch.from_numpy(labels)], dim=0)
|
||||
|
||||
labels_full = torch.cat([torch.ones_like(inputs['input_ids']) * -100, torch.from_numpy(labels)], dim=0)
|
||||
|
||||
# Create a full attention mask
|
||||
attention_mask = torch.ones_like(full_input_ids)
|
||||
|
||||
# image_input_idx does not need adjustment as images are inserted before labels
|
||||
image_input_idx = inputs['image_input_idx']
|
||||
|
||||
return {
|
||||
'input_ids': full_input_ids,
|
||||
'labels': labels_full,
|
||||
'images': inputs['images'],
|
||||
'image_input_idx': image_input_idx,
|
||||
'image_masks': inputs['image_masks'],
|
||||
'attention_mask': attention_mask,
|
||||
}
|
||||
|
||||
def batch_prepare_data_for_molmo_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
|
||||
# Assume batch size 1 and process the single example
|
||||
example = {
|
||||
"local_pdf_path": batch["local_pdf_path"][0],
|
||||
"page_num": batch["page_num"][0],
|
||||
"response": batch["response"][0]
|
||||
}
|
||||
processed_example = prepare_data_for_molmo_training(
|
||||
example,
|
||||
processor,
|
||||
target_longest_image_dim=target_longest_image_dim,
|
||||
target_anchor_text_len=target_anchor_text_len
|
||||
)
|
||||
|
||||
# Return in the same format as the qwen2 function
|
||||
return {
|
||||
"input_ids": [processed_example["input_ids"]],
|
||||
"attention_mask": [processed_example["attention_mask"]],
|
||||
"labels": [processed_example["labels"]],
|
||||
"images": [processed_example["images"]],
|
||||
"image_input_idx": [processed_example["image_input_idx"]],
|
||||
"image_masks": [processed_example["image_masks"]],
|
||||
}
|
@ -31,6 +31,7 @@ dependencies = [
|
||||
"bleach",
|
||||
"markdown2",
|
||||
"filelock",
|
||||
"orjson",
|
||||
]
|
||||
license = {file = "LICENSE"}
|
||||
|
||||
@ -61,6 +62,7 @@ dev = [
|
||||
"sphinx-autodoc-typehints==1.23.3",
|
||||
"packaging",
|
||||
"necessary",
|
||||
"requests",
|
||||
]
|
||||
|
||||
train = [
|
||||
|
@ -1,16 +1,21 @@
|
||||
import unittest
|
||||
import random
|
||||
import requests
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor
|
||||
from unittest.mock import patch
|
||||
|
||||
from pdelfin.train.dataloader import (
|
||||
build_finetuning_dataset,
|
||||
)
|
||||
|
||||
from pdelfin.train.dataprep import (
|
||||
prepare_data_for_qwen2_training, build_finetuning_prompt
|
||||
prepare_data_for_qwen2_training, build_finetuning_prompt,
|
||||
prepare_data_for_molmo_training, batch_prepare_data_for_molmo_training
|
||||
)
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
@ -158,4 +163,75 @@ class TestDataprep(unittest.TestCase):
|
||||
|
||||
# Verify total adds up to 100%
|
||||
self.assertEqual(zero_count + full_count, num_iterations,
|
||||
"Total count should equal number of iterations")
|
||||
"Total count should equal number of iterations")
|
||||
|
||||
|
||||
class TestMolmoDataPrep(unittest.TestCase):
|
||||
def testMolmoDefaultSetup(self):
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
'allenai/Molmo-7B-O-0924',
|
||||
trust_remote_code=True,
|
||||
torch_dtype='auto',
|
||||
device_map='auto'
|
||||
)
|
||||
|
||||
# process the image and text
|
||||
inputs = processor.process(
|
||||
images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)],
|
||||
text="Describe this image."
|
||||
)
|
||||
|
||||
print(inputs.keys())
|
||||
print(inputs["input_ids"])
|
||||
print(processor.tokenizer.batch_decode(inputs["input_ids"]))
|
||||
|
||||
labels = processor.tokenizer("This is a page of the pdf that's the text", return_tensors="np")
|
||||
|
||||
print(labels)
|
||||
print(processor.tokenizer.batch_decode(labels["input_ids"]))
|
||||
|
||||
def testMolmoDataPrep(self):
|
||||
# Initialize the processor for Molmo
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
'allenai/Molmo-7B-O-0924',
|
||||
trust_remote_code=True,
|
||||
torch_dtype='auto',
|
||||
device_map='auto'
|
||||
)
|
||||
|
||||
# Create a mock example
|
||||
example = {
|
||||
"local_pdf_path": os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"),
|
||||
"page_num": 1,
|
||||
"response": "This is the response text."
|
||||
}
|
||||
|
||||
# Define target dimensions and anchor text lengths
|
||||
target_longest_image_dim = [1024]
|
||||
target_anchor_text_len = [0, 6000]
|
||||
|
||||
# Set a fixed seed for reproducibility
|
||||
random.seed(42)
|
||||
|
||||
# Mock the functions that require actual PDF files
|
||||
with patch('pdelfin.prompts.anchor.get_anchor_text') as mock_get_anchor_text, \
|
||||
patch('pdelfin.data.renderpdf.render_pdf_to_base64png') as mock_render_pdf_to_base64png:
|
||||
|
||||
# Set return values for the mocked functions
|
||||
mock_get_anchor_text.return_value = "This is the anchor text."
|
||||
# Create a red square image and encode it in base64
|
||||
img = Image.new('RGB', (100, 100), color='red')
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
mock_render_pdf_to_base64png.return_value = img_str
|
||||
|
||||
# Process the example using the prepare_data_for_molmo_training function
|
||||
processed_example = prepare_data_for_molmo_training(
|
||||
example,
|
||||
processor,
|
||||
target_longest_image_dim=target_longest_image_dim,
|
||||
target_anchor_text_len=target_anchor_text_len
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user