Running a mini config again with metric

This commit is contained in:
Jake Poznanski 2024-10-03 11:12:30 -07:00
parent 046d4a4534
commit 8f1fa4f796
3 changed files with 7 additions and 3 deletions

View File

@ -33,6 +33,7 @@ train_data:
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
valid_data: valid_data:
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources: sources:
- name: openai_batch_data_v5_1_eval - name: openai_batch_data_v5_1_eval
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
@ -51,10 +52,10 @@ hparams:
gradient_checkpointing: false gradient_checkpointing: false
clip_grad_norm: 1.0 clip_grad_norm: 1.0
learning_rate: 3e-4 learning_rate: 3e-4
max_steps: 500 max_steps: 50
pad_multiple_of: 16 pad_multiple_of: 16
log_every_steps: 50 log_every_steps: 10
eval_every_steps: 100 eval_every_steps: 50
optim: adamw_torch optim: adamw_torch
lr_scheduler: cosine lr_scheduler: cosine
weight_decay: 0.01 weight_decay: 0.01

View File

@ -82,6 +82,7 @@ class SourceConfig:
@dataclass @dataclass
class DataConfig: class DataConfig:
seed: int = field(default=42, help="The seed to use for data loading") seed: int = field(default=42, help="The seed to use for data loading")
metric_for_best_model: Optional[str] = field(help="metric to pass to trainer args to use for picking best model checkpoint at end", default=None)
sources: List[SourceConfig] = field(help="The source configurations") sources: List[SourceConfig] = field(help="The source configurations")

View File

@ -3,6 +3,7 @@ import json
import base64 import base64
import logging import logging
import time import time
import random
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from functools import partial from functools import partial
@ -194,6 +195,7 @@ def run_train(config: TrainConfig):
max_grad_norm=config.hparams.clip_grad_norm, max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False, remove_unused_columns=False,
eval_on_start=True, eval_on_start=True,
metric_for_best_model=config.valid_data.metric_for_best_model,
) )
# Set the collator # Set the collator