2024-09-18 21:42:09 +00:00
|
|
|
# Step 1, load the data
|
|
|
|
# Probably, we want to see just a folder with openai batch input jsonls, plus the batch output jsonls
|
|
|
|
# TODO: Figure out hyperparameters for image sizing
|
|
|
|
# Step 2. Load those prompts through and do a forward pass to calculate the loss
|
|
|
|
|
|
|
|
# Step 3. Add hugging face accelerate for training
|
|
|
|
|
|
|
|
# Step 4. Checkpointing code, both saving and reloading to restart
|
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
# Step 5. Move over from interactive session to gantry launch script
|
2024-09-19 21:55:07 +00:00
|
|
|
|
|
|
|
import os
|
2024-09-20 09:25:54 -07:00
|
|
|
import json
|
2024-09-19 22:16:59 +00:00
|
|
|
import base64
|
2024-09-19 21:55:07 +00:00
|
|
|
import logging
|
2024-09-19 22:16:59 +00:00
|
|
|
from io import BytesIO
|
|
|
|
from PIL import Image
|
2024-09-19 21:55:07 +00:00
|
|
|
from functools import partial
|
|
|
|
from logging import Logger
|
|
|
|
from pathlib import Path
|
|
|
|
from tempfile import TemporaryDirectory
|
|
|
|
from typing import Optional
|
2024-09-20 09:25:54 -07:00
|
|
|
from tqdm import tqdm
|
2024-09-19 21:55:07 +00:00
|
|
|
|
|
|
|
import accelerate
|
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
from datasets.utils import disable_progress_bars
|
|
|
|
from datasets.utils.logging import set_verbosity
|
|
|
|
from peft import LoraConfig, get_peft_model # pyright: ignore
|
|
|
|
from transformers import (
|
|
|
|
AutoModelForCausalLM,
|
|
|
|
Trainer,
|
|
|
|
TrainerCallback,
|
|
|
|
TrainingArguments,
|
|
|
|
Qwen2VLForConditionalGeneration,
|
|
|
|
AutoProcessor
|
|
|
|
)
|
|
|
|
from transformers.integrations import WandbCallback
|
|
|
|
from transformers.trainer_callback import TrainerControl, TrainerState
|
|
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
|
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
from pdelfin.train.core.cli import make_cli, save_config, to_native_types
|
|
|
|
from pdelfin.train.core.config import TrainConfig
|
|
|
|
from pdelfin.train.core.loggers import get_logger
|
|
|
|
from pdelfin.train.core.paths import copy_dir, join_path
|
|
|
|
from pdelfin.train.core.state import BeakerState
|
|
|
|
|
|
|
|
from .utils import (
|
|
|
|
RunName,
|
|
|
|
get_local_dir,
|
|
|
|
log_trainable_parameters,
|
|
|
|
packing_collator,
|
|
|
|
setup_environment,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
2024-09-20 09:25:54 -07:00
|
|
|
from pdelfin.train.dataprep import prepare_data_for_qwen2_training
|
2024-09-19 21:55:07 +00:00
|
|
|
|
|
|
|
|
2024-09-19 22:16:59 +00:00
|
|
|
def run_train(config: TrainConfig):
|
2024-09-19 21:55:07 +00:00
|
|
|
train_ds = build_batch_query_response_vision_dataset(
|
2024-09-20 09:25:54 -07:00
|
|
|
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
|
|
|
|
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
|
2024-09-19 21:55:07 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
2024-09-20 12:01:05 -07:00
|
|
|
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
|
2024-09-19 21:55:07 +00:00
|
|
|
)
|
2024-09-20 12:01:05 -07:00
|
|
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
2024-09-19 21:55:07 +00:00
|
|
|
|
2024-09-20 09:25:54 -07:00
|
|
|
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor),
|
|
|
|
remove_columns=train_ds.column_names)
|
|
|
|
|
|
|
|
print(train_ds)
|
2024-09-19 22:16:59 +00:00
|
|
|
|
|
|
|
|
2024-09-19 21:55:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
train_config = make_cli(TrainConfig) # pyright: ignore
|
|
|
|
run_train(train_config)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|