mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
176 lines
7.4 KiB
Python
176 lines
7.4 KiB
Python
import unittest
|
|
import base64
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
from transformers import AutoProcessor
|
|
|
|
from pdelfin.train.dataloader import (
|
|
build_batch_query_response_vision_dataset,
|
|
)
|
|
|
|
from pdelfin.train.dataprep import (
|
|
prepare_data_for_qwen2_training, build_finetuning_prompt
|
|
)
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from torch.utils.data import DataLoader
|
|
from pdelfin.train.utils import make_dataset
|
|
from pdelfin.train.core.config import TrainConfig, DataConfig, SourceConfig
|
|
|
|
class TestDataprep(unittest.TestCase):
|
|
def testFullDataloader(self):
|
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
config = TrainConfig(
|
|
train_data=DataConfig(seed=42,
|
|
sources=[SourceConfig(name="eval_test",
|
|
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
|
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json")]),
|
|
|
|
valid_data=DataConfig(seed=42,
|
|
sources=[SourceConfig(name="eval_test",
|
|
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
|
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json")])
|
|
)
|
|
train_dataset, valid_dataset = make_dataset(config, processor)
|
|
|
|
im_end_token_ids = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
|
|
|
|
|
|
#train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False)
|
|
for entry in train_dataset:
|
|
print({x: (y.shape, y.dtype) for (x,y) in entry.items()})
|
|
|
|
self.assertEqual(entry["input_ids"].dtype, np.int64)
|
|
self.assertEqual(entry["attention_mask"].dtype, np.int64)
|
|
self.assertEqual(entry["labels"].dtype, np.int64)
|
|
self.assertEqual(entry["pixel_values"].dtype, np.float32)
|
|
self.assertEqual(entry["image_grid_thw"].dtype, np.int64)
|
|
|
|
# Extract input_ids and labels
|
|
input_ids = entry["input_ids"]
|
|
labels = entry["labels"]
|
|
|
|
# 1. Verify that the last token is the end token
|
|
# Ensure input_ids is long enough
|
|
self.assertTrue(len(input_ids) >= len(im_end_token_ids), "Input IDs are shorter than the end token sequence.")
|
|
|
|
# Compare the last tokens of input_ids with im_end_token_ids
|
|
self.assertEqual(
|
|
input_ids[-len(im_end_token_ids):].tolist(),
|
|
im_end_token_ids,
|
|
"The last tokens of input_ids do not match the end token sequence."
|
|
)
|
|
|
|
# 2. Ensure labels are masked correctly and match input_ids after the mask
|
|
# Find where labels start being non-masked (-100 is the mask value)
|
|
label_indices = np.where(labels != -100)[0]
|
|
|
|
# There should be at least one label that is not masked
|
|
self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.")
|
|
|
|
first_label_index = label_indices[0]
|
|
|
|
# Ensure the masked portion is at least 10 tokens long
|
|
self.assertTrue(first_label_index >= 10, "Masked portion of labels is less than 10 tokens.")
|
|
|
|
# Check that all values before first_label_index are -100
|
|
self.assertTrue(
|
|
np.all(labels[:first_label_index] == -100),
|
|
"Labels before the first unmasked token are not all -100."
|
|
)
|
|
|
|
# Check that the unmasked labels match the corresponding input_ids
|
|
self.assertTrue(
|
|
np.array_equal(labels[first_label_index:], input_ids[first_label_index:]),
|
|
"Unmasked labels do not match the corresponding input_ids."
|
|
)
|
|
|
|
# Optionally, verify that the last unmasked tokens in labels match the end token IDs
|
|
unmasked_labels = labels[labels != -100]
|
|
self.assertEqual(
|
|
unmasked_labels[-len(im_end_token_ids):].tolist(),
|
|
im_end_token_ids,
|
|
"The last unmasked tokens in labels do not match the end token sequence."
|
|
)
|
|
|
|
def testTokenizationMatches(self):
|
|
ds = build_batch_query_response_vision_dataset(
|
|
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
|
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
|
|
)
|
|
|
|
example = ds[0]
|
|
|
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
|
|
full_messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image",
|
|
"image": example["input_prompt_image_base64"] # Placeholder
|
|
},
|
|
{"type": "text", "text": build_finetuning_prompt(example["raw_page_text"])},
|
|
],
|
|
},
|
|
|
|
{
|
|
"role": "assistant",
|
|
"content": example["response"]
|
|
}
|
|
]
|
|
|
|
text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=False)
|
|
|
|
# Decode image from base64
|
|
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
|
|
|
width, height = main_image.size
|
|
assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid"
|
|
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
|
|
|
|
|
|
# Process inputs using processor
|
|
inference_inputs = processor(
|
|
text=[text],
|
|
images=[main_image],
|
|
padding=True,
|
|
return_tensors="np",
|
|
)
|
|
|
|
print(inference_inputs)
|
|
print(inference_inputs["input_ids"].shape)
|
|
|
|
training_inputs = prepare_data_for_qwen2_training(example, processor=processor)
|
|
|
|
print(training_inputs)
|
|
print(training_inputs["input_ids"].shape)
|
|
|
|
print("Original tokenization")
|
|
print(processor.tokenizer.decode(inference_inputs["input_ids"][0]))
|
|
print("\n\n")
|
|
|
|
print("Assembled tokenization")
|
|
print(processor.tokenizer.decode(training_inputs["input_ids"]))
|
|
print("\n\n")
|
|
|
|
# Make sure that the token streams are the same
|
|
self.assertEqual(processor.tokenizer.decode(inference_inputs["input_ids"][0]),
|
|
processor.tokenizer.decode(training_inputs["input_ids"]))
|
|
|
|
# Make sure that the labels are masked with -100s properly
|
|
# You only want the last assistant generation itself to be not -100, and thus contributing to the loss
|
|
|
|
# Find the positions where labels are not -100
|
|
non_masked_positions = training_inputs['labels'] != -100
|
|
|
|
# Extract the tokens at those positions
|
|
label_tokens = training_inputs['input_ids'][non_masked_positions]
|
|
|
|
# Decode those tokens
|
|
decoded_labels = processor.tokenizer.decode(label_tokens)
|
|
assistant_response_with_end = example["response"] + "<|im_end|>\n"
|
|
|
|
# Assert that the decoded labels match the assistant's response with <|im_end|>\n
|
|
self.assertEqual(decoded_labels, assistant_response_with_end) |