mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-04 03:56:16 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			97 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			97 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import unittest
 | 
						|
import base64
 | 
						|
from io import BytesIO
 | 
						|
from PIL import Image
 | 
						|
from transformers import AutoProcessor
 | 
						|
 | 
						|
from pdelfin.train.dataloader import (
 | 
						|
    build_batch_query_response_vision_dataset,
 | 
						|
)
 | 
						|
 | 
						|
from pdelfin.train.dataprep import (
 | 
						|
    prepare_data_for_qwen2_training, build_finetuning_prompt
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
class TestDataprep(unittest.TestCase):
 | 
						|
    def testTokenizationMatches(self):
 | 
						|
        ds = build_batch_query_response_vision_dataset(
 | 
						|
            query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
 | 
						|
            response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
 | 
						|
        )
 | 
						|
 | 
						|
        example = ds[0]
 | 
						|
 | 
						|
        processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
 | 
						|
 | 
						|
        full_messages = [
 | 
						|
            {
 | 
						|
                "role": "user",
 | 
						|
                "content": [
 | 
						|
                    {
 | 
						|
                        "type": "image",
 | 
						|
                        "image": example["input_prompt_image_base64"]  # Placeholder
 | 
						|
                    },
 | 
						|
                    {"type": "text", "text": build_finetuning_prompt(example["raw_page_text"])},
 | 
						|
                ],
 | 
						|
            },
 | 
						|
 | 
						|
            {
 | 
						|
                "role": "assistant",
 | 
						|
                "content": example["response"]
 | 
						|
            }
 | 
						|
        ]
 | 
						|
 | 
						|
        text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=False)
 | 
						|
 | 
						|
        # Decode image from base64
 | 
						|
        main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
 | 
						|
 | 
						|
        width, height = main_image.size
 | 
						|
        assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid"
 | 
						|
        main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
 | 
						|
 | 
						|
 | 
						|
        # Process inputs using processor
 | 
						|
        inference_inputs = processor(
 | 
						|
            text=[text],
 | 
						|
            images=[main_image],
 | 
						|
            padding=True,
 | 
						|
            return_tensors="np",
 | 
						|
        )
 | 
						|
 | 
						|
        print(inference_inputs)
 | 
						|
        print(inference_inputs["input_ids"].shape)
 | 
						|
 | 
						|
        training_inputs = prepare_data_for_qwen2_training(example, processor=processor)
 | 
						|
 | 
						|
        print(training_inputs)
 | 
						|
        print(training_inputs["input_ids"].shape)
 | 
						|
 | 
						|
        print("Original tokenization")
 | 
						|
        print(processor.tokenizer.decode(inference_inputs["input_ids"][0]))
 | 
						|
        print("\n\n")
 | 
						|
 | 
						|
        print("Assembled tokenization")
 | 
						|
        print(processor.tokenizer.decode(training_inputs["input_ids"]))
 | 
						|
        print("\n\n")
 | 
						|
 | 
						|
        # Make sure that the token streams are the same
 | 
						|
        self.assertEqual(processor.tokenizer.decode(inference_inputs["input_ids"][0]),
 | 
						|
                         processor.tokenizer.decode(training_inputs["input_ids"]))
 | 
						|
 | 
						|
        # Make sure that the labels are masked with -100s properly
 | 
						|
        # You only want the last assistant generation itself to be not -100, and thus contributing to the loss
 | 
						|
 | 
						|
        # Find the positions where labels are not -100
 | 
						|
        non_masked_positions = training_inputs['labels'] != -100
 | 
						|
 | 
						|
        # Extract the tokens at those positions
 | 
						|
        label_tokens = training_inputs['input_ids'][non_masked_positions]
 | 
						|
 | 
						|
        # Decode those tokens
 | 
						|
        decoded_labels = processor.tokenizer.decode(label_tokens)
 | 
						|
        assistant_response_with_end = example["response"] + "<|im_end|>\n"
 | 
						|
 | 
						|
        # Assert that the decoded labels match the assistant's response with <|im_end|>\n
 | 
						|
        self.assertEqual(decoded_labels, assistant_response_with_end) |