Cleanup collator

This commit is contained in:
Jake Poznanski 2025-06-28 22:46:12 +00:00
parent 887190e961
commit c36b5df2af

View File

@ -4,6 +4,7 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse
import logging
import numpy as np
from transformers import (
AutoProcessor,
@ -28,9 +29,10 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def create_data_collator():
"""Create a data collator for vision-language models."""
def collate_fn(examples):
class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays."""
def __call__(self, examples):
# Filter out None values and extract the fields we need
batch = {
'input_ids': [],
@ -42,11 +44,22 @@ def create_data_collator():
for example in examples:
if example is not None:
batch['input_ids'].append(example['input_ids'])
batch['attention_mask'].append(example['attention_mask'])
batch['labels'].append(example['labels'])
batch['pixel_values'].append(example['pixel_values'])
batch['image_grid_thw'].append(example['image_grid_thw'])
# Convert numpy arrays to tensors
batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids'])
batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask'])
batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels'])
# Handle pixel_values which might be numpy array or already a tensor
pixel_values = example['pixel_values']
if isinstance(pixel_values, np.ndarray):
pixel_values = torch.from_numpy(pixel_values)
batch['pixel_values'].append(pixel_values)
# Handle image_grid_thw
image_grid_thw = example['image_grid_thw']
if isinstance(image_grid_thw, np.ndarray):
image_grid_thw = torch.from_numpy(image_grid_thw)
batch['image_grid_thw'].append(image_grid_thw)
# Convert lists to tensors with proper padding
# Note: For Qwen2-VL, we typically handle variable length sequences
@ -58,8 +71,6 @@ def create_data_collator():
'pixel_values': batch['pixel_values'], # Keep as list for now
'image_grid_thw': torch.stack(batch['image_grid_thw'])
}
return collate_fn
def main():
@ -215,7 +226,7 @@ def main():
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_datasets,
data_collator=create_data_collator(),
data_collator=QwenDataCollator(),
callbacks=callbacks,
)