mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 11:35:29 +00:00
Cleanup collator
This commit is contained in:
parent
887190e961
commit
c36b5df2af
@ -4,6 +4,7 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
@ -28,9 +29,10 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_data_collator():
|
||||
"""Create a data collator for vision-language models."""
|
||||
def collate_fn(examples):
|
||||
class QwenDataCollator:
|
||||
"""Data collator for vision-language models that handles numpy arrays."""
|
||||
|
||||
def __call__(self, examples):
|
||||
# Filter out None values and extract the fields we need
|
||||
batch = {
|
||||
'input_ids': [],
|
||||
@ -42,11 +44,22 @@ def create_data_collator():
|
||||
|
||||
for example in examples:
|
||||
if example is not None:
|
||||
batch['input_ids'].append(example['input_ids'])
|
||||
batch['attention_mask'].append(example['attention_mask'])
|
||||
batch['labels'].append(example['labels'])
|
||||
batch['pixel_values'].append(example['pixel_values'])
|
||||
batch['image_grid_thw'].append(example['image_grid_thw'])
|
||||
# Convert numpy arrays to tensors
|
||||
batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids'])
|
||||
batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask'])
|
||||
batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels'])
|
||||
|
||||
# Handle pixel_values which might be numpy array or already a tensor
|
||||
pixel_values = example['pixel_values']
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = torch.from_numpy(pixel_values)
|
||||
batch['pixel_values'].append(pixel_values)
|
||||
|
||||
# Handle image_grid_thw
|
||||
image_grid_thw = example['image_grid_thw']
|
||||
if isinstance(image_grid_thw, np.ndarray):
|
||||
image_grid_thw = torch.from_numpy(image_grid_thw)
|
||||
batch['image_grid_thw'].append(image_grid_thw)
|
||||
|
||||
# Convert lists to tensors with proper padding
|
||||
# Note: For Qwen2-VL, we typically handle variable length sequences
|
||||
@ -59,8 +72,6 @@ def create_data_collator():
|
||||
'image_grid_thw': torch.stack(batch['image_grid_thw'])
|
||||
}
|
||||
|
||||
return collate_fn
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
||||
@ -215,7 +226,7 @@ def main():
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_datasets,
|
||||
data_collator=create_data_collator(),
|
||||
data_collator=QwenDataCollator(),
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user