map and filter on iterable dataset

This commit is contained in:
Jake Poznanski 2024-09-26 19:01:34 +00:00
parent f14e910175
commit 05fdb81da2
2 changed files with 7 additions and 2 deletions

View File

@ -56,7 +56,7 @@ hparams:
pad_multiple_of: 16 pad_multiple_of: 16
log_every_steps: 50 log_every_steps: 50
eval_every_steps: 500 eval_every_steps: 500
optim: adamw_bnb_8bit optim: adamw_torch
lr_scheduler: cosine lr_scheduler: cosine
weight_decay: 0.01 weight_decay: 0.01
warmup_ratio: 0.03 warmup_ratio: 0.03

View File

@ -137,7 +137,12 @@ def run_train(config: TrainConfig):
model = get_peft_model(model=model, peft_config=peft_config) model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger) log_trainable_parameters(model=model, logger=logger)
formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) # formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
# Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)
print(formatted_dataset) print(formatted_dataset)
print("---------------") print("---------------")