diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 3eb7ed9..8bde005 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -68,7 +68,7 @@ class QwenDataCollator: 'input_ids': torch.stack(batch['input_ids']), 'attention_mask': torch.stack(batch['attention_mask']), 'labels': torch.stack(batch['labels']), - 'pixel_values': batch['pixel_values'], # Keep as list for now + 'pixel_values': torch.stack(batch['pixel_values']), # Stack into tensor 'image_grid_thw': torch.stack(batch['image_grid_thw']) }