olmocr/tests/test_dataprep.py

329 lines
15 KiB
Python
Raw Normal View History

2024-10-30 13:24:11 -07:00
import base64
import os
2025-01-29 15:25:10 -08:00
import random
2024-10-23 22:53:05 +00:00
import re
2025-01-29 15:25:10 -08:00
import unittest
from io import BytesIO
2025-01-29 15:25:10 -08:00
from unittest.mock import patch
import numpy as np
2025-02-14 20:42:19 +00:00
import pytest
2025-01-29 15:25:10 -08:00
import requests
import torch
from PIL import Image
2025-01-29 15:25:10 -08:00
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
2025-01-29 15:25:10 -08:00
from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig
2025-02-25 08:57:02 -08:00
from olmocr.train.dataloader import build_finetuning_dataset
from olmocr.train.dataprep import (
2025-01-29 15:25:10 -08:00
batch_prepare_data_for_molmo_training,
build_finetuning_prompt,
prepare_data_for_molmo_training,
prepare_data_for_qwen2_training,
)
from olmocr.train.utils import make_dataset
2025-01-29 15:25:10 -08:00
2025-02-14 20:42:19 +00:00
@pytest.mark.nonci
class TestDataprep(unittest.TestCase):
def testFullDataloader(self):
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
config = TrainConfig(
2025-01-29 15:30:39 -08:00
train_data=DataConfig(
seed=42,
sources=[
SourceConfig(
name="eval_test",
target_longest_image_dim=1024,
target_anchor_text_len=6000,
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
],
),
valid_data=DataConfig(
seed=42,
sources=[
SourceConfig(
name="eval_test",
target_longest_image_dim=1024,
target_anchor_text_len=6000,
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
],
),
)
2025-01-29 15:30:39 -08:00
train_dataset, valid_dataset = make_dataset(config, processor)
2024-10-10 16:57:00 +00:00
im_end_token_ids = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
2025-01-29 15:30:39 -08:00
# train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False)
for entry in train_dataset:
2025-01-29 15:30:39 -08:00
print({x: (y.shape, y.dtype) for (x, y) in entry.items()})
2024-10-10 16:57:00 +00:00
self.assertEqual(entry["input_ids"].dtype, np.int64)
self.assertEqual(entry["attention_mask"].dtype, np.int64)
self.assertEqual(entry["labels"].dtype, np.int64)
self.assertEqual(entry["pixel_values"].dtype, np.float32)
self.assertEqual(entry["image_grid_thw"].dtype, np.int64)
2025-01-29 15:30:39 -08:00
2024-10-10 16:57:00 +00:00
# Extract input_ids and labels
input_ids = entry["input_ids"]
labels = entry["labels"]
# 1. Verify that the last token is the end token
# Ensure input_ids is long enough
self.assertTrue(len(input_ids) >= len(im_end_token_ids), "Input IDs are shorter than the end token sequence.")
# Compare the last tokens of input_ids with im_end_token_ids
self.assertEqual(
2025-01-29 15:30:39 -08:00
input_ids[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last tokens of input_ids do not match the end token sequence."
2024-10-10 16:57:00 +00:00
)
# 2. Ensure labels are masked correctly and match input_ids after the mask
# Find where labels start being non-masked (-100 is the mask value)
label_indices = np.where(labels != -100)[0]
# There should be at least one label that is not masked
self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.")
first_label_index = label_indices[0]
# Ensure the masked portion is at least 10 tokens long
self.assertTrue(first_label_index >= 10, "Masked portion of labels is less than 10 tokens.")
# Check that all values before first_label_index are -100
2025-01-29 15:30:39 -08:00
self.assertTrue(np.all(labels[:first_label_index] == -100), "Labels before the first unmasked token are not all -100.")
2024-10-10 16:57:00 +00:00
# Check that the unmasked labels match the corresponding input_ids
self.assertTrue(
2025-01-29 15:30:39 -08:00
np.array_equal(labels[first_label_index:], input_ids[first_label_index:]), "Unmasked labels do not match the corresponding input_ids."
2024-10-10 16:57:00 +00:00
)
2024-10-10 16:57:00 +00:00
# Optionally, verify that the last unmasked tokens in labels match the end token IDs
unmasked_labels = labels[labels != -100]
self.assertEqual(
2025-01-29 15:30:39 -08:00
unmasked_labels[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last unmasked tokens in labels do not match the end token sequence."
2024-10-10 16:57:00 +00:00
)
2024-10-23 22:53:05 +00:00
def testListTargetAnchorLength(self):
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
config = TrainConfig(
2025-01-29 15:30:39 -08:00
train_data=DataConfig(
seed=42,
sources=[
SourceConfig(
name="eval_test",
target_longest_image_dim=1024,
target_anchor_text_len=[0, 6000], # Only 0 and 6000
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
],
),
valid_data=DataConfig(
seed=42,
sources=[
SourceConfig(
name="eval_test",
target_longest_image_dim=1024,
target_anchor_text_len=[0, 6000], # Only 0 and 6000
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
],
),
2024-10-23 22:53:05 +00:00
)
2025-01-29 15:30:39 -08:00
2024-10-23 22:53:05 +00:00
# Set a fixed seed for reproducibility
random.seed(42)
train_dataset, valid_dataset = make_dataset(config, processor)
zero_count = 0
full_count = 0
num_iterations = 100
2025-01-29 15:30:39 -08:00
2024-10-23 22:53:05 +00:00
for i in range(num_iterations):
entry = train_dataset[0] # Get the first entry repeatedly
2025-01-29 15:30:39 -08:00
2024-10-23 22:53:05 +00:00
# Basic type checks
self.assertEqual(entry["input_ids"].dtype, np.int64)
self.assertEqual(entry["attention_mask"].dtype, np.int64)
self.assertEqual(entry["labels"].dtype, np.int64)
self.assertEqual(entry["pixel_values"].dtype, np.float32)
self.assertEqual(entry["image_grid_thw"].dtype, np.int64)
2025-01-29 15:30:39 -08:00
2024-10-23 22:53:05 +00:00
# Get the input text before the response
# Find where labels start being non-masked (-100 is the mask value)
label_indices = np.where(entry["labels"] != -100)[0]
first_label_index = label_indices[0] if len(label_indices) > 0 else len(entry["input_ids"])
2025-01-29 15:30:39 -08:00
2024-10-23 22:53:05 +00:00
# Decode the input portion to check the prompt
input_text = processor.tokenizer.decode(entry["input_ids"][:first_label_index])
2025-01-29 15:30:39 -08:00
pattern = r"RAW_TEXT_START\nPage dimensions: (\d+\.?\d*)x(\d+\.?\d*)\s+RAW_TEXT_END"
2024-10-23 22:53:05 +00:00
match = re.search(pattern, input_text, flags=re.MULTILINE)
if match:
zero_count += 1
else:
full_count += 1
2025-01-29 15:30:39 -08:00
2024-10-23 22:53:05 +00:00
# Verify the distribution: should be roughly 10% zero-length, 90% full-length
zero_ratio = zero_count / num_iterations
full_ratio = full_count / num_iterations
print(zero_count, full_count)
2025-01-29 15:30:39 -08:00
self.assertTrue(0.45 <= zero_ratio <= 0.55, f"Expected zero-length ratio around 0.5, got {zero_ratio:.2f}")
self.assertTrue(0.45 <= full_ratio <= 0.55, f"Expected full-length ratio around 0.5, got {full_ratio:.2f}")
2024-10-23 22:53:05 +00:00
# Verify total adds up to 100%
2025-01-29 15:30:39 -08:00
self.assertEqual(zero_count + full_count, num_iterations, "Total count should equal number of iterations")
2024-10-30 13:24:11 -07:00
2025-02-14 20:46:55 +00:00
@pytest.mark.nonci
2024-10-30 13:24:11 -07:00
class TestMolmoDataPrep(unittest.TestCase):
def testMolmoDefaultSetup(self):
2025-01-29 15:30:39 -08:00
processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
2024-10-30 13:24:11 -07:00
# process the image and text
2025-01-29 15:30:39 -08:00
inputs = processor.process(images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)], text="Describe this image.")
2024-10-30 13:24:11 -07:00
print(inputs.keys())
print(inputs["input_ids"])
print(processor.tokenizer.batch_decode(inputs["input_ids"]))
labels = processor.tokenizer("This is a page of the pdf that's the text", return_tensors="np")
print(labels)
print(processor.tokenizer.batch_decode(labels["input_ids"]))
def testMolmoDataPrep(self):
# Initialize the processor for Molmo
2025-01-29 15:30:39 -08:00
processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
2024-10-30 13:24:11 -07:00
# Create a mock example
example = {
2025-01-29 15:30:39 -08:00
"local_pdf_path": os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"),
2024-10-30 13:24:11 -07:00
"page_num": 1,
2025-01-29 15:30:39 -08:00
"response": "This is the response text.",
2024-10-30 13:24:11 -07:00
}
# Define target dimensions and anchor text lengths
target_longest_image_dim = [1024]
target_anchor_text_len = [0, 6000]
# Set a fixed seed for reproducibility
random.seed(42)
# Mock the functions that require actual PDF files
2025-01-29 15:30:39 -08:00
with (
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
):
2024-10-30 13:24:11 -07:00
# Set return values for the mocked functions
mock_get_anchor_text.return_value = "This is the anchor text."
# Create a red square image and encode it in base64
2025-01-29 15:30:39 -08:00
img = Image.new("RGB", (100, 100), color="red")
2024-10-30 13:24:11 -07:00
buffered = BytesIO()
img.save(buffered, format="PNG")
2025-01-29 15:30:39 -08:00
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
2024-10-30 13:24:11 -07:00
mock_render_pdf_to_base64png.return_value = img_str
# Process the example using the prepare_data_for_molmo_training function
processed_example = prepare_data_for_molmo_training(
2025-01-29 15:30:39 -08:00
example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
2024-10-30 13:24:11 -07:00
)
# Basic type checks
2025-01-29 15:30:39 -08:00
self.assertIsInstance(processed_example["input_ids"], torch.Tensor, "input_ids should be a torch.Tensor")
self.assertIsInstance(processed_example["attention_mask"], torch.Tensor, "attention_mask should be a torch.Tensor")
self.assertIsInstance(processed_example["labels"], torch.Tensor, "labels should be a torch.Tensor")
self.assertIsInstance(processed_example["images"], torch.Tensor, "images should be a torch.Tensor")
self.assertIsInstance(processed_example["image_input_idx"], torch.Tensor, "image_input_idx should be a torch.Tensor")
self.assertIsInstance(processed_example["image_masks"], torch.Tensor, "image_masks should be a torch.Tensor")
# Check tensor dimensions
2025-01-29 15:30:39 -08:00
self.assertEqual(len(processed_example["input_ids"].shape), 1, "input_ids should be a 1D tensor")
self.assertEqual(
processed_example["input_ids"].shape, processed_example["attention_mask"].shape, "input_ids and attention_mask should have the same shape"
)
self.assertEqual(processed_example["input_ids"].shape, processed_example["labels"].shape, "input_ids and labels should have the same shape")
# Verify label masking
# Find where labels start being non-masked (-100 is the mask value)
label_indices = torch.where(processed_example["labels"] != -100)[0]
2025-01-29 15:30:39 -08:00
# There should be at least one label that is not masked
2025-01-29 15:30:39 -08:00
self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.")
first_label_index = label_indices[0]
2025-01-29 15:30:39 -08:00
# Ensure the masked portion is reasonable (at least a few tokens long)
2025-01-29 15:30:39 -08:00
self.assertTrue(first_label_index >= 5, "Masked portion of labels is too short")
# Check that all values before first_label_index are -100
2025-01-29 15:30:39 -08:00
self.assertTrue(torch.all(processed_example["labels"][:first_label_index] == -100), "Labels before the first unmasked token are not all -100.")
# Verify attention mask
2025-01-29 15:30:39 -08:00
self.assertTrue(torch.all(processed_example["attention_mask"] == 1), "All attention mask values should be 1")
# Verify image input indices
self.assertTrue(
torch.all(processed_example["image_input_idx"] < len(processed_example["input_ids"])),
2025-01-29 15:30:39 -08:00
"Image input indices should be within the range of input_ids length",
)
# Decode and verify content structure
decoded_input = processor.tokenizer.decode(processed_example["input_ids"])
2025-01-29 15:30:39 -08:00
self.assertIn("This is the anchor text", decoded_input, "Anchor text should be present in the decoded input")
# Verify that unmasked labels decode to the response text
unmasked_labels = processed_example["labels"][processed_example["labels"] != -100]
decoded_labels = processor.tokenizer.decode(unmasked_labels)
2025-01-29 15:30:39 -08:00
self.assertIn("This is the response text", decoded_labels, "Response text should be present in the decoded labels")
def testBatchMolmoDataPrep(self):
"""Test the batch preparation function for Molmo"""
2025-01-29 15:30:39 -08:00
processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
# Create a mock batch
batch = {
"local_pdf_path": [os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf")],
"page_num": [1],
2025-01-29 15:30:39 -08:00
"response": ["This is the response text."],
}
target_longest_image_dim = [1024]
target_anchor_text_len = [0, 6000]
# Mock the necessary functions
2025-01-29 15:30:39 -08:00
with (
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
):
mock_get_anchor_text.return_value = "This is the anchor text."
2025-01-29 15:30:39 -08:00
img = Image.new("RGB", (100, 100), color="red")
buffered = BytesIO()
img.save(buffered, format="PNG")
2025-01-29 15:30:39 -08:00
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
mock_render_pdf_to_base64png.return_value = img_str
# Process the batch
processed_batch = batch_prepare_data_for_molmo_training(
2025-01-29 15:30:39 -08:00
batch, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
)
# Verify batch structure
2025-01-29 15:30:39 -08:00
self.assertEqual(len(processed_batch["input_ids"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["attention_mask"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["labels"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["images"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["image_input_idx"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["image_masks"]), 1, "Batch size should be 1")
# Verify the first item in the batch
first_item = {k: v[0] for k, v in processed_batch.items()}
2025-01-29 15:30:39 -08:00
self.assertIsInstance(first_item["input_ids"], torch.Tensor, "Batch item should contain torch.Tensor")
self.assertTrue(torch.all(first_item["attention_mask"] == 1), "All attention mask values should be 1")