Removing lambda due to pickling errors

This commit is contained in:
Jake Poznanski 2024-09-26 21:39:08 +00:00
parent 61dd7bb61f
commit 84e9da637c
2 changed files with 8 additions and 2 deletions

View File

@ -4,6 +4,12 @@ from PIL import Image
import base64
import torch # Make sure to import torch as it's used in the DataCollator
def filter_by_max_seq_len(example, max_seq_len=4500):
sizes = example["input_ids"].shape
return sizes[-1] <= max_seq_len
def prepare_data_for_qwen2_training(example, processor):
# Prepare messages
messages = [

View File

@ -49,7 +49,7 @@ from .utils import (
from pdelfin.train.dataloader import make_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
from pdelfin.train.dataprep import filter_by_max_seq_len, prepare_data_for_qwen2_training
class CheckpointUploadCallback(TrainerCallback):
@ -143,7 +143,7 @@ def run_train(config: TrainConfig):
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
validation_ds = dataset["validation"]
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[0] < 4500)
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor)).filter(filter_by_max_seq_len)
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor))
print(train_ds)