Ok, dataloader from start to finish is running, now to write a trainer

This commit is contained in:
Jake Poznanski 2025-06-11 23:30:02 +00:00
parent 105d5907d6
commit cfe9aa102b

View File

@ -12,11 +12,17 @@ from pypdf import PdfReader
from tqdm import tqdm from tqdm import tqdm
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np
from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text from olmocr.prompts.anchor import get_anchor_text
try:
import numpy as np
except ImportError:
np = None
# Type alias for samples # Type alias for samples
Sample: TypeAlias = Dict[str, Any] Sample: TypeAlias = Dict[str, Any]
@ -299,6 +305,67 @@ class InstructMessages(PipelineStep):
return sample return sample
@dataclass(frozen=True, slots=True)
class Tokenizer(PipelineStep):
"""Tokenizes messages and creates training labels with proper masking."""
processor: Any # The model processor (e.g., AutoProcessor)
masking_index: int = -100
def __call__(self, sample: Sample) -> Sample:
"""Tokenize messages and create labels for training."""
if np is None:
raise ImportError("numpy is required for Tokenizer step")
messages = sample["messages"]
main_image = sample["image"]
# Apply chat template to full conversation
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False # Don't add prompt since we have the response
)
# Process everything together
inputs = self.processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)
# Create labels by copying input_ids and masking the prompt portion
labels = inputs.input_ids.copy()
# Find where the assistant response starts
# This assumes the processor adds some delimiter between user and assistant
# You might need to adjust based on your specific chat template
assistant_token = self.processor.tokenizer.encode("assistant", add_special_tokens=False)[0]
assistant_start_idx = np.where(inputs.input_ids[0] == assistant_token)[0]
if len(assistant_start_idx) > 0:
# Mask everything before the assistant's actual response content
# Usually there's a few tokens after "assistant" role marker
response_start = assistant_start_idx[-1] + 2 # Adjust offset as needed
labels[0, :response_start] = self.masking_index
else:
raise Exception("Could not find assistant tokens")
# Add tokenized data to sample
sample["input_ids"] = inputs.input_ids[0]
sample["attention_mask"] = inputs.attention_mask[0]
sample["labels"] = labels[0]
# Add image-related tensors if present
if hasattr(inputs, 'pixel_values'):
sample["pixel_values"] = inputs.pixel_values
if hasattr(inputs, 'image_grid_thw'):
sample["image_grid_thw"] = inputs.image_grid_thw[0]
return sample
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset): class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
"""Dataset that includes front matter parsing and PDF rendering by default.""" """Dataset that includes front matter parsing and PDF rendering by default."""
@ -326,6 +393,7 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
super().__init__(root_dir, pipeline_steps) super().__init__(root_dir, pipeline_steps)
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
from pathlib import Path from pathlib import Path
@ -399,5 +467,94 @@ if __name__ == "__main__":
print(f"PDF: {Path(first_sample['pdf_path']).name}") print(f"PDF: {Path(first_sample['pdf_path']).name}")
print(f"Image size: {first_sample['image'].size}") print(f"Image size: {first_sample['image'].size}")
print(f"Page data: {first_sample['page_data']}") print(f"Page data: {first_sample['page_data']}")
# Test with actual Qwen2.5-VL tokenization
print("\n\n=== Testing with Qwen2.5-VL-7B-Instruct Tokenization ===")
try:
from transformers import AutoProcessor
print("Loading Qwen2.5-VL processor...")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
# Create pipeline with real tokenizer
tokenized_dataset = BaseMarkdownPDFDataset(
args.root_dir,
pipeline_steps=[
FrontMatterParser(front_matter_class=PageResponse),
PDFRenderer(target_longest_image_dim=512),
StaticLengthDocumentAnchoring(target_anchor_text_len=1000),
FinetuningPrompt(),
FrontMatterOutputFormat(),
InstructMessages(),
Tokenizer(processor),
]
)
if len(tokenized_dataset) > 0:
print("\nProcessing first sample with Qwen2.5-VL...")
tokenized_sample = tokenized_dataset[0]
print("\nTokenized output:")
print(f" Keys: {list(tokenized_sample.keys())}")
print(f" Input IDs shape: {tokenized_sample['input_ids'].shape}")
print(f" Labels shape: {tokenized_sample['labels'].shape}")
print(f" Attention mask shape: {tokenized_sample['attention_mask'].shape}")
if 'pixel_values' in tokenized_sample:
print(f" Pixel values shape: {tokenized_sample['pixel_values'].shape}")
if 'image_grid_thw' in tokenized_sample:
print(f" Image grid THW: {tokenized_sample['image_grid_thw']}")
# Show label masking
print(f"\nLabel masking analysis:")
labels = tokenized_sample['labels']
masked_count = np.sum(labels == -100)
total_count = len(labels)
print(f" Total tokens: {total_count}")
print(f" Masked tokens: {masked_count} ({masked_count/total_count*100:.1f}%)")
print(f" Unmasked tokens: {total_count - masked_count} ({(total_count - masked_count)/total_count*100:.1f}%)")
# Find the transition point
transition_idx = None
for i in range(len(labels) - 1):
if labels[i] == -100 and labels[i + 1] != -100:
transition_idx = i + 1
break
if transition_idx:
print(f" Transition from masked to unmasked at position: {transition_idx}")
# Print all tokens
input_ids = tokenized_sample['input_ids']
print(f"\nAll tokens ({len(input_ids)} total):")
print("Format: [index] Token (repr) | Label | Token ID")
print("-" * 80)
for i in range(len(input_ids)):
token = processor.tokenizer.decode([input_ids[i]])
token_repr = repr(token)
label = labels[i] if i < len(labels) else "N/A"
token_id = input_ids[i]
# Mark special positions
marker = ""
if transition_idx and i == transition_idx:
marker = " <-- TRANSITION (first unmasked)"
elif i == 0:
marker = " <-- START"
elif label != -100 and i > 0 and labels[i-1] == -100:
marker = " <-- response begins"
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
except ImportError as e:
print(f"\nCould not import transformers: {e}")
print("Install with: pip install transformers")
except Exception as e:
print(f"\nError during tokenization test: {e}")
import traceback
traceback.print_exc()
else: else:
raise AssertionError("Expected some data to be created at this point") raise AssertionError("Expected some data to be created at this point")