diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index 9d09695..b9a63b4 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -12,11 +12,17 @@ from pypdf import PdfReader from tqdm import tqdm from dataclasses import dataclass, fields from abc import ABC, abstractmethod +import numpy as np from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt from olmocr.prompts.anchor import get_anchor_text +try: + import numpy as np +except ImportError: + np = None + # Type alias for samples Sample: TypeAlias = Dict[str, Any] @@ -299,6 +305,67 @@ class InstructMessages(PipelineStep): return sample +@dataclass(frozen=True, slots=True) +class Tokenizer(PipelineStep): + """Tokenizes messages and creates training labels with proper masking.""" + processor: Any # The model processor (e.g., AutoProcessor) + masking_index: int = -100 + + def __call__(self, sample: Sample) -> Sample: + """Tokenize messages and create labels for training.""" + if np is None: + raise ImportError("numpy is required for Tokenizer step") + + messages = sample["messages"] + main_image = sample["image"] + + # Apply chat template to full conversation + text = self.processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False # Don't add prompt since we have the response + ) + + # Process everything together + inputs = self.processor( + text=[text], + images=[main_image], + padding=True, + return_tensors="np", + ) + + # Create labels by copying input_ids and masking the prompt portion + labels = inputs.input_ids.copy() + + # Find where the assistant response starts + # This assumes the processor adds some delimiter between user and assistant + # You might need to adjust based on your specific chat template + + assistant_token = self.processor.tokenizer.encode("assistant", add_special_tokens=False)[0] + assistant_start_idx = np.where(inputs.input_ids[0] == assistant_token)[0] + + if len(assistant_start_idx) > 0: + # Mask everything before the assistant's actual response content + # Usually there's a few tokens after "assistant" role marker + response_start = assistant_start_idx[-1] + 2 # Adjust offset as needed + labels[0, :response_start] = self.masking_index + else: + raise Exception("Could not find assistant tokens") + + # Add tokenized data to sample + sample["input_ids"] = inputs.input_ids[0] + sample["attention_mask"] = inputs.attention_mask[0] + sample["labels"] = labels[0] + + # Add image-related tensors if present + if hasattr(inputs, 'pixel_values'): + sample["pixel_values"] = inputs.pixel_values + if hasattr(inputs, 'image_grid_thw'): + sample["image_grid_thw"] = inputs.image_grid_thw[0] + + return sample + + class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset): """Dataset that includes front matter parsing and PDF rendering by default.""" @@ -326,6 +393,7 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset): super().__init__(root_dir, pipeline_steps) + if __name__ == "__main__": import argparse from pathlib import Path @@ -399,5 +467,94 @@ if __name__ == "__main__": print(f"PDF: {Path(first_sample['pdf_path']).name}") print(f"Image size: {first_sample['image'].size}") print(f"Page data: {first_sample['page_data']}") + + # Test with actual Qwen2.5-VL tokenization + print("\n\n=== Testing with Qwen2.5-VL-7B-Instruct Tokenization ===") + + try: + from transformers import AutoProcessor + + print("Loading Qwen2.5-VL processor...") + processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + # Create pipeline with real tokenizer + tokenized_dataset = BaseMarkdownPDFDataset( + args.root_dir, + pipeline_steps=[ + FrontMatterParser(front_matter_class=PageResponse), + PDFRenderer(target_longest_image_dim=512), + StaticLengthDocumentAnchoring(target_anchor_text_len=1000), + FinetuningPrompt(), + FrontMatterOutputFormat(), + InstructMessages(), + Tokenizer(processor), + ] + ) + + if len(tokenized_dataset) > 0: + print("\nProcessing first sample with Qwen2.5-VL...") + tokenized_sample = tokenized_dataset[0] + + print("\nTokenized output:") + print(f" Keys: {list(tokenized_sample.keys())}") + print(f" Input IDs shape: {tokenized_sample['input_ids'].shape}") + print(f" Labels shape: {tokenized_sample['labels'].shape}") + print(f" Attention mask shape: {tokenized_sample['attention_mask'].shape}") + + if 'pixel_values' in tokenized_sample: + print(f" Pixel values shape: {tokenized_sample['pixel_values'].shape}") + if 'image_grid_thw' in tokenized_sample: + print(f" Image grid THW: {tokenized_sample['image_grid_thw']}") + + # Show label masking + print(f"\nLabel masking analysis:") + labels = tokenized_sample['labels'] + masked_count = np.sum(labels == -100) + total_count = len(labels) + print(f" Total tokens: {total_count}") + print(f" Masked tokens: {masked_count} ({masked_count/total_count*100:.1f}%)") + print(f" Unmasked tokens: {total_count - masked_count} ({(total_count - masked_count)/total_count*100:.1f}%)") + + # Find the transition point + transition_idx = None + for i in range(len(labels) - 1): + if labels[i] == -100 and labels[i + 1] != -100: + transition_idx = i + 1 + break + + if transition_idx: + print(f" Transition from masked to unmasked at position: {transition_idx}") + + # Print all tokens + input_ids = tokenized_sample['input_ids'] + print(f"\nAll tokens ({len(input_ids)} total):") + print("Format: [index] Token (repr) | Label | Token ID") + print("-" * 80) + + for i in range(len(input_ids)): + token = processor.tokenizer.decode([input_ids[i]]) + token_repr = repr(token) + label = labels[i] if i < len(labels) else "N/A" + token_id = input_ids[i] + + # Mark special positions + marker = "" + if transition_idx and i == transition_idx: + marker = " <-- TRANSITION (first unmasked)" + elif i == 0: + marker = " <-- START" + elif label != -100 and i > 0 and labels[i-1] == -100: + marker = " <-- response begins" + + print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}") + + except ImportError as e: + print(f"\nCould not import transformers: {e}") + print("Install with: pip install transformers") + except Exception as e: + print(f"\nError during tokenization test: {e}") + import traceback + traceback.print_exc() + else: raise AssertionError("Expected some data to be created at this point") \ No newline at end of file