diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index c803191..a1b4e2f 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -198,8 +198,8 @@ def run_train(config: TrainConfig): trainer = Trainer( model=model, args=training_args, - train_dataset=formatted_dataset["train"], - eval_dataset=formatted_dataset["validation"], # pyright: ignore + train_dataset=train_ds, + eval_dataset=validation_ds, tokenizer=processor.tokenizer, #Collator is not needed as we are doing batch size 1 for now... #data_collator=collator,