Cleaning data loader

This commit is contained in:
Jake Poznanski 2025-06-11 21:41:18 +00:00
parent 5bbc1ffff7
commit d0df380ae9

View File

@ -19,6 +19,23 @@ class StandardFrontMatter:
is_table: bool
is_diagram: bool
def __post_init__(self):
# Validate rotation_correction is one of the allowed values
if self.rotation_correction not in {0, 90, 180, 270}:
raise ValueError("rotation_correction must be one of [0, 90, 180, 270].")
# Type checks
if not isinstance(self.primary_language, (str, type(None))):
raise TypeError("primary_language must be of type Optional[str].")
if not isinstance(self.is_rotation_valid, bool):
raise TypeError("is_rotation_valid must be of type bool.")
if not isinstance(self.rotation_correction, int):
raise TypeError("rotation_correction must be of type int.")
if not isinstance(self.is_table, bool):
raise TypeError("is_table must be of type bool.")
if not isinstance(self.is_diagram, bool):
raise TypeError("is_diagram must be of type bool.")
class MarkdownPDFDocumentDataset(Dataset):
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
@ -207,7 +224,7 @@ if __name__ == "__main__":
# Test dataset initialization
print(f"\nTesting dataset with root directory: {args.root_dir}")
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=None)
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, front_matter_class=StandardFrontMatter, image_transform=None)
print(f"\nDataset length: {len(dataset)}")
@ -225,39 +242,4 @@ if __name__ == "__main__":
print(f"Image size: {first_sample['image'].size}")
print(f"PDF Path: {first_sample['pdf_path']}")
print(f"Front Matter: {first_sample['front_matter']}")
print(f"Text preview (first 200 chars): {first_sample['text'][:200]}...")
# Test with transforms
print("\nTesting with torchvision transforms:")
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
])
dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=transform)
transformed_sample = dataset_with_transform[0]
print(f"Transformed image type: {type(transformed_sample['image'])}")
print(f"Transformed image shape: {transformed_sample['image'].shape}")
# Test with front matter validation
print("\n\nTesting with front matter validation:")
dataset_with_validation = MarkdownPDFDocumentDataset(
args.root_dir,
target_longest_image_dim=1024,
front_matter_class=StandardFrontMatter
)
validated_sample = dataset_with_validation[0]
print(f"Front matter type: {type(validated_sample['front_matter'])}")
print(f"Front matter: {validated_sample['front_matter']}")
# Access fields directly
fm = validated_sample['front_matter']
print(f"\nAccessing fields:")
print(f" primary_language: {fm.primary_language}")
print(f" is_rotation_valid: {fm.is_rotation_valid}")
print(f" rotation_correction: {fm.rotation_correction}")
print(f" is_table: {fm.is_table}")
print(f" is_diagram: {fm.is_diagram}")
print(f"Text: {first_sample['text']}...")