diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 0d0d9b9..a3be3ee 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -15,6 +15,7 @@ from tqdm import tqdm import accelerate import torch import torch.distributed +from datasets import DatasetDict from datasets.utils import disable_progress_bars from datasets.utils.logging import set_verbosity from peft import LoraConfig, get_peft_model # pyright: ignore @@ -137,8 +138,8 @@ def run_train(config: TrainConfig): model = get_peft_model(model=model, peft_config=peft_config) log_trainable_parameters(model=model, logger=logger) - filtered_dataset = {split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset} - + # Do final filtering, and prep for running model forward() + filtered_dataset = DatasetDict(**{split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset}) formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) print(formatted_dataset) print("---------------")