mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-06 21:30:23 +00:00
Cleanup collator
This commit is contained in:
parent
887190e961
commit
c36b5df2af
@ -4,6 +4,7 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
@ -28,9 +29,10 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_data_collator():
|
class QwenDataCollator:
|
||||||
"""Create a data collator for vision-language models."""
|
"""Data collator for vision-language models that handles numpy arrays."""
|
||||||
def collate_fn(examples):
|
|
||||||
|
def __call__(self, examples):
|
||||||
# Filter out None values and extract the fields we need
|
# Filter out None values and extract the fields we need
|
||||||
batch = {
|
batch = {
|
||||||
'input_ids': [],
|
'input_ids': [],
|
||||||
@ -42,11 +44,22 @@ def create_data_collator():
|
|||||||
|
|
||||||
for example in examples:
|
for example in examples:
|
||||||
if example is not None:
|
if example is not None:
|
||||||
batch['input_ids'].append(example['input_ids'])
|
# Convert numpy arrays to tensors
|
||||||
batch['attention_mask'].append(example['attention_mask'])
|
batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids'])
|
||||||
batch['labels'].append(example['labels'])
|
batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask'])
|
||||||
batch['pixel_values'].append(example['pixel_values'])
|
batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels'])
|
||||||
batch['image_grid_thw'].append(example['image_grid_thw'])
|
|
||||||
|
# Handle pixel_values which might be numpy array or already a tensor
|
||||||
|
pixel_values = example['pixel_values']
|
||||||
|
if isinstance(pixel_values, np.ndarray):
|
||||||
|
pixel_values = torch.from_numpy(pixel_values)
|
||||||
|
batch['pixel_values'].append(pixel_values)
|
||||||
|
|
||||||
|
# Handle image_grid_thw
|
||||||
|
image_grid_thw = example['image_grid_thw']
|
||||||
|
if isinstance(image_grid_thw, np.ndarray):
|
||||||
|
image_grid_thw = torch.from_numpy(image_grid_thw)
|
||||||
|
batch['image_grid_thw'].append(image_grid_thw)
|
||||||
|
|
||||||
# Convert lists to tensors with proper padding
|
# Convert lists to tensors with proper padding
|
||||||
# Note: For Qwen2-VL, we typically handle variable length sequences
|
# Note: For Qwen2-VL, we typically handle variable length sequences
|
||||||
@ -59,8 +72,6 @@ def create_data_collator():
|
|||||||
'image_grid_thw': torch.stack(batch['image_grid_thw'])
|
'image_grid_thw': torch.stack(batch['image_grid_thw'])
|
||||||
}
|
}
|
||||||
|
|
||||||
return collate_fn
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
||||||
@ -215,7 +226,7 @@ def main():
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_datasets,
|
eval_dataset=eval_datasets,
|
||||||
data_collator=create_data_collator(),
|
data_collator=QwenDataCollator(),
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user