mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Dataloader progress
This commit is contained in:
parent
9f50bda6bf
commit
105d5907d6
@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
|
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
|
||||||
|
from olmocr.prompts.anchor import get_anchor_text
|
||||||
|
|
||||||
# Type alias for samples
|
# Type alias for samples
|
||||||
Sample: TypeAlias = Dict[str, Any]
|
Sample: TypeAlias = Dict[str, Any]
|
||||||
@ -191,13 +192,12 @@ class FrontMatterParser(PipelineStep):
|
|||||||
|
|
||||||
# Parse front matter to dataclass if specified
|
# Parse front matter to dataclass if specified
|
||||||
try:
|
try:
|
||||||
parsed_front_matter = self._parse_front_matter(front_matter, text)
|
page_data = self._parse_front_matter(front_matter, text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
|
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
|
||||||
|
|
||||||
# Update sample
|
# Only add page_data field
|
||||||
sample['text'] = text
|
sample['page_data'] = page_data
|
||||||
sample['front_matter'] = parsed_front_matter
|
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@ -230,14 +230,72 @@ class PDFRenderer(PipelineStep):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class PromptBuilder(PipelineStep):
|
class StaticLengthDocumentAnchoring(PipelineStep):
|
||||||
"""Pipeline step that builds prompts using the finetuning prompt template."""
|
target_anchor_text_len: int
|
||||||
base_text_field: str = 'text'
|
|
||||||
|
|
||||||
|
"""Pipeline step that runs document anchoring on the PDF and puts in the data to be used by later prompting stages"""
|
||||||
def __call__(self, sample: Sample) -> Sample:
|
def __call__(self, sample: Sample) -> Sample:
|
||||||
"""Build prompt from base text."""
|
anchor_text = get_anchor_text(sample["pdf_path"], page=1, pdf_engine="pdfreport", target_length=self.target_anchor_text_len)
|
||||||
base_text = sample.get(self.base_text_field, '')
|
sample["anchor_text"] = anchor_text
|
||||||
sample['prompt'] = build_finetuning_prompt(base_text)
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class FinetuningPrompt(PipelineStep):
|
||||||
|
"""Applies the standard fine tuning prompt"""
|
||||||
|
def __call__(self, sample: Sample) -> Sample:
|
||||||
|
sample["instruction_prompt"] = build_finetuning_prompt(sample["anchor_text"])
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class FrontMatterOutputFormat(PipelineStep):
|
||||||
|
"""Takes the output and applies the standard yaml formatting to it"""
|
||||||
|
def __call__(self, sample: Sample) -> Sample:
|
||||||
|
page_data = sample["page_data"]
|
||||||
|
assert type(page_data) == PageResponse
|
||||||
|
|
||||||
|
sample["output"] = f"""---
|
||||||
|
primary_language: {page_data.primary_language}
|
||||||
|
is_rotation_valid: {page_data.is_rotation_valid}
|
||||||
|
rotation_correction: {page_data.rotation_correction}
|
||||||
|
is_table: {page_data.is_table}
|
||||||
|
is_diagram: {page_data.is_diagram}
|
||||||
|
---
|
||||||
|
{page_data.natural_text}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class InstructMessages(PipelineStep):
|
||||||
|
"""Creates instruction-following messages format for training."""
|
||||||
|
def __call__(self, sample: Sample) -> Sample:
|
||||||
|
# Convert PIL image to base64 string
|
||||||
|
if 'image' in sample:
|
||||||
|
buffered = BytesIO()
|
||||||
|
sample['image'].save(buffered, format="PNG")
|
||||||
|
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
else:
|
||||||
|
raise ValueError("Image not found in sample. Make sure PDFRenderer runs before InstructMessages.")
|
||||||
|
|
||||||
|
# Prepare messages
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": base64_image},
|
||||||
|
{"type": "text", "text": sample["instruction_prompt"]},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": sample["output"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
sample["messages"] = messages
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@ -257,7 +315,11 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
|||||||
# Create default pipeline steps
|
# Create default pipeline steps
|
||||||
pipeline_steps = [
|
pipeline_steps = [
|
||||||
FrontMatterParser(front_matter_class),
|
FrontMatterParser(front_matter_class),
|
||||||
PDFRenderer(target_longest_image_dim, image_transform)
|
PDFRenderer(target_longest_image_dim, image_transform),
|
||||||
|
StaticLengthDocumentAnchoring(target_anchor_text_len=6000),
|
||||||
|
FinetuningPrompt(),
|
||||||
|
FrontMatterOutputFormat(),
|
||||||
|
InstructMessages(),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Initialize base class with pipeline
|
# Initialize base class with pipeline
|
||||||
@ -266,6 +328,7 @@ class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# Set up logging for testing
|
# Set up logging for testing
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
@ -280,37 +343,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Test base dataset without any pipeline steps
|
# Quick test to ensure dataset loads
|
||||||
print(f"\n=== Testing base dataset without pipeline steps ===")
|
print(f"\n=== Testing dataset loading ===")
|
||||||
base_dataset = BaseMarkdownPDFDataset(args.root_dir)
|
base_dataset = BaseMarkdownPDFDataset(args.root_dir)
|
||||||
print(f"Dataset length: {len(base_dataset)}")
|
print(f"Found {len(base_dataset)} markdown-PDF pairs")
|
||||||
|
|
||||||
if len(base_dataset) > 0:
|
|
||||||
print("\nFirst sample (no pipeline):")
|
|
||||||
sample = base_dataset[0]
|
|
||||||
print(f" Keys: {list(sample.keys())}")
|
|
||||||
print(f" Markdown: {sample['markdown_path'].name}")
|
|
||||||
print(f" PDF: {sample['pdf_path'].name}")
|
|
||||||
|
|
||||||
# Test with individual pipeline steps
|
|
||||||
print(f"\n=== Testing with individual pipeline steps ===")
|
|
||||||
pipeline_dataset = BaseMarkdownPDFDataset(
|
|
||||||
args.root_dir,
|
|
||||||
pipeline_steps=[
|
|
||||||
FrontMatterParser(front_matter_class=PageResponse),
|
|
||||||
PDFRenderer(target_longest_image_dim=1024),
|
|
||||||
PromptBuilder()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(pipeline_dataset) > 0:
|
|
||||||
print("\nFirst sample (with pipeline):")
|
|
||||||
sample = pipeline_dataset[0]
|
|
||||||
print(f" Keys: {list(sample.keys())}")
|
|
||||||
print(f" Front Matter: {sample['front_matter']}")
|
|
||||||
print(f" Image size: {sample['image'].size}")
|
|
||||||
print(f" Text preview: {sample['text'][:100]}...")
|
|
||||||
print(f" Prompt preview: {sample.get('prompt', 'No prompt')[:200]}...")
|
|
||||||
|
|
||||||
# Test the convenience dataset class
|
# Test the convenience dataset class
|
||||||
print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===")
|
print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===")
|
||||||
@ -333,10 +369,35 @@ if __name__ == "__main__":
|
|||||||
# Test __getitem__
|
# Test __getitem__
|
||||||
print("\nTesting __getitem__ on first sample:")
|
print("\nTesting __getitem__ on first sample:")
|
||||||
first_sample = dataset[0]
|
first_sample = dataset[0]
|
||||||
print(f"Image type: {type(first_sample['image'])}")
|
|
||||||
|
# Pretty print the message structure
|
||||||
|
print("\n=== Message Structure ===")
|
||||||
|
messages = first_sample['messages']
|
||||||
|
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
print(f"\nMessage {i + 1}:")
|
||||||
|
print(f" role: {msg['role']}")
|
||||||
|
|
||||||
|
if msg['role'] == 'user':
|
||||||
|
print(" content:")
|
||||||
|
for j, content_item in enumerate(msg['content']):
|
||||||
|
if content_item['type'] == 'image':
|
||||||
|
# Show that there's an image without the base64 data
|
||||||
|
image_data = content_item['image']
|
||||||
|
print(f" [{j}] type: image")
|
||||||
|
print(f" image: <base64 PNG data, {len(image_data)} chars>")
|
||||||
|
elif content_item['type'] == 'text':
|
||||||
|
text_preview = content_item['text'][:200].replace('\n', '\n ')
|
||||||
|
print(f" [{j}] type: text")
|
||||||
|
print(f" text: {text_preview}...")
|
||||||
|
else:
|
||||||
|
# Assistant message
|
||||||
|
content_preview = msg['content'][:300].replace('\n', '\n ')
|
||||||
|
print(f" content: {content_preview}...")
|
||||||
|
|
||||||
|
print("\n=== Sample Metadata ===")
|
||||||
|
print(f"PDF: {Path(first_sample['pdf_path']).name}")
|
||||||
print(f"Image size: {first_sample['image'].size}")
|
print(f"Image size: {first_sample['image'].size}")
|
||||||
print(f"PDF Path: {first_sample['pdf_path']}")
|
print(f"Page data: {first_sample['page_data']}")
|
||||||
print(f"Front Matter: {first_sample['front_matter']}")
|
else:
|
||||||
print(f"Text (first 200 chars): {first_sample['text'][:200]}...")
|
raise AssertionError("Expected some data to be created at this point")
|
||||||
if hasattr(first_sample['front_matter'], 'natural_text'):
|
|
||||||
print(f"Natural Text from PageResponse: {first_sample['front_matter'].natural_text[:200] if first_sample['front_matter'].natural_text else 'None'}...")
|
|
Loading…
x
Reference in New Issue
Block a user