From c36b5df2afd60b84781ccb1d99d84b33567fed46 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Sat, 28 Jun 2025 22:46:12 +0000 Subject: [PATCH] Cleanup collator --- olmocr/train/train.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/olmocr/train/train.py b/olmocr/train/train.py index b9988c5..3eb7ed9 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -4,6 +4,7 @@ Simple script to test OlmOCR dataset loading with YAML configuration. import argparse import logging +import numpy as np from transformers import ( AutoProcessor, @@ -28,9 +29,10 @@ logging.basicConfig( logger = logging.getLogger(__name__) -def create_data_collator(): - """Create a data collator for vision-language models.""" - def collate_fn(examples): +class QwenDataCollator: + """Data collator for vision-language models that handles numpy arrays.""" + + def __call__(self, examples): # Filter out None values and extract the fields we need batch = { 'input_ids': [], @@ -42,11 +44,22 @@ def create_data_collator(): for example in examples: if example is not None: - batch['input_ids'].append(example['input_ids']) - batch['attention_mask'].append(example['attention_mask']) - batch['labels'].append(example['labels']) - batch['pixel_values'].append(example['pixel_values']) - batch['image_grid_thw'].append(example['image_grid_thw']) + # Convert numpy arrays to tensors + batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids']) + batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask']) + batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels']) + + # Handle pixel_values which might be numpy array or already a tensor + pixel_values = example['pixel_values'] + if isinstance(pixel_values, np.ndarray): + pixel_values = torch.from_numpy(pixel_values) + batch['pixel_values'].append(pixel_values) + + # Handle image_grid_thw + image_grid_thw = example['image_grid_thw'] + if isinstance(image_grid_thw, np.ndarray): + image_grid_thw = torch.from_numpy(image_grid_thw) + batch['image_grid_thw'].append(image_grid_thw) # Convert lists to tensors with proper padding # Note: For Qwen2-VL, we typically handle variable length sequences @@ -58,8 +71,6 @@ def create_data_collator(): 'pixel_values': batch['pixel_values'], # Keep as list for now 'image_grid_thw': torch.stack(batch['image_grid_thw']) } - - return collate_fn def main(): @@ -215,7 +226,7 @@ def main(): args=training_args, train_dataset=train_dataset, eval_dataset=eval_datasets, - data_collator=create_data_collator(), + data_collator=QwenDataCollator(), callbacks=callbacks, )