mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-13 01:02:26 +00:00
Ok, dataloader from start to finish is running, now to write a trainer
This commit is contained in:
parent
105d5907d6
commit
cfe9aa102b
@ -12,11 +12,17 @@ from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
from dataclasses import dataclass, fields
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
|
||||
# Type alias for samples
|
||||
Sample: TypeAlias = Dict[str, Any]
|
||||
|
||||
@ -299,6 +305,67 @@ class InstructMessages(PipelineStep):
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Tokenizer(PipelineStep):
|
||||
"""Tokenizes messages and creates training labels with proper masking."""
|
||||
processor: Any # The model processor (e.g., AutoProcessor)
|
||||
masking_index: int = -100
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
"""Tokenize messages and create labels for training."""
|
||||
if np is None:
|
||||
raise ImportError("numpy is required for Tokenizer step")
|
||||
|
||||
messages = sample["messages"]
|
||||
main_image = sample["image"]
|
||||
|
||||
# Apply chat template to full conversation
|
||||
text = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False # Don't add prompt since we have the response
|
||||
)
|
||||
|
||||
# Process everything together
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=[main_image],
|
||||
padding=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Create labels by copying input_ids and masking the prompt portion
|
||||
labels = inputs.input_ids.copy()
|
||||
|
||||
# Find where the assistant response starts
|
||||
# This assumes the processor adds some delimiter between user and assistant
|
||||
# You might need to adjust based on your specific chat template
|
||||
|
||||
assistant_token = self.processor.tokenizer.encode("assistant", add_special_tokens=False)[0]
|
||||
assistant_start_idx = np.where(inputs.input_ids[0] == assistant_token)[0]
|
||||
|
||||
if len(assistant_start_idx) > 0:
|
||||
# Mask everything before the assistant's actual response content
|
||||
# Usually there's a few tokens after "assistant" role marker
|
||||
response_start = assistant_start_idx[-1] + 2 # Adjust offset as needed
|
||||
labels[0, :response_start] = self.masking_index
|
||||
else:
|
||||
raise Exception("Could not find assistant tokens")
|
||||
|
||||
# Add tokenized data to sample
|
||||
sample["input_ids"] = inputs.input_ids[0]
|
||||
sample["attention_mask"] = inputs.attention_mask[0]
|
||||
sample["labels"] = labels[0]
|
||||
|
||||
# Add image-related tensors if present
|
||||
if hasattr(inputs, 'pixel_values'):
|
||||
sample["pixel_values"] = inputs.pixel_values
|
||||
if hasattr(inputs, 'image_grid_thw'):
|
||||
sample["image_grid_thw"] = inputs.image_grid_thw[0]
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
||||
"""Dataset that includes front matter parsing and PDF rendering by default."""
|
||||
|
||||
@ -326,6 +393,7 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
||||
super().__init__(root_dir, pipeline_steps)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
@ -399,5 +467,94 @@ if __name__ == "__main__":
|
||||
print(f"PDF: {Path(first_sample['pdf_path']).name}")
|
||||
print(f"Image size: {first_sample['image'].size}")
|
||||
print(f"Page data: {first_sample['page_data']}")
|
||||
|
||||
# Test with actual Qwen2.5-VL tokenization
|
||||
print("\n\n=== Testing with Qwen2.5-VL-7B-Instruct Tokenization ===")
|
||||
|
||||
try:
|
||||
from transformers import AutoProcessor
|
||||
|
||||
print("Loading Qwen2.5-VL processor...")
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
# Create pipeline with real tokenizer
|
||||
tokenized_dataset = BaseMarkdownPDFDataset(
|
||||
args.root_dir,
|
||||
pipeline_steps=[
|
||||
FrontMatterParser(front_matter_class=PageResponse),
|
||||
PDFRenderer(target_longest_image_dim=512),
|
||||
StaticLengthDocumentAnchoring(target_anchor_text_len=1000),
|
||||
FinetuningPrompt(),
|
||||
FrontMatterOutputFormat(),
|
||||
InstructMessages(),
|
||||
Tokenizer(processor),
|
||||
]
|
||||
)
|
||||
|
||||
if len(tokenized_dataset) > 0:
|
||||
print("\nProcessing first sample with Qwen2.5-VL...")
|
||||
tokenized_sample = tokenized_dataset[0]
|
||||
|
||||
print("\nTokenized output:")
|
||||
print(f" Keys: {list(tokenized_sample.keys())}")
|
||||
print(f" Input IDs shape: {tokenized_sample['input_ids'].shape}")
|
||||
print(f" Labels shape: {tokenized_sample['labels'].shape}")
|
||||
print(f" Attention mask shape: {tokenized_sample['attention_mask'].shape}")
|
||||
|
||||
if 'pixel_values' in tokenized_sample:
|
||||
print(f" Pixel values shape: {tokenized_sample['pixel_values'].shape}")
|
||||
if 'image_grid_thw' in tokenized_sample:
|
||||
print(f" Image grid THW: {tokenized_sample['image_grid_thw']}")
|
||||
|
||||
# Show label masking
|
||||
print(f"\nLabel masking analysis:")
|
||||
labels = tokenized_sample['labels']
|
||||
masked_count = np.sum(labels == -100)
|
||||
total_count = len(labels)
|
||||
print(f" Total tokens: {total_count}")
|
||||
print(f" Masked tokens: {masked_count} ({masked_count/total_count*100:.1f}%)")
|
||||
print(f" Unmasked tokens: {total_count - masked_count} ({(total_count - masked_count)/total_count*100:.1f}%)")
|
||||
|
||||
# Find the transition point
|
||||
transition_idx = None
|
||||
for i in range(len(labels) - 1):
|
||||
if labels[i] == -100 and labels[i + 1] != -100:
|
||||
transition_idx = i + 1
|
||||
break
|
||||
|
||||
if transition_idx:
|
||||
print(f" Transition from masked to unmasked at position: {transition_idx}")
|
||||
|
||||
# Print all tokens
|
||||
input_ids = tokenized_sample['input_ids']
|
||||
print(f"\nAll tokens ({len(input_ids)} total):")
|
||||
print("Format: [index] Token (repr) | Label | Token ID")
|
||||
print("-" * 80)
|
||||
|
||||
for i in range(len(input_ids)):
|
||||
token = processor.tokenizer.decode([input_ids[i]])
|
||||
token_repr = repr(token)
|
||||
label = labels[i] if i < len(labels) else "N/A"
|
||||
token_id = input_ids[i]
|
||||
|
||||
# Mark special positions
|
||||
marker = ""
|
||||
if transition_idx and i == transition_idx:
|
||||
marker = " <-- TRANSITION (first unmasked)"
|
||||
elif i == 0:
|
||||
marker = " <-- START"
|
||||
elif label != -100 and i > 0 and labels[i-1] == -100:
|
||||
marker = " <-- response begins"
|
||||
|
||||
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"\nCould not import transformers: {e}")
|
||||
print("Install with: pip install transformers")
|
||||
except Exception as e:
|
||||
print(f"\nError during tokenization test: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
else:
|
||||
raise AssertionError("Expected some data to be created at this point")
|
Loading…
x
Reference in New Issue
Block a user