Image params to loader

This commit is contained in:
Jake Poznanski 2025-06-11 21:05:23 +00:00
parent 9a390e3d58
commit aedc295e3f

View File

@ -12,16 +12,17 @@ from olmocr.data.renderpdf import render_pdf_to_base64png
class MarkdownPDFDocumentDataset(Dataset): class MarkdownPDFDocumentDataset(Dataset):
def __init__(self, root_dir: str | PathLike, transform=None): def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None):
""" """
Initialize the dataset by finding all markdown files with corresponding PDFs. Initialize the dataset by finding all markdown files with corresponding PDFs.
Args: Args:
root_dir: Path to the root folder containing processed markdown and PDF files root_dir: Path to the root folder containing processed markdown and PDF files
transform: Optional transform to apply to the PDF images image_transform: Optional transform to apply to the PDF images
""" """
self.root_dir = Path(root_dir) self.root_dir = Path(root_dir)
self.transform = transform self.image_transform = image_transform
self.target_longest_image_dim = target_longest_image_dim
self.samples = [] self.samples = []
# Find all markdown files recursively # Find all markdown files recursively
@ -115,13 +116,13 @@ class MarkdownPDFDocumentDataset(Dataset):
text = parts[2].strip() text = parts[2].strip()
# Render PDF to image # Render PDF to image
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1) base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1, target_longest_image_dim=self.target_longest_image_dim)
png_bytes = base64.b64decode(base64_png) png_bytes = base64.b64decode(base64_png)
image = Image.open(BytesIO(png_bytes)) image = Image.open(BytesIO(png_bytes))
# Apply transform if provided # Apply transform if provided
if self.transform: if self.image_transform:
image = self.transform(image) image = self.image_transform(image)
return { return {
'image': image, 'image': image,
@ -146,7 +147,7 @@ if __name__ == "__main__":
# Test dataset initialization # Test dataset initialization
print(f"\nTesting dataset with root directory: {args.root_dir}") print(f"\nTesting dataset with root directory: {args.root_dir}")
dataset = MarkdownPDFDocumentDataset(args.root_dir) dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=None)
print(f"\nDataset length: {len(dataset)}") print(f"\nDataset length: {len(dataset)}")
@ -175,7 +176,7 @@ if __name__ == "__main__":
transforms.ToTensor(), transforms.ToTensor(),
]) ])
dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, transform=transform) dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=transform)
transformed_sample = dataset_with_transform[0] transformed_sample = dataset_with_transform[0]
print(f"Transformed image type: {type(transformed_sample['image'])}") print(f"Transformed image type: {type(transformed_sample['image'])}")
print(f"Transformed image shape: {transformed_sample['image'].shape}") print(f"Transformed image shape: {transformed_sample['image'].shape}")