From ebd40f908464741b2aa5251374747a03443ac4b2 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Mon, 7 Oct 2024 12:59:27 -0700 Subject: [PATCH] Hopefully fixing dataloader for now --- pdelfin/train/train.py | 34 ++++++++++------------------------ pdelfin/train/utils.py | 32 ++++++++++++++++++++++++++++++++ tests/test_dataloader.py | 15 +++++++++++++++ 3 files changed, 57 insertions(+), 24 deletions(-) diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index b639dba..c9a6ecd 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -47,6 +47,7 @@ from .utils import ( log_trainable_parameters, packing_collator, setup_environment, + make_dataset ) @@ -113,6 +114,15 @@ def run_train(config: TrainConfig): run_name = RunName.get(config) setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group) + accelerator = accelerate.Accelerator() + + # Build and download the dataset on process 0 + if accelerator.is_main_process: + make_dataset(config) + + accelerator.wait_for_everyone() + + train_dataset, valid_dataset = make_dataset(config) model = Qwen2VLForConditionalGeneration.from_pretrained( config.model.name_or_path, torch_dtype=torch.bfloat16, @@ -132,30 +142,6 @@ def run_train(config: TrainConfig): model = get_peft_model(model=model, peft_config=peft_config) log_trainable_parameters(model=model, logger=logger) - random.seed(config.train_data.seed) - - # Training sets get all concatenated and shuffled - train_dataset = ( - concatenate_datasets( - [ - build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) - for source in config.train_data.sources - ] - ) - .filter(partial(filter_by_max_seq_len, processor=processor)) - .with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) - ) - - # Validation sets get put into a datasetdict so each can report a loss separately - valid_dataset = DatasetDict( - **{ - source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) - .filter(partial(filter_by_max_seq_len, processor=processor)) - .with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) - for source in config.valid_data.sources - } - ) - save_path = join_path("", config.save.path, run_name.run) save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore diff --git a/pdelfin/train/utils.py b/pdelfin/train/utils.py index abbc410..0636d5d 100644 --- a/pdelfin/train/utils.py +++ b/pdelfin/train/utils.py @@ -1,6 +1,7 @@ import json import multiprocessing import os +import random from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime @@ -13,6 +14,7 @@ import torch import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import PrecisionType +from datasets import Dataset, concatenate_datasets, DatasetDict from .core.cli import to_native_types from .core.config import AwsConfig, TrainConfig, WandbConfig @@ -23,6 +25,9 @@ from .core.state import BeakerState T = TypeVar("T") +from pdelfin.train.dataloader import build_batch_query_response_vision_dataset +from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len + def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype: pt = PrecisionType(accelerator.mixed_precision) @@ -34,6 +39,33 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype: return torch.float8_e4m3fn return torch.float32 +def make_dataset(config: TrainConfig) -> tuple[Dataset, Dataset]: + random.seed(config.train_data.seed) + + # Training sets get all concatenated and shuffled + train_dataset = ( + concatenate_datasets( + [ + build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) + for source in config.train_data.sources + ] + ) + .filter(partial(filter_by_max_seq_len, processor=processor)) + .with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) + ) + + # Validation sets get put into a datasetdict so each can report a loss separately + valid_dataset = DatasetDict( + **{ + source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) + .filter(partial(filter_by_max_seq_len, processor=processor)) + .with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) + for source in config.valid_data.sources + } + ) + + return train_dataset, valid_dataset + def setup_environment( aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 4c53ee5..610ec70 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -37,6 +37,21 @@ class TestBatchQueryResponseDataset(unittest.TestCase): print(ds[0]) + def testLocalDS(self): + ds = build_batch_query_response_vision_dataset( + query_glob_path="/root/openai_batch_data_v5_1_train/*.jsonl", + response_glob_path="/root/openai_batch_data_v5_1_train_done/*.json", + ) + + print(ds) + + ds.to_parquet("/root/trainds_parquet/bigds.parquet") + + processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") + from pdelfin.train.dataprep import filter_by_max_seq_len + ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000)) + + print(ds[0]) def testPlotSequenceLengthHistogram(self): import plotly.express as px