mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-13 17:22:13 +00:00
Ok, dataloader from start to finish is running, now to write a trainer
This commit is contained in:
parent
105d5907d6
commit
cfe9aa102b
@ -12,11 +12,17 @@ from pypdf import PdfReader
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
|
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
|
||||||
from olmocr.prompts.anchor import get_anchor_text
|
from olmocr.prompts.anchor import get_anchor_text
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
np = None
|
||||||
|
|
||||||
# Type alias for samples
|
# Type alias for samples
|
||||||
Sample: TypeAlias = Dict[str, Any]
|
Sample: TypeAlias = Dict[str, Any]
|
||||||
|
|
||||||
@ -299,6 +305,67 @@ class InstructMessages(PipelineStep):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class Tokenizer(PipelineStep):
|
||||||
|
"""Tokenizes messages and creates training labels with proper masking."""
|
||||||
|
processor: Any # The model processor (e.g., AutoProcessor)
|
||||||
|
masking_index: int = -100
|
||||||
|
|
||||||
|
def __call__(self, sample: Sample) -> Sample:
|
||||||
|
"""Tokenize messages and create labels for training."""
|
||||||
|
if np is None:
|
||||||
|
raise ImportError("numpy is required for Tokenizer step")
|
||||||
|
|
||||||
|
messages = sample["messages"]
|
||||||
|
main_image = sample["image"]
|
||||||
|
|
||||||
|
# Apply chat template to full conversation
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=False # Don't add prompt since we have the response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process everything together
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text],
|
||||||
|
images=[main_image],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create labels by copying input_ids and masking the prompt portion
|
||||||
|
labels = inputs.input_ids.copy()
|
||||||
|
|
||||||
|
# Find where the assistant response starts
|
||||||
|
# This assumes the processor adds some delimiter between user and assistant
|
||||||
|
# You might need to adjust based on your specific chat template
|
||||||
|
|
||||||
|
assistant_token = self.processor.tokenizer.encode("assistant", add_special_tokens=False)[0]
|
||||||
|
assistant_start_idx = np.where(inputs.input_ids[0] == assistant_token)[0]
|
||||||
|
|
||||||
|
if len(assistant_start_idx) > 0:
|
||||||
|
# Mask everything before the assistant's actual response content
|
||||||
|
# Usually there's a few tokens after "assistant" role marker
|
||||||
|
response_start = assistant_start_idx[-1] + 2 # Adjust offset as needed
|
||||||
|
labels[0, :response_start] = self.masking_index
|
||||||
|
else:
|
||||||
|
raise Exception("Could not find assistant tokens")
|
||||||
|
|
||||||
|
# Add tokenized data to sample
|
||||||
|
sample["input_ids"] = inputs.input_ids[0]
|
||||||
|
sample["attention_mask"] = inputs.attention_mask[0]
|
||||||
|
sample["labels"] = labels[0]
|
||||||
|
|
||||||
|
# Add image-related tensors if present
|
||||||
|
if hasattr(inputs, 'pixel_values'):
|
||||||
|
sample["pixel_values"] = inputs.pixel_values
|
||||||
|
if hasattr(inputs, 'image_grid_thw'):
|
||||||
|
sample["image_grid_thw"] = inputs.image_grid_thw[0]
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
||||||
"""Dataset that includes front matter parsing and PDF rendering by default."""
|
"""Dataset that includes front matter parsing and PDF rendering by default."""
|
||||||
|
|
||||||
@ -326,6 +393,7 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
|||||||
super().__init__(root_dir, pipeline_steps)
|
super().__init__(root_dir, pipeline_steps)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -399,5 +467,94 @@ if __name__ == "__main__":
|
|||||||
print(f"PDF: {Path(first_sample['pdf_path']).name}")
|
print(f"PDF: {Path(first_sample['pdf_path']).name}")
|
||||||
print(f"Image size: {first_sample['image'].size}")
|
print(f"Image size: {first_sample['image'].size}")
|
||||||
print(f"Page data: {first_sample['page_data']}")
|
print(f"Page data: {first_sample['page_data']}")
|
||||||
|
|
||||||
|
# Test with actual Qwen2.5-VL tokenization
|
||||||
|
print("\n\n=== Testing with Qwen2.5-VL-7B-Instruct Tokenization ===")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
print("Loading Qwen2.5-VL processor...")
|
||||||
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||||
|
|
||||||
|
# Create pipeline with real tokenizer
|
||||||
|
tokenized_dataset = BaseMarkdownPDFDataset(
|
||||||
|
args.root_dir,
|
||||||
|
pipeline_steps=[
|
||||||
|
FrontMatterParser(front_matter_class=PageResponse),
|
||||||
|
PDFRenderer(target_longest_image_dim=512),
|
||||||
|
StaticLengthDocumentAnchoring(target_anchor_text_len=1000),
|
||||||
|
FinetuningPrompt(),
|
||||||
|
FrontMatterOutputFormat(),
|
||||||
|
InstructMessages(),
|
||||||
|
Tokenizer(processor),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(tokenized_dataset) > 0:
|
||||||
|
print("\nProcessing first sample with Qwen2.5-VL...")
|
||||||
|
tokenized_sample = tokenized_dataset[0]
|
||||||
|
|
||||||
|
print("\nTokenized output:")
|
||||||
|
print(f" Keys: {list(tokenized_sample.keys())}")
|
||||||
|
print(f" Input IDs shape: {tokenized_sample['input_ids'].shape}")
|
||||||
|
print(f" Labels shape: {tokenized_sample['labels'].shape}")
|
||||||
|
print(f" Attention mask shape: {tokenized_sample['attention_mask'].shape}")
|
||||||
|
|
||||||
|
if 'pixel_values' in tokenized_sample:
|
||||||
|
print(f" Pixel values shape: {tokenized_sample['pixel_values'].shape}")
|
||||||
|
if 'image_grid_thw' in tokenized_sample:
|
||||||
|
print(f" Image grid THW: {tokenized_sample['image_grid_thw']}")
|
||||||
|
|
||||||
|
# Show label masking
|
||||||
|
print(f"\nLabel masking analysis:")
|
||||||
|
labels = tokenized_sample['labels']
|
||||||
|
masked_count = np.sum(labels == -100)
|
||||||
|
total_count = len(labels)
|
||||||
|
print(f" Total tokens: {total_count}")
|
||||||
|
print(f" Masked tokens: {masked_count} ({masked_count/total_count*100:.1f}%)")
|
||||||
|
print(f" Unmasked tokens: {total_count - masked_count} ({(total_count - masked_count)/total_count*100:.1f}%)")
|
||||||
|
|
||||||
|
# Find the transition point
|
||||||
|
transition_idx = None
|
||||||
|
for i in range(len(labels) - 1):
|
||||||
|
if labels[i] == -100 and labels[i + 1] != -100:
|
||||||
|
transition_idx = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if transition_idx:
|
||||||
|
print(f" Transition from masked to unmasked at position: {transition_idx}")
|
||||||
|
|
||||||
|
# Print all tokens
|
||||||
|
input_ids = tokenized_sample['input_ids']
|
||||||
|
print(f"\nAll tokens ({len(input_ids)} total):")
|
||||||
|
print("Format: [index] Token (repr) | Label | Token ID")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
for i in range(len(input_ids)):
|
||||||
|
token = processor.tokenizer.decode([input_ids[i]])
|
||||||
|
token_repr = repr(token)
|
||||||
|
label = labels[i] if i < len(labels) else "N/A"
|
||||||
|
token_id = input_ids[i]
|
||||||
|
|
||||||
|
# Mark special positions
|
||||||
|
marker = ""
|
||||||
|
if transition_idx and i == transition_idx:
|
||||||
|
marker = " <-- TRANSITION (first unmasked)"
|
||||||
|
elif i == 0:
|
||||||
|
marker = " <-- START"
|
||||||
|
elif label != -100 and i > 0 and labels[i-1] == -100:
|
||||||
|
marker = " <-- response begins"
|
||||||
|
|
||||||
|
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"\nCould not import transformers: {e}")
|
||||||
|
print("Install with: pip install transformers")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError during tokenization test: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise AssertionError("Expected some data to be created at this point")
|
raise AssertionError("Expected some data to be created at this point")
|
Loading…
x
Reference in New Issue
Block a user