Going back to non iterable dataset, so shuffling works better, applying a light filter

This commit is contained in:
Jake Poznanski 2024-09-27 15:48:56 +00:00
parent 65a9c9981e
commit 22b765e6be
4 changed files with 18 additions and 20 deletions

View File

@ -11,7 +11,6 @@ from logging import Logger
import boto3
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
from .core.config import DataConfig, SourceConfig
# Configure logging

View File

@ -5,9 +5,15 @@ import base64
import torch # Make sure to import torch as it's used in the DataCollator
def filter_by_max_seq_len(example, max_seq_len=4500):
sizes = example["input_ids"].shape
return sizes[-1] <= max_seq_len
def filter_by_max_seq_len(example, processor, max_prompt_len: int=2000, max_response_len: int=2000):
if len(processor.tokenizer.tokenize(example["input_prompt_text"])) > max_prompt_len:
return False
if len(processor.tokenizer.tokenize(example["response"])) > max_response_len:
return False
return True
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):

View File

@ -49,7 +49,7 @@ from .utils import (
from pdelfin.train.dataloader import make_dataset
from pdelfin.train.dataprep import filter_by_max_seq_len, prepare_data_for_qwen2_training
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len
class CheckpointUploadCallback(TrainerCallback):
@ -137,17 +137,10 @@ def run_train(config: TrainConfig):
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)
# formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
filtered_dataset = {split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset}
# Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
validation_ds = dataset["validation"]
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len)
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=validation_ds.column_names)
print(train_ds)
print(validation_ds)
formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(formatted_dataset)
print("---------------")
save_path = join_path("", config.save.path, run_name.run)
@ -188,9 +181,6 @@ def run_train(config: TrainConfig):
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False,
accelerator_config={
"dispatch_batches": False
}
)
# Set the collator
@ -201,8 +191,8 @@ def run_train(config: TrainConfig):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=validation_ds,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["validation"], # pyright: ignore
tokenizer=processor.tokenizer,
#Collator is not needed as we are doing batch size 1 for now...
#data_collator=collator,

View File

@ -40,6 +40,9 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
from pdelfin.train.dataprep import filter_by_max_seq_len
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor))
formatted_dataset = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=50, shuffle=False)