Hopefully fixing dataloader for now

This commit is contained in:
Jake Poznanski 2024-10-07 12:59:27 -07:00
parent 5d35461dd2
commit ebd40f9084
3 changed files with 57 additions and 24 deletions

View File

@ -47,6 +47,7 @@ from .utils import (
log_trainable_parameters,
packing_collator,
setup_environment,
make_dataset
)
@ -113,6 +114,15 @@ def run_train(config: TrainConfig):
run_name = RunName.get(config)
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
accelerator = accelerate.Accelerator()
# Build and download the dataset on process 0
if accelerator.is_main_process:
make_dataset(config)
accelerator.wait_for_everyone()
train_dataset, valid_dataset = make_dataset(config)
model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16,
@ -132,30 +142,6 @@ def run_train(config: TrainConfig):
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)
random.seed(config.train_data.seed)
# Training sets get all concatenated and shuffled
train_dataset = (
concatenate_datasets(
[
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
for source in config.train_data.sources
]
)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
)
# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
for source in config.valid_data.sources
}
)
save_path = join_path("", config.save.path, run_name.run)
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore

View File

@ -1,6 +1,7 @@
import json
import multiprocessing
import os
import random
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
@ -13,6 +14,7 @@ import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import PrecisionType
from datasets import Dataset, concatenate_datasets, DatasetDict
from .core.cli import to_native_types
from .core.config import AwsConfig, TrainConfig, WandbConfig
@ -23,6 +25,9 @@ from .core.state import BeakerState
T = TypeVar("T")
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len
def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
pt = PrecisionType(accelerator.mixed_precision)
@ -34,6 +39,33 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
return torch.float8_e4m3fn
return torch.float32
def make_dataset(config: TrainConfig) -> tuple[Dataset, Dataset]:
random.seed(config.train_data.seed)
# Training sets get all concatenated and shuffled
train_dataset = (
concatenate_datasets(
[
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
for source in config.train_data.sources
]
)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
)
# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
for source in config.valid_data.sources
}
)
return train_dataset, valid_dataset
def setup_environment(
aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str

View File

@ -37,6 +37,21 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
print(ds[0])
def testLocalDS(self):
ds = build_batch_query_response_vision_dataset(
query_glob_path="/root/openai_batch_data_v5_1_train/*.jsonl",
response_glob_path="/root/openai_batch_data_v5_1_train_done/*.json",
)
print(ds)
ds.to_parquet("/root/trainds_parquet/bigds.parquet")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
from pdelfin.train.dataprep import filter_by_max_seq_len
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))
print(ds[0])
def testPlotSequenceLengthHistogram(self):
import plotly.express as px