mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Image params to loader
This commit is contained in:
parent
9a390e3d58
commit
aedc295e3f
@ -12,16 +12,17 @@ from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
|
||||
|
||||
class MarkdownPDFDocumentDataset(Dataset):
|
||||
def __init__(self, root_dir: str | PathLike, transform=None):
|
||||
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None):
|
||||
"""
|
||||
Initialize the dataset by finding all markdown files with corresponding PDFs.
|
||||
|
||||
Args:
|
||||
root_dir: Path to the root folder containing processed markdown and PDF files
|
||||
transform: Optional transform to apply to the PDF images
|
||||
image_transform: Optional transform to apply to the PDF images
|
||||
"""
|
||||
self.root_dir = Path(root_dir)
|
||||
self.transform = transform
|
||||
self.image_transform = image_transform
|
||||
self.target_longest_image_dim = target_longest_image_dim
|
||||
self.samples = []
|
||||
|
||||
# Find all markdown files recursively
|
||||
@ -115,13 +116,13 @@ class MarkdownPDFDocumentDataset(Dataset):
|
||||
text = parts[2].strip()
|
||||
|
||||
# Render PDF to image
|
||||
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1)
|
||||
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1, target_longest_image_dim=self.target_longest_image_dim)
|
||||
png_bytes = base64.b64decode(base64_png)
|
||||
image = Image.open(BytesIO(png_bytes))
|
||||
|
||||
# Apply transform if provided
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
if self.image_transform:
|
||||
image = self.image_transform(image)
|
||||
|
||||
return {
|
||||
'image': image,
|
||||
@ -146,7 +147,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Test dataset initialization
|
||||
print(f"\nTesting dataset with root directory: {args.root_dir}")
|
||||
dataset = MarkdownPDFDocumentDataset(args.root_dir)
|
||||
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=None)
|
||||
|
||||
print(f"\nDataset length: {len(dataset)}")
|
||||
|
||||
@ -175,7 +176,7 @@ if __name__ == "__main__":
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, transform=transform)
|
||||
dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=transform)
|
||||
transformed_sample = dataset_with_transform[0]
|
||||
print(f"Transformed image type: {type(transformed_sample['image'])}")
|
||||
print(f"Transformed image shape: {transformed_sample['image'].shape}")
|
Loading…
x
Reference in New Issue
Block a user