diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 2f5d785..2b9c294 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -116,7 +116,7 @@ def run_train(config: TrainConfig): setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group) processor = AutoProcessor.from_pretrained(config.model.name_or_path) - train_dataset, valid_dataset = make_dataset(config) + train_dataset, valid_dataset = make_dataset(config, processor) model = Qwen2VLForConditionalGeneration.from_pretrained( config.model.name_or_path, torch_dtype=torch.bfloat16,