diff --git a/pdelfin/train/loaddataset.py b/pdelfin/train/loaddataset.py index 05d49c1..279e932 100644 --- a/pdelfin/train/loaddataset.py +++ b/pdelfin/train/loaddataset.py @@ -20,7 +20,13 @@ def main(): print("Training dataset........") print(train_dataset) - print(train_dataset[0]) + + train_example = train_dataset[0] + print(train_example) + print({(x, y.shape) for x,y in train_example.items()}) + print("\nTokens") + print(processor.tokenizer.batch_decode(train_example["input_ids"])) + print("\n\n") print("Validation dataset........") diff --git a/tests/test_molmo.py b/tests/test_molmo.py index 28df0b2..947a289 100644 --- a/tests/test_molmo.py +++ b/tests/test_molmo.py @@ -37,6 +37,7 @@ class MolmoProcessorTest(unittest.TestCase): print(inputs) print("\nShapes") + # {('input_ids', torch.Size([1, 589])), ('images', torch.Size([1, 5, 576, 588])), ('image_masks', torch.Size([1, 5, 576])), ('image_input_idx', torch.Size([1, 5, 144]))} print({(x, y.shape) for x,y in inputs.items()}) print("\nTokens")