From 8fece62b8ab24df6dc6384672b390e7dfdb4cbfe Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Thu, 26 Sep 2024 10:13:36 -0700 Subject: [PATCH] Trying manual training loop with lora for memory usage --- pdelfin/train/config/qwen2vl-7b-lora.yaml | 4 +- pdelfin/train/config/qwen2vl-7b.yaml | 4 +- pdelfin/train/train.py | 89 +++++------------------ 3 files changed, 21 insertions(+), 76 deletions(-) diff --git a/pdelfin/train/config/qwen2vl-7b-lora.yaml b/pdelfin/train/config/qwen2vl-7b-lora.yaml index f399e1f..06cea09 100644 --- a/pdelfin/train/config/qwen2vl-7b-lora.yaml +++ b/pdelfin/train/config/qwen2vl-7b-lora.yaml @@ -29,8 +29,8 @@ train_data: seed: 1337 sources: - name: openai_batch_data_v2 - query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl - response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json + query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl + response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json backend: - openai size: 100_000 diff --git a/pdelfin/train/config/qwen2vl-7b.yaml b/pdelfin/train/config/qwen2vl-7b.yaml index d973f5e..a6b04cb 100644 --- a/pdelfin/train/config/qwen2vl-7b.yaml +++ b/pdelfin/train/config/qwen2vl-7b.yaml @@ -29,8 +29,8 @@ train_data: seed: 1337 sources: - name: openai_batch_data_v2 - query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl - response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json + query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl + response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json backend: - openai size: 100_000 diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 8cd26b8..f8c0f68 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -12,9 +12,9 @@ from tempfile import TemporaryDirectory from typing import Optional from tqdm import tqdm -import accelerate import torch import torch.distributed +from accelerate import Accelerator from datasets.utils import disable_progress_bars from datasets.utils.logging import set_verbosity from peft import LoraConfig, get_peft_model # pyright: ignore @@ -145,82 +145,27 @@ def run_train(config: TrainConfig): save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore - with TemporaryDirectory() as output_dir: + train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False) - training_args = TrainingArguments( - run_name=run_name.run, - logging_steps=config.hparams.log_every_steps, - output_dir=output_dir, - eval_strategy="steps", - report_to="wandb", - # report_to=[], # disable logging to wandb, we will use a custom callback - optim=config.hparams.optim, - eval_steps=config.hparams.eval_every_steps, - learning_rate=config.hparams.learning_rate, - per_device_train_batch_size=config.hparams.batch_size, - per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size, - gradient_checkpointing=config.hparams.gradient_checkpointing, - gradient_checkpointing_kwargs=( - dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142 - if config.hparams.gradient_checkpointing and config.lora is not None - else {} - ), - gradient_accumulation_steps=config.hparams.gradient_accumulation_steps, - max_steps=config.hparams.max_steps, - weight_decay=config.hparams.weight_decay, - dataloader_num_workers=config.max_workers, - load_best_model_at_end=True, - save_strategy="steps", - ddp_find_unused_parameters=config.hparams.find_unused_parameters, - save_steps=config.save.save_every_steps, - warmup_steps=config.hparams.warmup_steps, - warmup_ratio=config.hparams.warmup_ratio, - bf16=True, - label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885 - max_grad_norm=config.hparams.clip_grad_norm, - remove_unused_columns=False, - ) + optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) - # Set the collator - collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False) - checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger) + accelerator = Accelerator(mixed_precision="bf16") + model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader) - # Initialize Trainer - trainer = Trainer( - model=model, - args=training_args, - train_dataset=formatted_dataset["train"], - eval_dataset=formatted_dataset["validation"], # pyright: ignore - tokenizer=processor.tokenizer, - #Collator is not needed as we are doing batch size 1 for now... - #data_collator=collator, - callbacks=[checkpoint_callback], - ) + steps = 0 - # Could not get this to work - # if get_rank() == 0: - # # this is a hack to add script and peft config to wandb config - # update_wandb_config(config, trainer, model) + for entry in tqdm(train_dataloader): + print("Sequence len", entry["input_ids"].shape) + with accelerator.accumulate(model): + optimizer.zero_grad() + outputs = model(**entry) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() - # Train the model - trainer.train() # pyright: ignore - - with get_local_dir(join_path("", save_path, "best")) as best_dir: - if config.lora is not None: - logger.info("Merging LoRA adapters into the base model...") - model = model.merge_and_unload() - logger.info("LoRA adapters merged successfully.") - - model.save_pretrained(best_dir) - - logger.info("Saved best model to %s", best_dir) - - - # Uncomment to test speed of data loader - # train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False) - # for entry in tqdm(train_dataloader): - # print("Step!") - # model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()}) + steps += 1 + if accelerator.is_local_main_process: + logger.info(f"step {steps}, training loss : {loss.item()}") def main():