From 05fdb81da2b1bd1a1f19c0a54c5f5d006b5f93bb Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Thu, 26 Sep 2024 19:01:34 +0000 Subject: [PATCH] map and filter on iterable dataset --- pdelfin/train/config/qwen2vl-7b-lora.yaml | 2 +- pdelfin/train/train.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pdelfin/train/config/qwen2vl-7b-lora.yaml b/pdelfin/train/config/qwen2vl-7b-lora.yaml index f399e1f..41bc8a1 100644 --- a/pdelfin/train/config/qwen2vl-7b-lora.yaml +++ b/pdelfin/train/config/qwen2vl-7b-lora.yaml @@ -56,7 +56,7 @@ hparams: pad_multiple_of: 16 log_every_steps: 50 eval_every_steps: 500 - optim: adamw_bnb_8bit + optim: adamw_torch lr_scheduler: cosine weight_decay: 0.01 warmup_ratio: 0.03 diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 8cd26b8..91e4540 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -137,7 +137,12 @@ def run_train(config: TrainConfig): model = get_peft_model(model=model, peft_config=peft_config) log_trainable_parameters(model=model, logger=logger) - formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) + # formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) + + # Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance + formatted_dataset = dataset.to_iterable_dataset(num_shards=64) + formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500) + print(formatted_dataset) print("---------------")