mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-15 12:21:44 +00:00
241 lines
9.1 KiB
Python
241 lines
9.1 KiB
Python
import os
|
|
import json
|
|
import base64
|
|
import logging
|
|
import time
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
from functools import partial
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Optional
|
|
from tqdm import tqdm
|
|
|
|
import accelerate
|
|
import torch
|
|
import torch.distributed
|
|
from datasets.utils import disable_progress_bars
|
|
from datasets.utils.logging import set_verbosity
|
|
from peft import LoraConfig, get_peft_model # pyright: ignore
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
Trainer,
|
|
TrainerCallback,
|
|
TrainingArguments,
|
|
Qwen2VLForConditionalGeneration,
|
|
AutoProcessor
|
|
)
|
|
from transformers.integrations import WandbCallback
|
|
from transformers.trainer_callback import TrainerControl, TrainerState
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
from torch.utils.data import DataLoader
|
|
|
|
import wandb
|
|
|
|
from pdelfin.train.core.cli import make_cli, save_config, to_native_types
|
|
from pdelfin.train.core.config import TrainConfig
|
|
from pdelfin.train.core.loggers import get_logger
|
|
from pdelfin.train.core.paths import copy_dir, join_path
|
|
from pdelfin.train.core.state import BeakerState
|
|
|
|
from .utils import (
|
|
RunName,
|
|
get_local_dir,
|
|
log_trainable_parameters,
|
|
packing_collator,
|
|
setup_environment,
|
|
)
|
|
|
|
|
|
from pdelfin.train.dataloader import make_dataset
|
|
from pdelfin.train.dataprep import filter_by_max_seq_len, prepare_data_for_qwen2_training
|
|
|
|
|
|
class CheckpointUploadCallback(TrainerCallback):
|
|
def __init__(self, save_path: str, logger: Optional[Logger] = None):
|
|
self.save_path = save_path
|
|
self.logger = logger or get_logger(self.__class__.__name__)
|
|
|
|
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
|
if state.is_local_process_zero:
|
|
latest_checkpoint = get_last_checkpoint(args.output_dir)
|
|
if not latest_checkpoint:
|
|
return
|
|
|
|
dir_name = Path(latest_checkpoint).name
|
|
copy_dir(str(latest_checkpoint), f"{self.save_path}/{dir_name}")
|
|
self.logger.info("Saved checkpoint to %s", f"{self.save_path}/{dir_name}")
|
|
|
|
|
|
def update_wandb_config(config: TrainConfig, trainer: Trainer, model: torch.nn.Module):
|
|
# finding wandb callback
|
|
callbacks = [c for c in trainer.callback_handler.callbacks if isinstance(c, WandbCallback)] # pyright: ignore
|
|
if not callbacks:
|
|
raise ValueError("WandbCallback not found in trainer callbacks")
|
|
|
|
wandb_callback = callbacks[0]
|
|
peft_config = to_native_types(getattr(model, "peft_config", {}))
|
|
script_config = to_native_types(config)
|
|
beaker_envs = {k: v for k, v in os.environ.items() if k.lower().startswith("beaker")}
|
|
|
|
on_setup_fn = wandb_callback.setup
|
|
|
|
def setup_and_update(args, state, model, **kwargs):
|
|
on_setup_fn(args=args, state=state, model=model, **kwargs)
|
|
wandb.config.update({"peft": peft_config}, allow_val_change=True)
|
|
wandb.config.update({"script": script_config}, allow_val_change=True)
|
|
wandb.config.update({"beaker": beaker_envs}, allow_val_change=True)
|
|
if (run := wandb.run) and (beaker_url := BeakerState().url):
|
|
run.notes = beaker_url
|
|
|
|
wandb_callback.setup = setup_and_update
|
|
|
|
|
|
def get_rank() -> int:
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
return torch.distributed.get_rank()
|
|
return 0
|
|
|
|
|
|
def run_train(config: TrainConfig):
|
|
if get_rank() == 0:
|
|
logger_level = logging.INFO
|
|
else:
|
|
logger_level = logging.WARN
|
|
disable_progress_bars()
|
|
|
|
logger = get_logger(__name__, level=logger_level)
|
|
set_verbosity(logger_level)
|
|
|
|
run_name = RunName.get(config)
|
|
|
|
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
|
|
|
dataset = make_dataset(
|
|
train_data_config=config.train_data,
|
|
valid_data_config=config.valid_data,
|
|
num_proc=config.num_proc,
|
|
logger=logger,
|
|
)
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
config.model.name_or_path, torch_dtype=torch.bfloat16,
|
|
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
|
)
|
|
processor = AutoProcessor.from_pretrained(config.model.name_or_path)
|
|
|
|
if config.lora is not None:
|
|
peft_config = LoraConfig(
|
|
r=config.lora.rank,
|
|
lora_alpha=config.lora.alpha,
|
|
lora_dropout=config.lora.dropout,
|
|
bias=config.lora.bias, # pyright: ignore
|
|
task_type=config.lora.task_type,
|
|
target_modules=list(config.lora.target_modules),
|
|
)
|
|
model = get_peft_model(model=model, peft_config=peft_config)
|
|
log_trainable_parameters(model=model, logger=logger)
|
|
|
|
# formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
|
|
|
# Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance
|
|
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
|
|
validation_ds = dataset["validation"]
|
|
|
|
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len)
|
|
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True)), remove_columns=validation_ds.column_names)
|
|
|
|
print(train_ds)
|
|
print(validation_ds)
|
|
print("---------------")
|
|
|
|
save_path = join_path("", config.save.path, run_name.run)
|
|
|
|
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
|
|
|
|
with TemporaryDirectory() as output_dir:
|
|
|
|
training_args = TrainingArguments(
|
|
run_name=run_name.run,
|
|
logging_steps=config.hparams.log_every_steps,
|
|
output_dir=output_dir,
|
|
eval_strategy="steps",
|
|
report_to="wandb",
|
|
# report_to=[], # disable logging to wandb, we will use a custom callback
|
|
optim=config.hparams.optim,
|
|
eval_steps=config.hparams.eval_every_steps,
|
|
learning_rate=config.hparams.learning_rate,
|
|
per_device_train_batch_size=config.hparams.batch_size,
|
|
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
|
|
gradient_checkpointing=config.hparams.gradient_checkpointing,
|
|
gradient_checkpointing_kwargs=(
|
|
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
|
|
if config.hparams.gradient_checkpointing and config.lora is not None
|
|
else {}
|
|
),
|
|
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
|
|
max_steps=config.hparams.max_steps,
|
|
weight_decay=config.hparams.weight_decay,
|
|
dataloader_num_workers=config.max_workers,
|
|
load_best_model_at_end=True,
|
|
save_strategy="steps",
|
|
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
|
|
save_steps=config.save.save_every_steps,
|
|
warmup_steps=config.hparams.warmup_steps,
|
|
warmup_ratio=config.hparams.warmup_ratio,
|
|
bf16=True,
|
|
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
|
|
max_grad_norm=config.hparams.clip_grad_norm,
|
|
remove_unused_columns=False,
|
|
)
|
|
|
|
# Set the collator
|
|
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
|
|
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
|
|
|
|
# Initialize Trainer
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_ds,
|
|
eval_dataset=validation_ds,
|
|
tokenizer=processor.tokenizer,
|
|
#Collator is not needed as we are doing batch size 1 for now...
|
|
#data_collator=collator,
|
|
callbacks=[checkpoint_callback],
|
|
)
|
|
|
|
# Could not get this to work
|
|
# if get_rank() == 0:
|
|
# # this is a hack to add script and peft config to wandb config
|
|
# update_wandb_config(config, trainer, model)
|
|
|
|
# Train the model
|
|
trainer.train() # pyright: ignore
|
|
|
|
with get_local_dir(join_path("", save_path, "best")) as best_dir:
|
|
if config.lora is not None:
|
|
logger.info("Merging LoRA adapters into the base model...")
|
|
model = model.merge_and_unload()
|
|
logger.info("LoRA adapters merged successfully.")
|
|
|
|
model.save_pretrained(best_dir)
|
|
|
|
logger.info("Saved best model to %s", best_dir)
|
|
|
|
|
|
# Uncomment to test speed of data loader
|
|
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
|
|
# for entry in tqdm(train_dataloader):
|
|
# print("Step!")
|
|
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
|
|
|
|
|
|
def main():
|
|
train_config = make_cli(TrainConfig) # pyright: ignore
|
|
run_train(train_config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |