Vllm enable

This commit is contained in:
Jake Poznanski 2025-08-21 17:33:56 +00:00
parent 6fb136deee
commit 6184c94c3c
2 changed files with 8 additions and 3 deletions

View File

@ -207,7 +207,7 @@ def main():
parser.add_argument(
"--learning_rate",
type=float,
default=1e-6,
default=1e-5,
help="Learning rate"
)
parser.add_argument(
@ -355,6 +355,11 @@ def main():
remove_unused_columns=False,
bf16=True,
dataloader_num_workers=0,
# Vllm setup to speed up generation
use_vllm=True,
vllm_mode="colocate",
vllm_gpu_memory_utilization=0.15,
)
# Initialize GRPO trainer

View File

@ -9,7 +9,7 @@ PREEMPTIBLE=false
MAX_TRAIN_SAMPLES=""
MAX_EVAL_SAMPLES=""
NUM_EPOCHS=1
LEARNING_RATE="1e-6"
LEARNING_RATE="1e-5"
BATCH_SIZE=1
GRAD_ACCUM_STEPS=4
USE_WANDB=false
@ -183,7 +183,7 @@ grpo_cmd = [
"--train_bench_data_folder /data/olmOCR-bench/bench_data",
"--eval_bench_data_folder /data/olmOCR-bench/bench_data", # Using same data for now
f"--model_name {model_name}",
"--output_dir /weka/oe-training-default/olmocr-grpo-checkpoints",
"--output_dir /weka/oe-training-default/jakep/olmocr-grpo-checkpoints",
f"--num_train_epochs {num_epochs}",
f"--learning_rate {learning_rate}",
f"--per_device_train_batch_size {batch_size}",