mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-07 20:52:08 +00:00
Removing lambda due to pickling errors
This commit is contained in:
parent
61dd7bb61f
commit
84e9da637c
@ -4,6 +4,12 @@ from PIL import Image
|
||||
import base64
|
||||
import torch # Make sure to import torch as it's used in the DataCollator
|
||||
|
||||
|
||||
def filter_by_max_seq_len(example, max_seq_len=4500):
|
||||
sizes = example["input_ids"].shape
|
||||
return sizes[-1] <= max_seq_len
|
||||
|
||||
|
||||
def prepare_data_for_qwen2_training(example, processor):
|
||||
# Prepare messages
|
||||
messages = [
|
||||
|
||||
@ -49,7 +49,7 @@ from .utils import (
|
||||
|
||||
|
||||
from pdelfin.train.dataloader import make_dataset
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
|
||||
from pdelfin.train.dataprep import filter_by_max_seq_len, prepare_data_for_qwen2_training
|
||||
|
||||
|
||||
class CheckpointUploadCallback(TrainerCallback):
|
||||
@ -143,7 +143,7 @@ def run_train(config: TrainConfig):
|
||||
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
|
||||
validation_ds = dataset["validation"]
|
||||
|
||||
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[0] < 4500)
|
||||
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor)).filter(filter_by_max_seq_len)
|
||||
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor))
|
||||
|
||||
print(train_ds)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user