diff --git a/pdelfin/train/config/qwen2vl-7b-lora.yaml b/pdelfin/train/config/qwen2vl-7b-lora.yaml index f399e1f..41bc8a1 100644 --- a/pdelfin/train/config/qwen2vl-7b-lora.yaml +++ b/pdelfin/train/config/qwen2vl-7b-lora.yaml @@ -56,7 +56,7 @@ hparams: pad_multiple_of: 16 log_every_steps: 50 eval_every_steps: 500 - optim: adamw_bnb_8bit + optim: adamw_torch lr_scheduler: cosine weight_decay: 0.01 warmup_ratio: 0.03 diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 8cd26b8..91e4540 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -137,7 +137,12 @@ def run_train(config: TrainConfig): model = get_peft_model(model=model, peft_config=peft_config) log_trainable_parameters(model=model, logger=logger) - formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) + # formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) + + # Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance + formatted_dataset = dataset.to_iterable_dataset(num_shards=64) + formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500) + print(formatted_dataset) print("---------------")