Datasetdict fix

This commit is contained in:
Jake Poznanski 2024-09-28 03:38:29 +00:00
parent decfd7fbc1
commit e53f782b0f

View File

@ -15,6 +15,7 @@ from tqdm import tqdm
import accelerate
import torch
import torch.distributed
from datasets import DatasetDict
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore
@ -137,8 +138,8 @@ def run_train(config: TrainConfig):
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)
filtered_dataset = {split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset}
# Do final filtering, and prep for running model forward()
filtered_dataset = DatasetDict(**{split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset})
formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(formatted_dataset)
print("---------------")