mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-18 11:42:55 +00:00
Basic LORA trainer, doesn't seem to make any speed difference
This commit is contained in:
parent
3ed14a9ea5
commit
ab9458b913
@ -3,8 +3,8 @@ model:
|
||||
arch: causal
|
||||
|
||||
wandb:
|
||||
project: refine
|
||||
entity: pdf-qwen2vl
|
||||
project: pdelfin
|
||||
entity: ai2-llm
|
||||
|
||||
# TODO This is not used
|
||||
format:
|
||||
@ -93,10 +93,10 @@ hparams:
|
||||
gradient_checkpointing: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 3e-4
|
||||
max_steps: 10000
|
||||
max_steps: 200
|
||||
pad_multiple_of: 16
|
||||
log_every_steps: 5
|
||||
eval_every_steps: 250
|
||||
eval_every_steps: 100
|
||||
optim: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
weight_decay: 0.01
|
||||
@ -118,7 +118,7 @@ lora:
|
||||
- down_proj
|
||||
|
||||
save:
|
||||
path: s3://ai2-tylerm-experimental/experiments/rephrase/v1/models/lucas
|
||||
save_every_steps: 500
|
||||
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
|
||||
save_every_steps: 100
|
||||
|
||||
max_workers: 1
|
@ -123,6 +123,8 @@ def run_train(config: TrainConfig):
|
||||
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
||||
|
||||
train_ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
|
||||
@ -133,10 +135,25 @@ def run_train(config: TrainConfig):
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
|
||||
if config.lora is not None:
|
||||
peft_config = LoraConfig(
|
||||
r=config.lora.rank,
|
||||
lora_alpha=config.lora.alpha,
|
||||
lora_dropout=config.lora.dropout,
|
||||
bias=config.lora.bias, # pyright: ignore
|
||||
task_type=config.lora.task_type,
|
||||
target_modules=list(config.lora.target_modules),
|
||||
)
|
||||
model = get_peft_model(model=model, peft_config=peft_config)
|
||||
log_trainable_parameters(model=model, logger=logger)
|
||||
|
||||
train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
print(train_ds)
|
||||
print("---------------")
|
||||
|
||||
save_path = join_path("", config.save.path, run_name.run)
|
||||
|
||||
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
|
||||
|
||||
with TemporaryDirectory() as output_dir:
|
||||
|
||||
@ -177,22 +194,31 @@ def run_train(config: TrainConfig):
|
||||
|
||||
# Set the collator
|
||||
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
|
||||
#checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
|
||||
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
|
||||
|
||||
# # Initialize Trainer
|
||||
# trainer = Trainer(
|
||||
# model=model,
|
||||
# args=training_args,
|
||||
# train_dataset=train_ds,
|
||||
# #eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
||||
# tokenizer=processor.tokenizer,
|
||||
# #data_collator=collator,
|
||||
# #callbacks=[checkpoint_callback],
|
||||
# )
|
||||
# Initialize Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_ds,
|
||||
#eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
||||
tokenizer=processor.tokenizer,
|
||||
#data_collator=collator,
|
||||
#callbacks=[checkpoint_callback],
|
||||
)
|
||||
|
||||
# Could not get this to work
|
||||
# if get_rank() == 0:
|
||||
# # this is a hack to add script and peft config to wandb config
|
||||
# update_wandb_config(config, trainer, model)
|
||||
|
||||
# # Train the model
|
||||
# trainer.train() # pyright: ignore
|
||||
# Train the model
|
||||
trainer.train() # pyright: ignore
|
||||
|
||||
with get_local_dir(join_path("", save_path, "best")) as best_dir:
|
||||
model.save_pretrained(best_dir)
|
||||
tokenizer.tokenizer.save_pretrained(best_dir)
|
||||
logger.info("Saved best model to %s", best_dir)
|
||||
|
||||
# Uncomment to test speed of data loader
|
||||
# train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user