import base64 import os import random import re import unittest from io import BytesIO from unittest.mock import patch import numpy as np import pytest import requests import torch from PIL import Image from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoProcessor from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig from olmocr.train.dataloader import build_finetuning_dataset from olmocr.train.dataprep import ( batch_prepare_data_for_molmo_training, build_finetuning_prompt, prepare_data_for_molmo_training, prepare_data_for_qwen2_training, ) from olmocr.train.utils import make_dataset @pytest.mark.nonci class TestDataprep(unittest.TestCase): def testFullDataloader(self): processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") config = TrainConfig( train_data=DataConfig( seed=42, sources=[ SourceConfig( name="eval_test", target_longest_image_dim=1024, target_anchor_text_len=6000, response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", ) ], ), valid_data=DataConfig( seed=42, sources=[ SourceConfig( name="eval_test", target_longest_image_dim=1024, target_anchor_text_len=6000, response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", ) ], ), ) train_dataset, valid_dataset = make_dataset(config, processor) im_end_token_ids = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"] # train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False) for entry in train_dataset: print({x: (y.shape, y.dtype) for (x, y) in entry.items()}) self.assertEqual(entry["input_ids"].dtype, np.int64) self.assertEqual(entry["attention_mask"].dtype, np.int64) self.assertEqual(entry["labels"].dtype, np.int64) self.assertEqual(entry["pixel_values"].dtype, np.float32) self.assertEqual(entry["image_grid_thw"].dtype, np.int64) # Extract input_ids and labels input_ids = entry["input_ids"] labels = entry["labels"] # 1. Verify that the last token is the end token # Ensure input_ids is long enough self.assertTrue(len(input_ids) >= len(im_end_token_ids), "Input IDs are shorter than the end token sequence.") # Compare the last tokens of input_ids with im_end_token_ids self.assertEqual( input_ids[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last tokens of input_ids do not match the end token sequence." ) # 2. Ensure labels are masked correctly and match input_ids after the mask # Find where labels start being non-masked (-100 is the mask value) label_indices = np.where(labels != -100)[0] # There should be at least one label that is not masked self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.") first_label_index = label_indices[0] # Ensure the masked portion is at least 10 tokens long self.assertTrue(first_label_index >= 10, "Masked portion of labels is less than 10 tokens.") # Check that all values before first_label_index are -100 self.assertTrue(np.all(labels[:first_label_index] == -100), "Labels before the first unmasked token are not all -100.") # Check that the unmasked labels match the corresponding input_ids self.assertTrue( np.array_equal(labels[first_label_index:], input_ids[first_label_index:]), "Unmasked labels do not match the corresponding input_ids." ) # Optionally, verify that the last unmasked tokens in labels match the end token IDs unmasked_labels = labels[labels != -100] self.assertEqual( unmasked_labels[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last unmasked tokens in labels do not match the end token sequence." ) def testListTargetAnchorLength(self): processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") config = TrainConfig( train_data=DataConfig( seed=42, sources=[ SourceConfig( name="eval_test", target_longest_image_dim=1024, target_anchor_text_len=[0, 6000], # Only 0 and 6000 response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", ) ], ), valid_data=DataConfig( seed=42, sources=[ SourceConfig( name="eval_test", target_longest_image_dim=1024, target_anchor_text_len=[0, 6000], # Only 0 and 6000 response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", ) ], ), ) # Set a fixed seed for reproducibility random.seed(42) train_dataset, valid_dataset = make_dataset(config, processor) zero_count = 0 full_count = 0 num_iterations = 100 for i in range(num_iterations): entry = train_dataset[0] # Get the first entry repeatedly # Basic type checks self.assertEqual(entry["input_ids"].dtype, np.int64) self.assertEqual(entry["attention_mask"].dtype, np.int64) self.assertEqual(entry["labels"].dtype, np.int64) self.assertEqual(entry["pixel_values"].dtype, np.float32) self.assertEqual(entry["image_grid_thw"].dtype, np.int64) # Get the input text before the response # Find where labels start being non-masked (-100 is the mask value) label_indices = np.where(entry["labels"] != -100)[0] first_label_index = label_indices[0] if len(label_indices) > 0 else len(entry["input_ids"]) # Decode the input portion to check the prompt input_text = processor.tokenizer.decode(entry["input_ids"][:first_label_index]) pattern = r"RAW_TEXT_START\nPage dimensions: (\d+\.?\d*)x(\d+\.?\d*)\s+RAW_TEXT_END" match = re.search(pattern, input_text, flags=re.MULTILINE) if match: zero_count += 1 else: full_count += 1 # Verify the distribution: should be roughly 10% zero-length, 90% full-length zero_ratio = zero_count / num_iterations full_ratio = full_count / num_iterations print(zero_count, full_count) self.assertTrue(0.45 <= zero_ratio <= 0.55, f"Expected zero-length ratio around 0.5, got {zero_ratio:.2f}") self.assertTrue(0.45 <= full_ratio <= 0.55, f"Expected full-length ratio around 0.5, got {full_ratio:.2f}") # Verify total adds up to 100% self.assertEqual(zero_count + full_count, num_iterations, "Total count should equal number of iterations") @pytest.mark.nonci class TestMolmoDataPrep(unittest.TestCase): def testMolmoDefaultSetup(self): processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto") # process the image and text inputs = processor.process(images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)], text="Describe this image.") print(inputs.keys()) print(inputs["input_ids"]) print(processor.tokenizer.batch_decode(inputs["input_ids"])) labels = processor.tokenizer("This is a page of the pdf that's the text", return_tensors="np") print(labels) print(processor.tokenizer.batch_decode(labels["input_ids"])) def testMolmoDataPrep(self): # Initialize the processor for Molmo processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto") # Create a mock example example = { "local_pdf_path": os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"), "page_num": 1, "response": "This is the response text.", } # Define target dimensions and anchor text lengths target_longest_image_dim = [1024] target_anchor_text_len = [0, 6000] # Set a fixed seed for reproducibility random.seed(42) # Mock the functions that require actual PDF files with ( patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text, patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png, ): # Set return values for the mocked functions mock_get_anchor_text.return_value = "This is the anchor text." # Create a red square image and encode it in base64 img = Image.new("RGB", (100, 100), color="red") buffered = BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") mock_render_pdf_to_base64png.return_value = img_str # Process the example using the prepare_data_for_molmo_training function processed_example = prepare_data_for_molmo_training( example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len ) # Basic type checks self.assertIsInstance(processed_example["input_ids"], torch.Tensor, "input_ids should be a torch.Tensor") self.assertIsInstance(processed_example["attention_mask"], torch.Tensor, "attention_mask should be a torch.Tensor") self.assertIsInstance(processed_example["labels"], torch.Tensor, "labels should be a torch.Tensor") self.assertIsInstance(processed_example["images"], torch.Tensor, "images should be a torch.Tensor") self.assertIsInstance(processed_example["image_input_idx"], torch.Tensor, "image_input_idx should be a torch.Tensor") self.assertIsInstance(processed_example["image_masks"], torch.Tensor, "image_masks should be a torch.Tensor") # Check tensor dimensions self.assertEqual(len(processed_example["input_ids"].shape), 1, "input_ids should be a 1D tensor") self.assertEqual( processed_example["input_ids"].shape, processed_example["attention_mask"].shape, "input_ids and attention_mask should have the same shape" ) self.assertEqual(processed_example["input_ids"].shape, processed_example["labels"].shape, "input_ids and labels should have the same shape") # Verify label masking # Find where labels start being non-masked (-100 is the mask value) label_indices = torch.where(processed_example["labels"] != -100)[0] # There should be at least one label that is not masked self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.") first_label_index = label_indices[0] # Ensure the masked portion is reasonable (at least a few tokens long) self.assertTrue(first_label_index >= 5, "Masked portion of labels is too short") # Check that all values before first_label_index are -100 self.assertTrue(torch.all(processed_example["labels"][:first_label_index] == -100), "Labels before the first unmasked token are not all -100.") # Verify attention mask self.assertTrue(torch.all(processed_example["attention_mask"] == 1), "All attention mask values should be 1") # Verify image input indices self.assertTrue( torch.all(processed_example["image_input_idx"] < len(processed_example["input_ids"])), "Image input indices should be within the range of input_ids length", ) # Decode and verify content structure decoded_input = processor.tokenizer.decode(processed_example["input_ids"]) self.assertIn("This is the anchor text", decoded_input, "Anchor text should be present in the decoded input") # Verify that unmasked labels decode to the response text unmasked_labels = processed_example["labels"][processed_example["labels"] != -100] decoded_labels = processor.tokenizer.decode(unmasked_labels) self.assertIn("This is the response text", decoded_labels, "Response text should be present in the decoded labels") def testBatchMolmoDataPrep(self): """Test the batch preparation function for Molmo""" processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto") # Create a mock batch batch = { "local_pdf_path": [os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf")], "page_num": [1], "response": ["This is the response text."], } target_longest_image_dim = [1024] target_anchor_text_len = [0, 6000] # Mock the necessary functions with ( patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text, patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png, ): mock_get_anchor_text.return_value = "This is the anchor text." img = Image.new("RGB", (100, 100), color="red") buffered = BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") mock_render_pdf_to_base64png.return_value = img_str # Process the batch processed_batch = batch_prepare_data_for_molmo_training( batch, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len ) # Verify batch structure self.assertEqual(len(processed_batch["input_ids"]), 1, "Batch size should be 1") self.assertEqual(len(processed_batch["attention_mask"]), 1, "Batch size should be 1") self.assertEqual(len(processed_batch["labels"]), 1, "Batch size should be 1") self.assertEqual(len(processed_batch["images"]), 1, "Batch size should be 1") self.assertEqual(len(processed_batch["image_input_idx"]), 1, "Batch size should be 1") self.assertEqual(len(processed_batch["image_masks"]), 1, "Batch size should be 1") # Verify the first item in the batch first_item = {k: v[0] for k, v in processed_batch.items()} self.assertIsInstance(first_item["input_ids"], torch.Tensor, "Batch item should contain torch.Tensor") self.assertTrue(torch.all(first_item["attention_mask"] == 1), "All attention mask values should be 1")