From b4b121b11877700700ea0da2d1d6f59ef2fa28d0 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 23 Sep 2025 20:10:11 +0000 Subject: [PATCH] Testing reward eos --- olmocr/train/grpo_train.py | 67 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index a472e9a..764fde3 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -11,7 +11,7 @@ import json import random from pathlib import Path import glob -from functools import lru_cache +from functools import lru_cache, partial from concurrent.futures import ThreadPoolExecutor from rapidfuzz import fuzz import sys @@ -673,6 +673,48 @@ def reward_element_count(prompts, completions: list[str] | list[list[dict]], cla return rewards +def reward_eos(eos_token_id: int, prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], **kwargs): + """ + Reward function that checks if the EOS token is the last token in completion_ids. + + Returns 1.0 if the EOS token is the last token, 0.0 otherwise. + + Args: + eos_token_id: The EOS token ID from the tokenizer + prompts: List of prompts + completions: List of generated completions (model outputs) + completion_ids: List of lists of token IDs for each completion + **kwargs: Additional arguments + + Returns: + List of reward scores (1.0 if EOS is last, 0.0 otherwise) + """ + logger.info(f"Running EOS reward function for {len(completions)} completions (EOS token ID: {eos_token_id})") + + rewards = [] + + for i, comp_ids in enumerate(completion_ids): + print("Testing, ", i, comp_ids[:-100]) + + if comp_ids and len(comp_ids) > 0: + last_token = comp_ids[-1] + if last_token == eos_token_id: + rewards.append(1.0) + logger.debug(f"Completion {i}: EOS token {last_token} found at end") + else: + rewards.append(0.0) + logger.debug(f"Completion {i}: Last token {last_token} is not EOS (expected {eos_token_id})") + else: + # Empty completion, no EOS + rewards.append(0.0) + logger.debug(f"Completion {i}: Empty completion, no EOS") + + eos_count = sum(rewards) + logger.info(f"EOS rewards: {eos_count}/{len(rewards)} completions have EOS as last token") + + return rewards + + def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs): """ Reward function that runs unit tests on completions and returns average pass rate. @@ -883,6 +925,14 @@ def main(): default=None, help="Use element count matching reward (tables and math equations) with optional weight (default: 1.0)" ) + parser.add_argument( + "--reward_eos", + nargs='?', + const=1.0, + type=float, + default=None, + help="Use EOS token check reward - scores 1 if EOS is last token, 0 otherwise (default: 1.0)" + ) parser.add_argument( "--vllm_mode", type=str, @@ -1015,9 +1065,20 @@ def main(): reward_weights.append(args.reward_element_count) reward_names.append("element_count") logger.info(f"Added element count matching reward function with weight {args.reward_element_count}") - + + if args.reward_eos is not None: + # Get EOS token ID from processor's tokenizer + eos_token_id = processor.tokenizer.eos_token_id + logger.info(f"EOS token ID from tokenizer: {eos_token_id}") + # Use partial to bind the EOS token ID to the reward function + reward_eos_partial = partial(reward_eos, eos_token_id) + reward_funcs.append(reward_eos_partial) + reward_weights.append(args.reward_eos) + reward_names.append("eos") + logger.info(f"Added EOS token check reward function with weight {args.reward_eos}") + if not reward_funcs: - logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter, --reward_element_count") + logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter, --reward_element_count, --reward_eos") return # Log summary of reward configuration