diff --git a/pdelfin/train/config/qwen2vl-7b-lora.yaml b/pdelfin/train/config/qwen2vl-7b-lora.yaml index 80bfba0..f8d47ec 100644 --- a/pdelfin/train/config/qwen2vl-7b-lora.yaml +++ b/pdelfin/train/config/qwen2vl-7b-lora.yaml @@ -33,6 +33,7 @@ train_data: response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json valid_data: + metric_for_best_model: openai_batch_data_v5_1_eval_loss sources: - name: openai_batch_data_v5_1_eval query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl @@ -51,10 +52,10 @@ hparams: gradient_checkpointing: false clip_grad_norm: 1.0 learning_rate: 3e-4 - max_steps: 500 + max_steps: 50 pad_multiple_of: 16 - log_every_steps: 50 - eval_every_steps: 100 + log_every_steps: 10 + eval_every_steps: 50 optim: adamw_torch lr_scheduler: cosine weight_decay: 0.01 diff --git a/pdelfin/train/core/config.py b/pdelfin/train/core/config.py index eca44cd..a1e838b 100644 --- a/pdelfin/train/core/config.py +++ b/pdelfin/train/core/config.py @@ -82,6 +82,7 @@ class SourceConfig: @dataclass class DataConfig: seed: int = field(default=42, help="The seed to use for data loading") + metric_for_best_model: Optional[str] = field(help="metric to pass to trainer args to use for picking best model checkpoint at end", default=None) sources: List[SourceConfig] = field(help="The source configurations") diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index e8f4bd2..d9dae06 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -3,6 +3,7 @@ import json import base64 import logging import time +import random from io import BytesIO from PIL import Image from functools import partial @@ -194,6 +195,7 @@ def run_train(config: TrainConfig): max_grad_norm=config.hparams.clip_grad_norm, remove_unused_columns=False, eval_on_start=True, + metric_for_best_model=config.valid_data.metric_for_best_model, ) # Set the collator