Startng to write molmo formatters

This commit is contained in:
Jake Poznanski 2024-10-30 13:24:11 -07:00
parent e65747e591
commit bede854cd5
3 changed files with 144 additions and 2 deletions

View File

@ -115,3 +115,67 @@ def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}
def prepare_data_for_molmo_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
if isinstance(target_longest_image_dim, list):
target_longest_image_dim = random.choice(target_longest_image_dim)
if isinstance(target_anchor_text_len, list):
target_anchor_text_len = random.choice(target_anchor_text_len)
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
# Process the input text and image
inputs = processor.process(
images=[main_image],
text=build_finetuning_prompt(anchor_text),
)
# Get labels by tokenizing the output text
labels = processor.tokenizer(example["response"], return_tensors="np")['input_ids'][0]
# Concatenate input_ids and labels
full_input_ids = torch.cat([inputs['input_ids'], torch.from_numpy(labels)], dim=0)
labels_full = torch.cat([torch.ones_like(inputs['input_ids']) * -100, torch.from_numpy(labels)], dim=0)
# Create a full attention mask
attention_mask = torch.ones_like(full_input_ids)
# image_input_idx does not need adjustment as images are inserted before labels
image_input_idx = inputs['image_input_idx']
return {
'input_ids': full_input_ids,
'labels': labels_full,
'images': inputs['images'],
'image_input_idx': image_input_idx,
'image_masks': inputs['image_masks'],
'attention_mask': attention_mask,
}
def batch_prepare_data_for_molmo_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
# Assume batch size 1 and process the single example
example = {
"local_pdf_path": batch["local_pdf_path"][0],
"page_num": batch["page_num"][0],
"response": batch["response"][0]
}
processed_example = prepare_data_for_molmo_training(
example,
processor,
target_longest_image_dim=target_longest_image_dim,
target_anchor_text_len=target_anchor_text_len
)
# Return in the same format as the qwen2 function
return {
"input_ids": [processed_example["input_ids"]],
"attention_mask": [processed_example["attention_mask"]],
"labels": [processed_example["labels"]],
"images": [processed_example["images"]],
"image_input_idx": [processed_example["image_input_idx"]],
"image_masks": [processed_example["image_masks"]],
}

View File

@ -31,6 +31,7 @@ dependencies = [
"bleach",
"markdown2",
"filelock",
"orjson",
]
license = {file = "LICENSE"}
@ -61,6 +62,7 @@ dev = [
"sphinx-autodoc-typehints==1.23.3",
"packaging",
"necessary",
"requests",
]
train = [

View File

@ -1,16 +1,21 @@
import unittest
import random
import requests
import base64
import os
import re
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor
from unittest.mock import patch
from pdelfin.train.dataloader import (
build_finetuning_dataset,
)
from pdelfin.train.dataprep import (
prepare_data_for_qwen2_training, build_finetuning_prompt
prepare_data_for_qwen2_training, build_finetuning_prompt,
prepare_data_for_molmo_training, batch_prepare_data_for_molmo_training
)
import numpy as np
from tqdm import tqdm
@ -158,4 +163,75 @@ class TestDataprep(unittest.TestCase):
# Verify total adds up to 100%
self.assertEqual(zero_count + full_count, num_iterations,
"Total count should equal number of iterations")
"Total count should equal number of iterations")
class TestMolmoDataPrep(unittest.TestCase):
def testMolmoDefaultSetup(self):
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-O-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
# process the image and text
inputs = processor.process(
images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)],
text="Describe this image."
)
print(inputs.keys())
print(inputs["input_ids"])
print(processor.tokenizer.batch_decode(inputs["input_ids"]))
labels = processor.tokenizer("This is a page of the pdf that's the text", return_tensors="np")
print(labels)
print(processor.tokenizer.batch_decode(labels["input_ids"]))
def testMolmoDataPrep(self):
# Initialize the processor for Molmo
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-O-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
# Create a mock example
example = {
"local_pdf_path": os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"),
"page_num": 1,
"response": "This is the response text."
}
# Define target dimensions and anchor text lengths
target_longest_image_dim = [1024]
target_anchor_text_len = [0, 6000]
# Set a fixed seed for reproducibility
random.seed(42)
# Mock the functions that require actual PDF files
with patch('pdelfin.prompts.anchor.get_anchor_text') as mock_get_anchor_text, \
patch('pdelfin.data.renderpdf.render_pdf_to_base64png') as mock_render_pdf_to_base64png:
# Set return values for the mocked functions
mock_get_anchor_text.return_value = "This is the anchor text."
# Create a red square image and encode it in base64
img = Image.new('RGB', (100, 100), color='red')
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
mock_render_pdf_to_base64png.return_value = img_str
# Process the example using the prepare_data_for_molmo_training function
processed_example = prepare_data_for_molmo_training(
example,
processor,
target_longest_image_dim=target_longest_image_dim,
target_anchor_text_len=target_anchor_text_len
)