This commit is contained in:
Jake Poznanski 2024-09-23 21:19:26 +00:00
commit d589b5651d
4 changed files with 61 additions and 19 deletions

View File

@ -1,6 +1,7 @@
model:
name_or_path: Qwen/Qwen2-VL-2B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
@ -48,7 +49,7 @@ hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing: false
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 200
@ -79,4 +80,4 @@ save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 100
max_workers: 1
max_workers: 30

View File

@ -1,14 +1,3 @@
# Step 1, load the data
# Probably, we want to see just a folder with openai batch input jsonls, plus the batch output jsonls
# TODO: Figure out hyperparameters for image sizing
# Step 2. Load those prompts through and do a forward pass to calculate the loss
# Step 3. Add hugging face accelerate for training
# Step 4. Checkpointing code, both saving and reloading to restart
# Step 5. Move over from interactive session to gantry launch script
import os
import json
import base64
@ -121,8 +110,6 @@ def run_train(config: TrainConfig):
run_name = RunName.get(config)
accelerator = accelerate.Accelerator()
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
dataset = make_dataset(
@ -133,7 +120,8 @@ def run_train(config: TrainConfig):
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto",
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
@ -187,8 +175,7 @@ def run_train(config: TrainConfig):
save_steps=config.save.save_every_steps,
warmup_steps=config.hparams.warmup_steps,
warmup_ratio=config.hparams.warmup_ratio,
bf16=accelerator.mixed_precision == "bf16",
fp16=accelerator.mixed_precision == "fp16",
bf16=True,
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False,
@ -219,13 +206,20 @@ def run_train(config: TrainConfig):
trainer.train() # pyright: ignore
with get_local_dir(join_path("", save_path, "best")) as best_dir:
if config.lora is not None:
logger.info("Merging LoRA adapters into the base model...")
model = model.merge_and_unload()
logger.info("LoRA adapters merged successfully.")
model.save_pretrained(best_dir)
logger.info("Saved best model to %s", best_dir)
# Uncomment to test speed of data loader
# train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False)
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
# for entry in tqdm(train_dataloader):
# print("Step!")
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
def main():

View File

@ -0,0 +1,2 @@
set -ex
export NCCL_DEBUG=INFO NCCL_SOCKET_IFNAME=ib NCCL_IB_HCA="^=mlx5_bond_0"

View File

@ -0,0 +1,45 @@
#!/usr/bin/env bash
set -ex
# check if jq is installed
if ! command -v jq &> /dev/null
then
echo "jq could not be found. Please install it."
exit
fi
EXTRA_ARGS="-c pdelfin/train/config/qwen2vl-2b.yaml --num_proc 64 --save.path \"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/\${BEAKER_USER_ID}\""
run_name=$(basename "$0" .sh)
# --cluster 'ai2/jupiter*' \
# --cluster 'ai2/pluto*' \
# --cluster 'ai2/allennlp-cirrascale' \
# --priority high \
CLUSTER='jupiter'
gantry run \
--description "${run_name}"\
--task-name "${run_name}"\
--allow-dirty \
--host-networking \
--workspace ai2/oe-data-model-based-cleanup \
--beaker-image 'lucas/refine-axelot-vllm' \
--venv 'base' \
--priority high \
--gpus 8 \
--preemptible \
--cluster "ai2/${CLUSTER}*" \
--budget ai2/oe-data \
--env LOG_FILTER_TYPE=local_rank0_only \
--env OMP_NUM_THREADS=8 \
--env BEAKER_USER_ID=$(beaker account whoami --format json | jq '.[0].name' -cr) \
--env-secret AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \
--env-secret AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \
--env-secret WANDB_API_KEY=WANDB_API_KEY \
--shared-memory 10GiB \
--yes \
-- /bin/bash -c "source scripts/beaker/${CLUSTER}-ib.sh && accelerate launch --multi_gpu --num_processes \${BEAKER_ASSIGNED_GPU_COUNT} --mixed_precision bf16 -m pdelfin.train.train ${EXTRA_ARGS}"