2024-09-18 21:42:09 +00:00
|
|
|
# Step 1, load the data
|
|
|
|
# Probably, we want to see just a folder with openai batch input jsonls, plus the batch output jsonls
|
|
|
|
# TODO: Figure out hyperparameters for image sizing
|
|
|
|
# Step 2. Load those prompts through and do a forward pass to calculate the loss
|
|
|
|
|
|
|
|
# Step 3. Add hugging face accelerate for training
|
|
|
|
|
|
|
|
# Step 4. Checkpointing code, both saving and reloading to restart
|
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
# Step 5. Move over from interactive session to gantry launch script
|
2024-09-19 21:55:07 +00:00
|
|
|
|
|
|
|
import os
|
2024-09-20 09:25:54 -07:00
|
|
|
import json
|
2024-09-19 22:16:59 +00:00
|
|
|
import base64
|
2024-09-19 21:55:07 +00:00
|
|
|
import logging
|
2024-09-20 15:53:11 -07:00
|
|
|
import time
|
2024-09-19 22:16:59 +00:00
|
|
|
from io import BytesIO
|
|
|
|
from PIL import Image
|
2024-09-19 21:55:07 +00:00
|
|
|
from functools import partial
|
|
|
|
from logging import Logger
|
|
|
|
from pathlib import Path
|
|
|
|
from tempfile import TemporaryDirectory
|
|
|
|
from typing import Optional
|
2024-09-20 09:25:54 -07:00
|
|
|
from tqdm import tqdm
|
2024-09-19 21:55:07 +00:00
|
|
|
|
|
|
|
import accelerate
|
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
from datasets.utils import disable_progress_bars
|
|
|
|
from datasets.utils.logging import set_verbosity
|
|
|
|
from peft import LoraConfig, get_peft_model # pyright: ignore
|
|
|
|
from transformers import (
|
|
|
|
AutoModelForCausalLM,
|
|
|
|
Trainer,
|
|
|
|
TrainerCallback,
|
|
|
|
TrainingArguments,
|
|
|
|
Qwen2VLForConditionalGeneration,
|
|
|
|
AutoProcessor
|
|
|
|
)
|
|
|
|
from transformers.integrations import WandbCallback
|
|
|
|
from transformers.trainer_callback import TrainerControl, TrainerState
|
|
|
|
from transformers.trainer_utils import get_last_checkpoint
|
2024-09-20 15:05:23 -07:00
|
|
|
from torch.utils.data import DataLoader
|
2024-09-19 21:55:07 +00:00
|
|
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
from pdelfin.train.core.cli import make_cli, save_config, to_native_types
|
|
|
|
from pdelfin.train.core.config import TrainConfig
|
|
|
|
from pdelfin.train.core.loggers import get_logger
|
|
|
|
from pdelfin.train.core.paths import copy_dir, join_path
|
|
|
|
from pdelfin.train.core.state import BeakerState
|
|
|
|
|
|
|
|
from .utils import (
|
|
|
|
RunName,
|
|
|
|
get_local_dir,
|
|
|
|
log_trainable_parameters,
|
|
|
|
packing_collator,
|
|
|
|
setup_environment,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
2024-09-20 15:05:23 -07:00
|
|
|
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
|
2024-09-19 21:55:07 +00:00
|
|
|
|
|
|
|
|
2024-09-20 15:53:11 -07:00
|
|
|
def get_rank() -> int:
|
|
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
|
|
return torch.distributed.get_rank()
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
2024-09-19 22:16:59 +00:00
|
|
|
def run_train(config: TrainConfig):
|
2024-09-20 15:53:11 -07:00
|
|
|
if get_rank() == 0:
|
|
|
|
logger_level = logging.INFO
|
|
|
|
else:
|
|
|
|
logger_level = logging.WARN
|
|
|
|
disable_progress_bars()
|
|
|
|
|
|
|
|
logger = get_logger(__name__, level=logger_level)
|
|
|
|
set_verbosity(logger_level)
|
|
|
|
|
|
|
|
run_name = RunName.get(config)
|
|
|
|
|
|
|
|
accelerator = accelerate.Accelerator()
|
|
|
|
|
2024-09-19 21:55:07 +00:00
|
|
|
train_ds = build_batch_query_response_vision_dataset(
|
2024-09-20 09:25:54 -07:00
|
|
|
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
|
|
|
|
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
|
2024-09-19 21:55:07 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
2024-09-20 12:01:05 -07:00
|
|
|
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
|
2024-09-19 21:55:07 +00:00
|
|
|
)
|
2024-09-20 12:01:05 -07:00
|
|
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
2024-09-19 21:55:07 +00:00
|
|
|
|
2024-09-20 15:05:23 -07:00
|
|
|
train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
2024-09-20 09:25:54 -07:00
|
|
|
print(train_ds)
|
2024-09-20 15:53:11 -07:00
|
|
|
print("---------------")
|
2024-09-20 15:05:23 -07:00
|
|
|
|
2024-09-20 15:53:11 -07:00
|
|
|
train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False)
|
|
|
|
|
|
|
|
|
|
|
|
with TemporaryDirectory() as output_dir:
|
|
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
|
|
run_name=run_name.run,
|
|
|
|
logging_steps=config.hparams.log_every_steps,
|
|
|
|
output_dir=output_dir,
|
|
|
|
eval_strategy="steps",
|
|
|
|
report_to="wandb",
|
|
|
|
# report_to=[], # disable logging to wandb, we will use a custom callback
|
|
|
|
optim=config.hparams.optim,
|
|
|
|
eval_steps=config.hparams.eval_every_steps,
|
|
|
|
learning_rate=config.hparams.learning_rate,
|
|
|
|
per_device_train_batch_size=config.hparams.batch_size,
|
|
|
|
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
|
|
|
|
gradient_checkpointing=config.hparams.gradient_checkpointing,
|
|
|
|
gradient_checkpointing_kwargs=(
|
|
|
|
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
|
|
|
|
if config.hparams.gradient_checkpointing and config.lora is not None
|
|
|
|
else {}
|
|
|
|
),
|
|
|
|
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
|
|
|
|
max_steps=config.hparams.max_steps,
|
|
|
|
weight_decay=config.hparams.weight_decay,
|
|
|
|
dataloader_num_workers=config.max_workers,
|
|
|
|
load_best_model_at_end=True,
|
|
|
|
save_strategy="steps",
|
|
|
|
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
|
|
|
|
save_steps=config.save.save_every_steps,
|
|
|
|
warmup_steps=config.hparams.warmup_steps,
|
|
|
|
warmup_ratio=config.hparams.warmup_ratio,
|
|
|
|
bf16=accelerator.mixed_precision == "bf16",
|
|
|
|
fp16=accelerator.mixed_precision == "fp16",
|
|
|
|
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
|
|
|
|
max_grad_norm=config.hparams.clip_grad_norm,
|
|
|
|
remove_unused_columns=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Set the collator
|
|
|
|
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
|
|
|
|
#checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
|
|
|
|
|
|
|
|
# Initialize Trainer
|
|
|
|
trainer = Trainer(
|
|
|
|
model=model,
|
|
|
|
args=training_args,
|
|
|
|
train_dataset=train_ds,
|
|
|
|
#eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
|
|
|
tokenizer=processor.tokenizer,
|
|
|
|
#data_collator=collator,
|
|
|
|
#callbacks=[checkpoint_callback],
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Train the model
|
|
|
|
trainer.train() # pyright: ignore
|
2024-09-19 21:55:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
train_config = make_cli(TrainConfig) # pyright: ignore
|
|
|
|
run_train(train_config)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|