diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index c53bb4c..3a92f6b 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -188,6 +188,9 @@ def run_train(config: TrainConfig): label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885 max_grad_norm=config.hparams.clip_grad_norm, remove_unused_columns=False, + accelerator_config={ + "dispatch_batches": False + } ) # Set the collator