Dataloader progress

This commit is contained in:
Jake Poznanski 2025-06-11 22:35:35 +00:00
parent 9f50bda6bf
commit 105d5907d6

View File

@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
# Type alias for samples
Sample: TypeAlias = Dict[str, Any]
@ -191,13 +192,12 @@ class FrontMatterParser(PipelineStep):
# Parse front matter to dataclass if specified
try:
parsed_front_matter = self._parse_front_matter(front_matter, text)
page_data = self._parse_front_matter(front_matter, text)
except Exception as e:
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
# Update sample
sample['text'] = text
sample['front_matter'] = parsed_front_matter
# Only add page_data field
sample['page_data'] = page_data
return sample
@ -230,16 +230,74 @@ class PDFRenderer(PipelineStep):
@dataclass(frozen=True, slots=True)
class PromptBuilder(PipelineStep):
"""Pipeline step that builds prompts using the finetuning prompt template."""
base_text_field: str = 'text'
def __call__(self, sample: Sample) -> Sample:
"""Build prompt from base text."""
base_text = sample.get(self.base_text_field, '')
sample['prompt'] = build_finetuning_prompt(base_text)
return sample
class StaticLengthDocumentAnchoring(PipelineStep):
target_anchor_text_len: int
"""Pipeline step that runs document anchoring on the PDF and puts in the data to be used by later prompting stages"""
def __call__(self, sample: Sample) -> Sample:
anchor_text = get_anchor_text(sample["pdf_path"], page=1, pdf_engine="pdfreport", target_length=self.target_anchor_text_len)
sample["anchor_text"] = anchor_text
return sample
@dataclass(frozen=True, slots=True)
class FinetuningPrompt(PipelineStep):
"""Applies the standard fine tuning prompt"""
def __call__(self, sample: Sample) -> Sample:
sample["instruction_prompt"] = build_finetuning_prompt(sample["anchor_text"])
return sample
@dataclass(frozen=True, slots=True)
class FrontMatterOutputFormat(PipelineStep):
"""Takes the output and applies the standard yaml formatting to it"""
def __call__(self, sample: Sample) -> Sample:
page_data = sample["page_data"]
assert type(page_data) == PageResponse
sample["output"] = f"""---
primary_language: {page_data.primary_language}
is_rotation_valid: {page_data.is_rotation_valid}
rotation_correction: {page_data.rotation_correction}
is_table: {page_data.is_table}
is_diagram: {page_data.is_diagram}
---
{page_data.natural_text}
""".strip()
return sample
@dataclass(frozen=True, slots=True)
class InstructMessages(PipelineStep):
"""Creates instruction-following messages format for training."""
def __call__(self, sample: Sample) -> Sample:
# Convert PIL image to base64 string
if 'image' in sample:
buffered = BytesIO()
sample['image'].save(buffered, format="PNG")
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
else:
raise ValueError("Image not found in sample. Make sure PDFRenderer runs before InstructMessages.")
# Prepare messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": base64_image},
{"type": "text", "text": sample["instruction_prompt"]},
],
},
{
"role": "assistant",
"content": sample["output"]
}
]
sample["messages"] = messages
return sample
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
"""Dataset that includes front matter parsing and PDF rendering by default."""
@ -257,7 +315,11 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
# Create default pipeline steps
pipeline_steps = [
FrontMatterParser(front_matter_class),
PDFRenderer(target_longest_image_dim, image_transform)
PDFRenderer(target_longest_image_dim, image_transform),
StaticLengthDocumentAnchoring(target_anchor_text_len=6000),
FinetuningPrompt(),
FrontMatterOutputFormat(),
InstructMessages(),
]
# Initialize base class with pipeline
@ -266,6 +328,7 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
if __name__ == "__main__":
import argparse
from pathlib import Path
# Set up logging for testing
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@ -280,37 +343,10 @@ if __name__ == "__main__":
args = parser.parse_args()
# Test base dataset without any pipeline steps
print(f"\n=== Testing base dataset without pipeline steps ===")
# Quick test to ensure dataset loads
print(f"\n=== Testing dataset loading ===")
base_dataset = BaseMarkdownPDFDataset(args.root_dir)
print(f"Dataset length: {len(base_dataset)}")
if len(base_dataset) > 0:
print("\nFirst sample (no pipeline):")
sample = base_dataset[0]
print(f" Keys: {list(sample.keys())}")
print(f" Markdown: {sample['markdown_path'].name}")
print(f" PDF: {sample['pdf_path'].name}")
# Test with individual pipeline steps
print(f"\n=== Testing with individual pipeline steps ===")
pipeline_dataset = BaseMarkdownPDFDataset(
args.root_dir,
pipeline_steps=[
FrontMatterParser(front_matter_class=PageResponse),
PDFRenderer(target_longest_image_dim=1024),
PromptBuilder()
]
)
if len(pipeline_dataset) > 0:
print("\nFirst sample (with pipeline):")
sample = pipeline_dataset[0]
print(f" Keys: {list(sample.keys())}")
print(f" Front Matter: {sample['front_matter']}")
print(f" Image size: {sample['image'].size}")
print(f" Text preview: {sample['text'][:100]}...")
print(f" Prompt preview: {sample.get('prompt', 'No prompt')[:200]}...")
print(f"Found {len(base_dataset)} markdown-PDF pairs")
# Test the convenience dataset class
print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===")
@ -333,10 +369,35 @@ if __name__ == "__main__":
# Test __getitem__
print("\nTesting __getitem__ on first sample:")
first_sample = dataset[0]
print(f"Image type: {type(first_sample['image'])}")
# Pretty print the message structure
print("\n=== Message Structure ===")
messages = first_sample['messages']
for i, msg in enumerate(messages):
print(f"\nMessage {i + 1}:")
print(f" role: {msg['role']}")
if msg['role'] == 'user':
print(" content:")
for j, content_item in enumerate(msg['content']):
if content_item['type'] == 'image':
# Show that there's an image without the base64 data
image_data = content_item['image']
print(f" [{j}] type: image")
print(f" image: <base64 PNG data, {len(image_data)} chars>")
elif content_item['type'] == 'text':
text_preview = content_item['text'][:200].replace('\n', '\n ')
print(f" [{j}] type: text")
print(f" text: {text_preview}...")
else:
# Assistant message
content_preview = msg['content'][:300].replace('\n', '\n ')
print(f" content: {content_preview}...")
print("\n=== Sample Metadata ===")
print(f"PDF: {Path(first_sample['pdf_path']).name}")
print(f"Image size: {first_sample['image'].size}")
print(f"PDF Path: {first_sample['pdf_path']}")
print(f"Front Matter: {first_sample['front_matter']}")
print(f"Text (first 200 chars): {first_sample['text'][:200]}...")
if hasattr(first_sample['front_matter'], 'natural_text'):
print(f"Natural Text from PageResponse: {first_sample['front_matter'].natural_text[:200] if first_sample['front_matter'].natural_text else 'None'}...")
print(f"Page data: {first_sample['page_data']}")
else:
raise AssertionError("Expected some data to be created at this point")