mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-12 07:41:41 +00:00
Trying manual training loop with lora for memory usage
This commit is contained in:
parent
f14e910175
commit
8fece62b8a
@ -29,8 +29,8 @@ train_data:
|
||||
seed: 1337
|
||||
sources:
|
||||
- name: openai_batch_data_v2
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
|
||||
@ -29,8 +29,8 @@ train_data:
|
||||
seed: 1337
|
||||
sources:
|
||||
- name: openai_batch_data_v2
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
|
||||
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
|
||||
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
|
||||
backend:
|
||||
- openai
|
||||
size: 100_000
|
||||
|
||||
@ -12,9 +12,9 @@ from tempfile import TemporaryDirectory
|
||||
from typing import Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed
|
||||
from accelerate import Accelerator
|
||||
from datasets.utils import disable_progress_bars
|
||||
from datasets.utils.logging import set_verbosity
|
||||
from peft import LoraConfig, get_peft_model # pyright: ignore
|
||||
@ -145,82 +145,27 @@ def run_train(config: TrainConfig):
|
||||
|
||||
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
|
||||
|
||||
with TemporaryDirectory() as output_dir:
|
||||
train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
run_name=run_name.run,
|
||||
logging_steps=config.hparams.log_every_steps,
|
||||
output_dir=output_dir,
|
||||
eval_strategy="steps",
|
||||
report_to="wandb",
|
||||
# report_to=[], # disable logging to wandb, we will use a custom callback
|
||||
optim=config.hparams.optim,
|
||||
eval_steps=config.hparams.eval_every_steps,
|
||||
learning_rate=config.hparams.learning_rate,
|
||||
per_device_train_batch_size=config.hparams.batch_size,
|
||||
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
|
||||
gradient_checkpointing=config.hparams.gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs=(
|
||||
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
|
||||
if config.hparams.gradient_checkpointing and config.lora is not None
|
||||
else {}
|
||||
),
|
||||
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
|
||||
max_steps=config.hparams.max_steps,
|
||||
weight_decay=config.hparams.weight_decay,
|
||||
dataloader_num_workers=config.max_workers,
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
|
||||
save_steps=config.save.save_every_steps,
|
||||
warmup_steps=config.hparams.warmup_steps,
|
||||
warmup_ratio=config.hparams.warmup_ratio,
|
||||
bf16=True,
|
||||
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
|
||||
max_grad_norm=config.hparams.clip_grad_norm,
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
|
||||
|
||||
# Set the collator
|
||||
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
|
||||
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
|
||||
accelerator = Accelerator(mixed_precision="bf16")
|
||||
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
|
||||
|
||||
# Initialize Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=formatted_dataset["train"],
|
||||
eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
||||
tokenizer=processor.tokenizer,
|
||||
#Collator is not needed as we are doing batch size 1 for now...
|
||||
#data_collator=collator,
|
||||
callbacks=[checkpoint_callback],
|
||||
)
|
||||
steps = 0
|
||||
|
||||
# Could not get this to work
|
||||
# if get_rank() == 0:
|
||||
# # this is a hack to add script and peft config to wandb config
|
||||
# update_wandb_config(config, trainer, model)
|
||||
for entry in tqdm(train_dataloader):
|
||||
print("Sequence len", entry["input_ids"].shape)
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**entry)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
# Train the model
|
||||
trainer.train() # pyright: ignore
|
||||
|
||||
with get_local_dir(join_path("", save_path, "best")) as best_dir:
|
||||
if config.lora is not None:
|
||||
logger.info("Merging LoRA adapters into the base model...")
|
||||
model = model.merge_and_unload()
|
||||
logger.info("LoRA adapters merged successfully.")
|
||||
|
||||
model.save_pretrained(best_dir)
|
||||
|
||||
logger.info("Saved best model to %s", best_dir)
|
||||
|
||||
|
||||
# Uncomment to test speed of data loader
|
||||
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
|
||||
# for entry in tqdm(train_dataloader):
|
||||
# print("Step!")
|
||||
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
|
||||
steps += 1
|
||||
if accelerator.is_local_main_process:
|
||||
logger.info(f"step {steps}, training loss : {loss.item()}")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user