Ok, dataloader from start to finish is running, now to write a trainer

This commit is contained in:
Jake Poznanski 2025-06-11 23:30:02 +00:00
parent 105d5907d6
commit cfe9aa102b

View File

@ -12,11 +12,17 @@ from pypdf import PdfReader
from tqdm import tqdm
from dataclasses import dataclass, fields
from abc import ABC, abstractmethod
import numpy as np
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
try:
import numpy as np
except ImportError:
np = None
# Type alias for samples
Sample: TypeAlias = Dict[str, Any]
@ -299,6 +305,67 @@ class InstructMessages(PipelineStep):
return sample
@dataclass(frozen=True, slots=True)
class Tokenizer(PipelineStep):
"""Tokenizes messages and creates training labels with proper masking."""
processor: Any # The model processor (e.g., AutoProcessor)
masking_index: int = -100
def __call__(self, sample: Sample) -> Sample:
"""Tokenize messages and create labels for training."""
if np is None:
raise ImportError("numpy is required for Tokenizer step")
messages = sample["messages"]
main_image = sample["image"]
# Apply chat template to full conversation
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False # Don't add prompt since we have the response
)
# Process everything together
inputs = self.processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)
# Create labels by copying input_ids and masking the prompt portion
labels = inputs.input_ids.copy()
# Find where the assistant response starts
# This assumes the processor adds some delimiter between user and assistant
# You might need to adjust based on your specific chat template
assistant_token = self.processor.tokenizer.encode("assistant", add_special_tokens=False)[0]
assistant_start_idx = np.where(inputs.input_ids[0] == assistant_token)[0]
if len(assistant_start_idx) > 0:
# Mask everything before the assistant's actual response content
# Usually there's a few tokens after "assistant" role marker
response_start = assistant_start_idx[-1] + 2 # Adjust offset as needed
labels[0, :response_start] = self.masking_index
else:
raise Exception("Could not find assistant tokens")
# Add tokenized data to sample
sample["input_ids"] = inputs.input_ids[0]
sample["attention_mask"] = inputs.attention_mask[0]
sample["labels"] = labels[0]
# Add image-related tensors if present
if hasattr(inputs, 'pixel_values'):
sample["pixel_values"] = inputs.pixel_values
if hasattr(inputs, 'image_grid_thw'):
sample["image_grid_thw"] = inputs.image_grid_thw[0]
return sample
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
"""Dataset that includes front matter parsing and PDF rendering by default."""
@ -326,6 +393,7 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
super().__init__(root_dir, pipeline_steps)
if __name__ == "__main__":
import argparse
from pathlib import Path
@ -399,5 +467,94 @@ if __name__ == "__main__":
print(f"PDF: {Path(first_sample['pdf_path']).name}")
print(f"Image size: {first_sample['image'].size}")
print(f"Page data: {first_sample['page_data']}")
# Test with actual Qwen2.5-VL tokenization
print("\n\n=== Testing with Qwen2.5-VL-7B-Instruct Tokenization ===")
try:
from transformers import AutoProcessor
print("Loading Qwen2.5-VL processor...")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
# Create pipeline with real tokenizer
tokenized_dataset = BaseMarkdownPDFDataset(
args.root_dir,
pipeline_steps=[
FrontMatterParser(front_matter_class=PageResponse),
PDFRenderer(target_longest_image_dim=512),
StaticLengthDocumentAnchoring(target_anchor_text_len=1000),
FinetuningPrompt(),
FrontMatterOutputFormat(),
InstructMessages(),
Tokenizer(processor),
]
)
if len(tokenized_dataset) > 0:
print("\nProcessing first sample with Qwen2.5-VL...")
tokenized_sample = tokenized_dataset[0]
print("\nTokenized output:")
print(f" Keys: {list(tokenized_sample.keys())}")
print(f" Input IDs shape: {tokenized_sample['input_ids'].shape}")
print(f" Labels shape: {tokenized_sample['labels'].shape}")
print(f" Attention mask shape: {tokenized_sample['attention_mask'].shape}")
if 'pixel_values' in tokenized_sample:
print(f" Pixel values shape: {tokenized_sample['pixel_values'].shape}")
if 'image_grid_thw' in tokenized_sample:
print(f" Image grid THW: {tokenized_sample['image_grid_thw']}")
# Show label masking
print(f"\nLabel masking analysis:")
labels = tokenized_sample['labels']
masked_count = np.sum(labels == -100)
total_count = len(labels)
print(f" Total tokens: {total_count}")
print(f" Masked tokens: {masked_count} ({masked_count/total_count*100:.1f}%)")
print(f" Unmasked tokens: {total_count - masked_count} ({(total_count - masked_count)/total_count*100:.1f}%)")
# Find the transition point
transition_idx = None
for i in range(len(labels) - 1):
if labels[i] == -100 and labels[i + 1] != -100:
transition_idx = i + 1
break
if transition_idx:
print(f" Transition from masked to unmasked at position: {transition_idx}")
# Print all tokens
input_ids = tokenized_sample['input_ids']
print(f"\nAll tokens ({len(input_ids)} total):")
print("Format: [index] Token (repr) | Label | Token ID")
print("-" * 80)
for i in range(len(input_ids)):
token = processor.tokenizer.decode([input_ids[i]])
token_repr = repr(token)
label = labels[i] if i < len(labels) else "N/A"
token_id = input_ids[i]
# Mark special positions
marker = ""
if transition_idx and i == transition_idx:
marker = " <-- TRANSITION (first unmasked)"
elif i == 0:
marker = " <-- START"
elif label != -100 and i > 0 and labels[i-1] == -100:
marker = " <-- response begins"
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
except ImportError as e:
print(f"\nCould not import transformers: {e}")
print("Install with: pip install transformers")
except Exception as e:
print(f"\nError during tokenization test: {e}")
import traceback
traceback.print_exc()
else:
raise AssertionError("Expected some data to be created at this point")