Trying manual training loop with lora for memory usage

This commit is contained in:
Jake Poznanski 2024-09-26 10:13:36 -07:00
parent f14e910175
commit 8fece62b8a
3 changed files with 21 additions and 76 deletions

View File

@ -29,8 +29,8 @@ train_data:
seed: 1337 seed: 1337
sources: sources:
- name: openai_batch_data_v2 - name: openai_batch_data_v2
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
backend: backend:
- openai - openai
size: 100_000 size: 100_000

View File

@ -29,8 +29,8 @@ train_data:
seed: 1337 seed: 1337
sources: sources:
- name: openai_batch_data_v2 - name: openai_batch_data_v2
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
backend: backend:
- openai - openai
size: 100_000 size: 100_000

View File

@ -12,9 +12,9 @@ from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
from tqdm import tqdm from tqdm import tqdm
import accelerate
import torch import torch
import torch.distributed import torch.distributed
from accelerate import Accelerator
from datasets.utils import disable_progress_bars from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore from peft import LoraConfig, get_peft_model # pyright: ignore
@ -145,82 +145,27 @@ def run_train(config: TrainConfig):
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
with TemporaryDirectory() as output_dir: train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
training_args = TrainingArguments( optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
run_name=run_name.run,
logging_steps=config.hparams.log_every_steps,
output_dir=output_dir,
eval_strategy="steps",
report_to="wandb",
# report_to=[], # disable logging to wandb, we will use a custom callback
optim=config.hparams.optim,
eval_steps=config.hparams.eval_every_steps,
learning_rate=config.hparams.learning_rate,
per_device_train_batch_size=config.hparams.batch_size,
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
gradient_checkpointing=config.hparams.gradient_checkpointing,
gradient_checkpointing_kwargs=(
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
if config.hparams.gradient_checkpointing and config.lora is not None
else {}
),
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
max_steps=config.hparams.max_steps,
weight_decay=config.hparams.weight_decay,
dataloader_num_workers=config.max_workers,
load_best_model_at_end=True,
save_strategy="steps",
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
save_steps=config.save.save_every_steps,
warmup_steps=config.hparams.warmup_steps,
warmup_ratio=config.hparams.warmup_ratio,
bf16=True,
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False,
)
# Set the collator accelerator = Accelerator(mixed_precision="bf16")
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False) model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
# Initialize Trainer steps = 0
trainer = Trainer(
model=model,
args=training_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["validation"], # pyright: ignore
tokenizer=processor.tokenizer,
#Collator is not needed as we are doing batch size 1 for now...
#data_collator=collator,
callbacks=[checkpoint_callback],
)
# Could not get this to work for entry in tqdm(train_dataloader):
# if get_rank() == 0: print("Sequence len", entry["input_ids"].shape)
# # this is a hack to add script and peft config to wandb config with accelerator.accumulate(model):
# update_wandb_config(config, trainer, model) optimizer.zero_grad()
outputs = model(**entry)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
# Train the model steps += 1
trainer.train() # pyright: ignore if accelerator.is_local_main_process:
logger.info(f"step {steps}, training loss : {loss.item()}")
with get_local_dir(join_path("", save_path, "best")) as best_dir:
if config.lora is not None:
logger.info("Merging LoRA adapters into the base model...")
model = model.merge_and_unload()
logger.info("LoRA adapters merged successfully.")
model.save_pretrained(best_dir)
logger.info("Saved best model to %s", best_dir)
# Uncomment to test speed of data loader
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
# for entry in tqdm(train_dataloader):
# print("Step!")
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
def main(): def main():