mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-25 06:06:23 +00:00
Going back to non iterable dataset, so shuffling works better, applying a light filter
This commit is contained in:
parent
65a9c9981e
commit
22b765e6be
@ -11,7 +11,6 @@ from logging import Logger
|
||||
|
||||
import boto3
|
||||
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
|
||||
|
||||
from .core.config import DataConfig, SourceConfig
|
||||
|
||||
# Configure logging
|
||||
|
||||
@ -5,9 +5,15 @@ import base64
|
||||
import torch # Make sure to import torch as it's used in the DataCollator
|
||||
|
||||
|
||||
def filter_by_max_seq_len(example, max_seq_len=4500):
|
||||
sizes = example["input_ids"].shape
|
||||
return sizes[-1] <= max_seq_len
|
||||
def filter_by_max_seq_len(example, processor, max_prompt_len: int=2000, max_response_len: int=2000):
|
||||
if len(processor.tokenizer.tokenize(example["input_prompt_text"])) > max_prompt_len:
|
||||
return False
|
||||
|
||||
if len(processor.tokenizer.tokenize(example["response"])) > max_response_len:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
|
||||
|
||||
@ -49,7 +49,7 @@ from .utils import (
|
||||
|
||||
|
||||
from pdelfin.train.dataloader import make_dataset
|
||||
from pdelfin.train.dataprep import filter_by_max_seq_len, prepare_data_for_qwen2_training
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len
|
||||
|
||||
|
||||
class CheckpointUploadCallback(TrainerCallback):
|
||||
@ -137,17 +137,10 @@ def run_train(config: TrainConfig):
|
||||
model = get_peft_model(model=model, peft_config=peft_config)
|
||||
log_trainable_parameters(model=model, logger=logger)
|
||||
|
||||
# formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
filtered_dataset = {split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset}
|
||||
|
||||
# Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance
|
||||
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
|
||||
validation_ds = dataset["validation"]
|
||||
|
||||
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len)
|
||||
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=validation_ds.column_names)
|
||||
|
||||
print(train_ds)
|
||||
print(validation_ds)
|
||||
formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
print(formatted_dataset)
|
||||
print("---------------")
|
||||
|
||||
save_path = join_path("", config.save.path, run_name.run)
|
||||
@ -188,9 +181,6 @@ def run_train(config: TrainConfig):
|
||||
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
|
||||
max_grad_norm=config.hparams.clip_grad_norm,
|
||||
remove_unused_columns=False,
|
||||
accelerator_config={
|
||||
"dispatch_batches": False
|
||||
}
|
||||
)
|
||||
|
||||
# Set the collator
|
||||
@ -201,8 +191,8 @@ def run_train(config: TrainConfig):
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=validation_ds,
|
||||
train_dataset=formatted_dataset["train"],
|
||||
eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
||||
tokenizer=processor.tokenizer,
|
||||
#Collator is not needed as we are doing batch size 1 for now...
|
||||
#data_collator=collator,
|
||||
|
||||
@ -40,6 +40,9 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
|
||||
from pdelfin.train.dataprep import filter_by_max_seq_len
|
||||
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor))
|
||||
|
||||
formatted_dataset = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=50, shuffle=False)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user