From e53f782b0fd24bd7791ad0d6e42bc56fdc40d449 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Sat, 28 Sep 2024 03:38:29 +0000 Subject: [PATCH] Datasetdict fix --- pdelfin/train/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 0d0d9b9..a3be3ee 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -15,6 +15,7 @@ from tqdm import tqdm import accelerate import torch import torch.distributed +from datasets import DatasetDict from datasets.utils import disable_progress_bars from datasets.utils.logging import set_verbosity from peft import LoraConfig, get_peft_model # pyright: ignore @@ -137,8 +138,8 @@ def run_train(config: TrainConfig): model = get_peft_model(model=model, peft_config=peft_config) log_trainable_parameters(model=model, logger=logger) - filtered_dataset = {split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset} - + # Do final filtering, and prep for running model forward() + filtered_dataset = DatasetDict(**{split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset}) formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) print(formatted_dataset) print("---------------")