# Step 1, load the data # Probably, we want to see just a folder with openai batch input jsonls, plus the batch output jsonls # TODO: Figure out hyperparameters for image sizing # Step 2. Load those prompts through and do a forward pass to calculate the loss # Step 3. Add hugging face accelerate for training # Step 4. Checkpointing code, both saving and reloading to restart # Step 5. Move over from interactive session to gantry launch script import os import json import base64 import logging import time from io import BytesIO from PIL import Image from functools import partial from logging import Logger from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional from tqdm import tqdm import accelerate import torch import torch.distributed from datasets.utils import disable_progress_bars from datasets.utils.logging import set_verbosity from peft import LoraConfig, get_peft_model # pyright: ignore from transformers import ( AutoModelForCausalLM, Trainer, TrainerCallback, TrainingArguments, Qwen2VLForConditionalGeneration, AutoProcessor ) from transformers.integrations import WandbCallback from transformers.trainer_callback import TrainerControl, TrainerState from transformers.trainer_utils import get_last_checkpoint from torch.utils.data import DataLoader import wandb from pdelfin.train.core.cli import make_cli, save_config, to_native_types from pdelfin.train.core.config import TrainConfig from pdelfin.train.core.loggers import get_logger from pdelfin.train.core.paths import copy_dir, join_path from pdelfin.train.core.state import BeakerState from .utils import ( RunName, get_local_dir, log_trainable_parameters, packing_collator, setup_environment, ) from pdelfin.train.dataloader import make_dataset from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training class CheckpointUploadCallback(TrainerCallback): def __init__(self, save_path: str, logger: Optional[Logger] = None): self.save_path = save_path self.logger = logger or get_logger(self.__class__.__name__) def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): if state.is_local_process_zero: latest_checkpoint = get_last_checkpoint(args.output_dir) if not latest_checkpoint: return dir_name = Path(latest_checkpoint).name copy_dir(str(latest_checkpoint), f"{self.save_path}/{dir_name}") self.logger.info("Saved checkpoint to %s", f"{self.save_path}/{dir_name}") def update_wandb_config(config: TrainConfig, trainer: Trainer, model: torch.nn.Module): # finding wandb callback callbacks = [c for c in trainer.callback_handler.callbacks if isinstance(c, WandbCallback)] # pyright: ignore if not callbacks: raise ValueError("WandbCallback not found in trainer callbacks") wandb_callback = callbacks[0] peft_config = to_native_types(getattr(model, "peft_config", {})) script_config = to_native_types(config) beaker_envs = {k: v for k, v in os.environ.items() if k.lower().startswith("beaker")} on_setup_fn = wandb_callback.setup def setup_and_update(args, state, model, **kwargs): on_setup_fn(args=args, state=state, model=model, **kwargs) wandb.config.update({"peft": peft_config}, allow_val_change=True) wandb.config.update({"script": script_config}, allow_val_change=True) wandb.config.update({"beaker": beaker_envs}, allow_val_change=True) if (run := wandb.run) and (beaker_url := BeakerState().url): run.notes = beaker_url wandb_callback.setup = setup_and_update def get_rank() -> int: if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_rank() return 0 def run_train(config: TrainConfig): if get_rank() == 0: logger_level = logging.INFO else: logger_level = logging.WARN disable_progress_bars() logger = get_logger(__name__, level=logger_level) set_verbosity(logger_level) run_name = RunName.get(config) accelerator = accelerate.Accelerator() setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group) dataset = make_dataset( train_data_config=config.train_data, valid_data_config=config.valid_data, num_proc=config.num_proc, logger=logger, ) model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto" ) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") if config.lora is not None: peft_config = LoraConfig( r=config.lora.rank, lora_alpha=config.lora.alpha, lora_dropout=config.lora.dropout, bias=config.lora.bias, # pyright: ignore task_type=config.lora.task_type, target_modules=list(config.lora.target_modules), ) model = get_peft_model(model=model, peft_config=peft_config) log_trainable_parameters(model=model, logger=logger) formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) print(formatted_dataset) print("---------------") save_path = join_path("", config.save.path, run_name.run) save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore with TemporaryDirectory() as output_dir: training_args = TrainingArguments( run_name=run_name.run, logging_steps=config.hparams.log_every_steps, output_dir=output_dir, eval_strategy="steps", report_to="wandb", # report_to=[], # disable logging to wandb, we will use a custom callback optim=config.hparams.optim, eval_steps=config.hparams.eval_every_steps, learning_rate=config.hparams.learning_rate, per_device_train_batch_size=config.hparams.batch_size, per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size, gradient_checkpointing=config.hparams.gradient_checkpointing, gradient_checkpointing_kwargs=( dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142 if config.hparams.gradient_checkpointing and config.lora is not None else {} ), gradient_accumulation_steps=config.hparams.gradient_accumulation_steps, max_steps=config.hparams.max_steps, weight_decay=config.hparams.weight_decay, dataloader_num_workers=config.max_workers, load_best_model_at_end=True, save_strategy="steps", ddp_find_unused_parameters=config.hparams.find_unused_parameters, save_steps=config.save.save_every_steps, warmup_steps=config.hparams.warmup_steps, warmup_ratio=config.hparams.warmup_ratio, bf16=accelerator.mixed_precision == "bf16", fp16=accelerator.mixed_precision == "fp16", label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885 max_grad_norm=config.hparams.clip_grad_norm, remove_unused_columns=False, ) # Set the collator collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False) checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger) # Initialize Trainer trainer = Trainer( model=model, args=training_args, train_dataset=formatted_dataset["train"], eval_dataset=formatted_dataset["validation"], # pyright: ignore tokenizer=processor.tokenizer, #Collator is not needed as we are doing batch size 1 for now... #data_collator=collator, callbacks=[checkpoint_callback], ) # Could not get this to work # if get_rank() == 0: # # this is a hack to add script and peft config to wandb config # update_wandb_config(config, trainer, model) # Train the model trainer.train() # pyright: ignore with get_local_dir(join_path("", save_path, "best")) as best_dir: model.save_pretrained(best_dir) logger.info("Saved best model to %s", best_dir) # Uncomment to test speed of data loader # train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False) # for entry in tqdm(train_dataloader): # print("Step!") def main(): train_config = make_cli(TrainConfig) # pyright: ignore run_train(train_config) if __name__ == "__main__": main()