diff --git a/pdelfin/train/config/qwen2vl-2b.yaml b/pdelfin/train/config/qwen2vl-2b.yaml index 6da9d5f..9ef18a1 100644 --- a/pdelfin/train/config/qwen2vl-2b.yaml +++ b/pdelfin/train/config/qwen2vl-2b.yaml @@ -3,8 +3,8 @@ model: arch: causal wandb: - project: refine - entity: pdf-qwen2vl + project: pdelfin + entity: ai2-llm # TODO This is not used format: @@ -93,10 +93,10 @@ hparams: gradient_checkpointing: true clip_grad_norm: 1.0 learning_rate: 3e-4 - max_steps: 10000 + max_steps: 200 pad_multiple_of: 16 log_every_steps: 5 - eval_every_steps: 250 + eval_every_steps: 100 optim: adamw_torch lr_scheduler: cosine weight_decay: 0.01 @@ -118,7 +118,7 @@ lora: - down_proj save: - path: s3://ai2-tylerm-experimental/experiments/rephrase/v1/models/lucas - save_every_steps: 500 + path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/ + save_every_steps: 100 max_workers: 1 \ No newline at end of file diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 476ad87..4e9650c 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -123,6 +123,8 @@ def run_train(config: TrainConfig): accelerator = accelerate.Accelerator() + setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group) + train_ds = build_batch_query_response_vision_dataset( query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json", @@ -133,10 +135,25 @@ def run_train(config: TrainConfig): ) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") + if config.lora is not None: + peft_config = LoraConfig( + r=config.lora.rank, + lora_alpha=config.lora.alpha, + lora_dropout=config.lora.dropout, + bias=config.lora.bias, # pyright: ignore + task_type=config.lora.task_type, + target_modules=list(config.lora.target_modules), + ) + model = get_peft_model(model=model, peft_config=peft_config) + log_trainable_parameters(model=model, logger=logger) + train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) print(train_ds) print("---------------") + save_path = join_path("", config.save.path, run_name.run) + + save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore with TemporaryDirectory() as output_dir: @@ -177,22 +194,31 @@ def run_train(config: TrainConfig): # Set the collator collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False) - #checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger) + checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger) - # # Initialize Trainer - # trainer = Trainer( - # model=model, - # args=training_args, - # train_dataset=train_ds, - # #eval_dataset=formatted_dataset["validation"], # pyright: ignore - # tokenizer=processor.tokenizer, - # #data_collator=collator, - # #callbacks=[checkpoint_callback], - # ) + # Initialize Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_ds, + #eval_dataset=formatted_dataset["validation"], # pyright: ignore + tokenizer=processor.tokenizer, + #data_collator=collator, + #callbacks=[checkpoint_callback], + ) + # Could not get this to work + # if get_rank() == 0: + # # this is a hack to add script and peft config to wandb config + # update_wandb_config(config, trainer, model) - # # Train the model - # trainer.train() # pyright: ignore + # Train the model + trainer.train() # pyright: ignore + + with get_local_dir(join_path("", save_path, "best")) as best_dir: + model.save_pretrained(best_dir) + tokenizer.tokenizer.save_pretrained(best_dir) + logger.info("Saved best model to %s", best_dir) # Uncomment to test speed of data loader # train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False)