From d0df380ae9a563b8887649335b01545ff76cccfc Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 11 Jun 2025 21:41:18 +0000 Subject: [PATCH] Cleaning data loader --- olmocr/train/dataloader.py | 56 +++++++++++++------------------------- 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index a0f71b2..a9f7120 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -19,6 +19,23 @@ class StandardFrontMatter: is_table: bool is_diagram: bool + def __post_init__(self): + # Validate rotation_correction is one of the allowed values + if self.rotation_correction not in {0, 90, 180, 270}: + raise ValueError("rotation_correction must be one of [0, 90, 180, 270].") + + # Type checks + if not isinstance(self.primary_language, (str, type(None))): + raise TypeError("primary_language must be of type Optional[str].") + if not isinstance(self.is_rotation_valid, bool): + raise TypeError("is_rotation_valid must be of type bool.") + if not isinstance(self.rotation_correction, int): + raise TypeError("rotation_correction must be of type int.") + if not isinstance(self.is_table, bool): + raise TypeError("is_table must be of type bool.") + if not isinstance(self.is_diagram, bool): + raise TypeError("is_diagram must be of type bool.") + class MarkdownPDFDocumentDataset(Dataset): def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None): @@ -207,7 +224,7 @@ if __name__ == "__main__": # Test dataset initialization print(f"\nTesting dataset with root directory: {args.root_dir}") - dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=None) + dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, front_matter_class=StandardFrontMatter, image_transform=None) print(f"\nDataset length: {len(dataset)}") @@ -225,39 +242,4 @@ if __name__ == "__main__": print(f"Image size: {first_sample['image'].size}") print(f"PDF Path: {first_sample['pdf_path']}") print(f"Front Matter: {first_sample['front_matter']}") - print(f"Text preview (first 200 chars): {first_sample['text'][:200]}...") - - # Test with transforms - print("\nTesting with torchvision transforms:") - import torchvision.transforms as transforms - - transform = transforms.Compose([ - transforms.Resize((1024, 1024)), - transforms.ToTensor(), - ]) - - dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=transform) - transformed_sample = dataset_with_transform[0] - print(f"Transformed image type: {type(transformed_sample['image'])}") - print(f"Transformed image shape: {transformed_sample['image'].shape}") - - # Test with front matter validation - print("\n\nTesting with front matter validation:") - dataset_with_validation = MarkdownPDFDocumentDataset( - args.root_dir, - target_longest_image_dim=1024, - front_matter_class=StandardFrontMatter - ) - - validated_sample = dataset_with_validation[0] - print(f"Front matter type: {type(validated_sample['front_matter'])}") - print(f"Front matter: {validated_sample['front_matter']}") - - # Access fields directly - fm = validated_sample['front_matter'] - print(f"\nAccessing fields:") - print(f" primary_language: {fm.primary_language}") - print(f" is_rotation_valid: {fm.is_rotation_valid}") - print(f" rotation_correction: {fm.rotation_correction}") - print(f" is_table: {fm.is_table}") - print(f" is_diagram: {fm.is_diagram}") \ No newline at end of file + print(f"Text: {first_sample['text']}...")