mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Dataloader progress
This commit is contained in:
parent
9f50bda6bf
commit
105d5907d6
@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
|
||||
# Type alias for samples
|
||||
Sample: TypeAlias = Dict[str, Any]
|
||||
@ -191,13 +192,12 @@ class FrontMatterParser(PipelineStep):
|
||||
|
||||
# Parse front matter to dataclass if specified
|
||||
try:
|
||||
parsed_front_matter = self._parse_front_matter(front_matter, text)
|
||||
page_data = self._parse_front_matter(front_matter, text)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
|
||||
|
||||
# Update sample
|
||||
sample['text'] = text
|
||||
sample['front_matter'] = parsed_front_matter
|
||||
# Only add page_data field
|
||||
sample['page_data'] = page_data
|
||||
|
||||
return sample
|
||||
|
||||
@ -230,16 +230,74 @@ class PDFRenderer(PipelineStep):
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PromptBuilder(PipelineStep):
|
||||
"""Pipeline step that builds prompts using the finetuning prompt template."""
|
||||
base_text_field: str = 'text'
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
"""Build prompt from base text."""
|
||||
base_text = sample.get(self.base_text_field, '')
|
||||
sample['prompt'] = build_finetuning_prompt(base_text)
|
||||
return sample
|
||||
class StaticLengthDocumentAnchoring(PipelineStep):
|
||||
target_anchor_text_len: int
|
||||
|
||||
"""Pipeline step that runs document anchoring on the PDF and puts in the data to be used by later prompting stages"""
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
anchor_text = get_anchor_text(sample["pdf_path"], page=1, pdf_engine="pdfreport", target_length=self.target_anchor_text_len)
|
||||
sample["anchor_text"] = anchor_text
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FinetuningPrompt(PipelineStep):
|
||||
"""Applies the standard fine tuning prompt"""
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
sample["instruction_prompt"] = build_finetuning_prompt(sample["anchor_text"])
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FrontMatterOutputFormat(PipelineStep):
|
||||
"""Takes the output and applies the standard yaml formatting to it"""
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
page_data = sample["page_data"]
|
||||
assert type(page_data) == PageResponse
|
||||
|
||||
sample["output"] = f"""---
|
||||
primary_language: {page_data.primary_language}
|
||||
is_rotation_valid: {page_data.is_rotation_valid}
|
||||
rotation_correction: {page_data.rotation_correction}
|
||||
is_table: {page_data.is_table}
|
||||
is_diagram: {page_data.is_diagram}
|
||||
---
|
||||
{page_data.natural_text}
|
||||
""".strip()
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class InstructMessages(PipelineStep):
|
||||
"""Creates instruction-following messages format for training."""
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
# Convert PIL image to base64 string
|
||||
if 'image' in sample:
|
||||
buffered = BytesIO()
|
||||
sample['image'].save(buffered, format="PNG")
|
||||
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
else:
|
||||
raise ValueError("Image not found in sample. Make sure PDFRenderer runs before InstructMessages.")
|
||||
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": base64_image},
|
||||
{"type": "text", "text": sample["instruction_prompt"]},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": sample["output"]
|
||||
}
|
||||
]
|
||||
sample["messages"] = messages
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
||||
"""Dataset that includes front matter parsing and PDF rendering by default."""
|
||||
@ -257,7 +315,11 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
||||
# Create default pipeline steps
|
||||
pipeline_steps = [
|
||||
FrontMatterParser(front_matter_class),
|
||||
PDFRenderer(target_longest_image_dim, image_transform)
|
||||
PDFRenderer(target_longest_image_dim, image_transform),
|
||||
StaticLengthDocumentAnchoring(target_anchor_text_len=6000),
|
||||
FinetuningPrompt(),
|
||||
FrontMatterOutputFormat(),
|
||||
InstructMessages(),
|
||||
]
|
||||
|
||||
# Initialize base class with pipeline
|
||||
@ -266,6 +328,7 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
# Set up logging for testing
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
@ -280,37 +343,10 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Test base dataset without any pipeline steps
|
||||
print(f"\n=== Testing base dataset without pipeline steps ===")
|
||||
# Quick test to ensure dataset loads
|
||||
print(f"\n=== Testing dataset loading ===")
|
||||
base_dataset = BaseMarkdownPDFDataset(args.root_dir)
|
||||
print(f"Dataset length: {len(base_dataset)}")
|
||||
|
||||
if len(base_dataset) > 0:
|
||||
print("\nFirst sample (no pipeline):")
|
||||
sample = base_dataset[0]
|
||||
print(f" Keys: {list(sample.keys())}")
|
||||
print(f" Markdown: {sample['markdown_path'].name}")
|
||||
print(f" PDF: {sample['pdf_path'].name}")
|
||||
|
||||
# Test with individual pipeline steps
|
||||
print(f"\n=== Testing with individual pipeline steps ===")
|
||||
pipeline_dataset = BaseMarkdownPDFDataset(
|
||||
args.root_dir,
|
||||
pipeline_steps=[
|
||||
FrontMatterParser(front_matter_class=PageResponse),
|
||||
PDFRenderer(target_longest_image_dim=1024),
|
||||
PromptBuilder()
|
||||
]
|
||||
)
|
||||
|
||||
if len(pipeline_dataset) > 0:
|
||||
print("\nFirst sample (with pipeline):")
|
||||
sample = pipeline_dataset[0]
|
||||
print(f" Keys: {list(sample.keys())}")
|
||||
print(f" Front Matter: {sample['front_matter']}")
|
||||
print(f" Image size: {sample['image'].size}")
|
||||
print(f" Text preview: {sample['text'][:100]}...")
|
||||
print(f" Prompt preview: {sample.get('prompt', 'No prompt')[:200]}...")
|
||||
print(f"Found {len(base_dataset)} markdown-PDF pairs")
|
||||
|
||||
# Test the convenience dataset class
|
||||
print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===")
|
||||
@ -333,10 +369,35 @@ if __name__ == "__main__":
|
||||
# Test __getitem__
|
||||
print("\nTesting __getitem__ on first sample:")
|
||||
first_sample = dataset[0]
|
||||
print(f"Image type: {type(first_sample['image'])}")
|
||||
|
||||
# Pretty print the message structure
|
||||
print("\n=== Message Structure ===")
|
||||
messages = first_sample['messages']
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
print(f"\nMessage {i + 1}:")
|
||||
print(f" role: {msg['role']}")
|
||||
|
||||
if msg['role'] == 'user':
|
||||
print(" content:")
|
||||
for j, content_item in enumerate(msg['content']):
|
||||
if content_item['type'] == 'image':
|
||||
# Show that there's an image without the base64 data
|
||||
image_data = content_item['image']
|
||||
print(f" [{j}] type: image")
|
||||
print(f" image: <base64 PNG data, {len(image_data)} chars>")
|
||||
elif content_item['type'] == 'text':
|
||||
text_preview = content_item['text'][:200].replace('\n', '\n ')
|
||||
print(f" [{j}] type: text")
|
||||
print(f" text: {text_preview}...")
|
||||
else:
|
||||
# Assistant message
|
||||
content_preview = msg['content'][:300].replace('\n', '\n ')
|
||||
print(f" content: {content_preview}...")
|
||||
|
||||
print("\n=== Sample Metadata ===")
|
||||
print(f"PDF: {Path(first_sample['pdf_path']).name}")
|
||||
print(f"Image size: {first_sample['image'].size}")
|
||||
print(f"PDF Path: {first_sample['pdf_path']}")
|
||||
print(f"Front Matter: {first_sample['front_matter']}")
|
||||
print(f"Text (first 200 chars): {first_sample['text'][:200]}...")
|
||||
if hasattr(first_sample['front_matter'], 'natural_text'):
|
||||
print(f"Natural Text from PageResponse: {first_sample['front_matter'].natural_text[:200] if first_sample['front_matter'].natural_text else 'None'}...")
|
||||
print(f"Page data: {first_sample['page_data']}")
|
||||
else:
|
||||
raise AssertionError("Expected some data to be created at this point")
|
Loading…
x
Reference in New Issue
Block a user