mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-18 04:58:15 +00:00
map and filter on iterable dataset
This commit is contained in:
parent
f14e910175
commit
05fdb81da2
@ -56,7 +56,7 @@ hparams:
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 50
|
||||
eval_every_steps: 500
|
||||
optim: adamw_bnb_8bit
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.03
|
||||
|
@ -137,7 +137,12 @@ def run_train(config: TrainConfig):
|
||||
model = get_peft_model(model=model, peft_config=peft_config)
|
||||
log_trainable_parameters(model=model, logger=logger)
|
||||
|
||||
formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
# formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
|
||||
# Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance
|
||||
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
|
||||
formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)
|
||||
|
||||
print(formatted_dataset)
|
||||
print("---------------")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user