mirror of
https://github.com/allenai/olmocr.git
synced 2025-07-24 17:43:34 +00:00
Image params to loader
This commit is contained in:
parent
9a390e3d58
commit
aedc295e3f
@ -12,16 +12,17 @@ from olmocr.data.renderpdf import render_pdf_to_base64png
|
|||||||
|
|
||||||
|
|
||||||
class MarkdownPDFDocumentDataset(Dataset):
|
class MarkdownPDFDocumentDataset(Dataset):
|
||||||
def __init__(self, root_dir: str | PathLike, transform=None):
|
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None):
|
||||||
"""
|
"""
|
||||||
Initialize the dataset by finding all markdown files with corresponding PDFs.
|
Initialize the dataset by finding all markdown files with corresponding PDFs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
root_dir: Path to the root folder containing processed markdown and PDF files
|
root_dir: Path to the root folder containing processed markdown and PDF files
|
||||||
transform: Optional transform to apply to the PDF images
|
image_transform: Optional transform to apply to the PDF images
|
||||||
"""
|
"""
|
||||||
self.root_dir = Path(root_dir)
|
self.root_dir = Path(root_dir)
|
||||||
self.transform = transform
|
self.image_transform = image_transform
|
||||||
|
self.target_longest_image_dim = target_longest_image_dim
|
||||||
self.samples = []
|
self.samples = []
|
||||||
|
|
||||||
# Find all markdown files recursively
|
# Find all markdown files recursively
|
||||||
@ -115,13 +116,13 @@ class MarkdownPDFDocumentDataset(Dataset):
|
|||||||
text = parts[2].strip()
|
text = parts[2].strip()
|
||||||
|
|
||||||
# Render PDF to image
|
# Render PDF to image
|
||||||
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1)
|
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1, target_longest_image_dim=self.target_longest_image_dim)
|
||||||
png_bytes = base64.b64decode(base64_png)
|
png_bytes = base64.b64decode(base64_png)
|
||||||
image = Image.open(BytesIO(png_bytes))
|
image = Image.open(BytesIO(png_bytes))
|
||||||
|
|
||||||
# Apply transform if provided
|
# Apply transform if provided
|
||||||
if self.transform:
|
if self.image_transform:
|
||||||
image = self.transform(image)
|
image = self.image_transform(image)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'image': image,
|
'image': image,
|
||||||
@ -146,7 +147,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Test dataset initialization
|
# Test dataset initialization
|
||||||
print(f"\nTesting dataset with root directory: {args.root_dir}")
|
print(f"\nTesting dataset with root directory: {args.root_dir}")
|
||||||
dataset = MarkdownPDFDocumentDataset(args.root_dir)
|
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=None)
|
||||||
|
|
||||||
print(f"\nDataset length: {len(dataset)}")
|
print(f"\nDataset length: {len(dataset)}")
|
||||||
|
|
||||||
@ -175,7 +176,7 @@ if __name__ == "__main__":
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
])
|
])
|
||||||
|
|
||||||
dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, transform=transform)
|
dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=transform)
|
||||||
transformed_sample = dataset_with_transform[0]
|
transformed_sample = dataset_with_transform[0]
|
||||||
print(f"Transformed image type: {type(transformed_sample['image'])}")
|
print(f"Transformed image type: {type(transformed_sample['image'])}")
|
||||||
print(f"Transformed image shape: {transformed_sample['image'].shape}")
|
print(f"Transformed image shape: {transformed_sample['image'].shape}")
|
Loading…
x
Reference in New Issue
Block a user