Trying manual training loop with lora for memory usage

This commit is contained in:
Jake Poznanski 2024-09-26 10:13:36 -07:00
parent f14e910175
commit 8fece62b8a
3 changed files with 21 additions and 76 deletions

View File

@ -29,8 +29,8 @@ train_data:
seed: 1337
sources:
- name: openai_batch_data_v2
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
backend:
- openai
size: 100_000

View File

@ -29,8 +29,8 @@ train_data:
seed: 1337
sources:
- name: openai_batch_data_v2
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
backend:
- openai
size: 100_000

View File

@ -12,9 +12,9 @@ from tempfile import TemporaryDirectory
from typing import Optional
from tqdm import tqdm
import accelerate
import torch
import torch.distributed
from accelerate import Accelerator
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore
@ -145,82 +145,27 @@ def run_train(config: TrainConfig):
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
with TemporaryDirectory() as output_dir:
train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
training_args = TrainingArguments(
run_name=run_name.run,
logging_steps=config.hparams.log_every_steps,
output_dir=output_dir,
eval_strategy="steps",
report_to="wandb",
# report_to=[], # disable logging to wandb, we will use a custom callback
optim=config.hparams.optim,
eval_steps=config.hparams.eval_every_steps,
learning_rate=config.hparams.learning_rate,
per_device_train_batch_size=config.hparams.batch_size,
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
gradient_checkpointing=config.hparams.gradient_checkpointing,
gradient_checkpointing_kwargs=(
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
if config.hparams.gradient_checkpointing and config.lora is not None
else {}
),
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
max_steps=config.hparams.max_steps,
weight_decay=config.hparams.weight_decay,
dataloader_num_workers=config.max_workers,
load_best_model_at_end=True,
save_strategy="steps",
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
save_steps=config.save.save_every_steps,
warmup_steps=config.hparams.warmup_steps,
warmup_ratio=config.hparams.warmup_ratio,
bf16=True,
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# Set the collator
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
accelerator = Accelerator(mixed_precision="bf16")
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["validation"], # pyright: ignore
tokenizer=processor.tokenizer,
#Collator is not needed as we are doing batch size 1 for now...
#data_collator=collator,
callbacks=[checkpoint_callback],
)
steps = 0
# Could not get this to work
# if get_rank() == 0:
# # this is a hack to add script and peft config to wandb config
# update_wandb_config(config, trainer, model)
for entry in tqdm(train_dataloader):
print("Sequence len", entry["input_ids"].shape)
with accelerator.accumulate(model):
optimizer.zero_grad()
outputs = model(**entry)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
# Train the model
trainer.train() # pyright: ignore
with get_local_dir(join_path("", save_path, "best")) as best_dir:
if config.lora is not None:
logger.info("Merging LoRA adapters into the base model...")
model = model.merge_and_unload()
logger.info("LoRA adapters merged successfully.")
model.save_pretrained(best_dir)
logger.info("Saved best model to %s", best_dir)
# Uncomment to test speed of data loader
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
# for entry in tqdm(train_dataloader):
# print("Step!")
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
steps += 1
if accelerator.is_local_main_process:
logger.info(f"step {steps}, training loss : {loss.item()}")
def main():