Image params to loader

This commit is contained in:
Jake Poznanski 2025-06-11 21:05:23 +00:00
parent 9a390e3d58
commit aedc295e3f

View File

@ -12,16 +12,17 @@ from olmocr.data.renderpdf import render_pdf_to_base64png
class MarkdownPDFDocumentDataset(Dataset):
def __init__(self, root_dir: str | PathLike, transform=None):
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None):
"""
Initialize the dataset by finding all markdown files with corresponding PDFs.
Args:
root_dir: Path to the root folder containing processed markdown and PDF files
transform: Optional transform to apply to the PDF images
image_transform: Optional transform to apply to the PDF images
"""
self.root_dir = Path(root_dir)
self.transform = transform
self.image_transform = image_transform
self.target_longest_image_dim = target_longest_image_dim
self.samples = []
# Find all markdown files recursively
@ -115,13 +116,13 @@ class MarkdownPDFDocumentDataset(Dataset):
text = parts[2].strip()
# Render PDF to image
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1)
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1, target_longest_image_dim=self.target_longest_image_dim)
png_bytes = base64.b64decode(base64_png)
image = Image.open(BytesIO(png_bytes))
# Apply transform if provided
if self.transform:
image = self.transform(image)
if self.image_transform:
image = self.image_transform(image)
return {
'image': image,
@ -146,7 +147,7 @@ if __name__ == "__main__":
# Test dataset initialization
print(f"\nTesting dataset with root directory: {args.root_dir}")
dataset = MarkdownPDFDocumentDataset(args.root_dir)
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=None)
print(f"\nDataset length: {len(dataset)}")
@ -175,7 +176,7 @@ if __name__ == "__main__":
transforms.ToTensor(),
])
dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, transform=transform)
dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=transform)
transformed_sample = dataset_with_transform[0]
print(f"Transformed image type: {type(transformed_sample['image'])}")
print(f"Transformed image shape: {transformed_sample['image'].shape}")