Testing reward eos

This commit is contained in:
Jake Poznanski 2025-09-23 20:10:11 +00:00
parent 01bc1ff7b6
commit b4b121b118

View File

@ -11,7 +11,7 @@ import json
import random
from pathlib import Path
import glob
from functools import lru_cache
from functools import lru_cache, partial
from concurrent.futures import ThreadPoolExecutor
from rapidfuzz import fuzz
import sys
@ -673,6 +673,48 @@ def reward_element_count(prompts, completions: list[str] | list[list[dict]], cla
return rewards
def reward_eos(eos_token_id: int, prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], **kwargs):
"""
Reward function that checks if the EOS token is the last token in completion_ids.
Returns 1.0 if the EOS token is the last token, 0.0 otherwise.
Args:
eos_token_id: The EOS token ID from the tokenizer
prompts: List of prompts
completions: List of generated completions (model outputs)
completion_ids: List of lists of token IDs for each completion
**kwargs: Additional arguments
Returns:
List of reward scores (1.0 if EOS is last, 0.0 otherwise)
"""
logger.info(f"Running EOS reward function for {len(completions)} completions (EOS token ID: {eos_token_id})")
rewards = []
for i, comp_ids in enumerate(completion_ids):
print("Testing, ", i, comp_ids[:-100])
if comp_ids and len(comp_ids) > 0:
last_token = comp_ids[-1]
if last_token == eos_token_id:
rewards.append(1.0)
logger.debug(f"Completion {i}: EOS token {last_token} found at end")
else:
rewards.append(0.0)
logger.debug(f"Completion {i}: Last token {last_token} is not EOS (expected {eos_token_id})")
else:
# Empty completion, no EOS
rewards.append(0.0)
logger.debug(f"Completion {i}: Empty completion, no EOS")
eos_count = sum(rewards)
logger.info(f"EOS rewards: {eos_count}/{len(rewards)} completions have EOS as last token")
return rewards
def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs):
"""
Reward function that runs unit tests on completions and returns average pass rate.
@ -883,6 +925,14 @@ def main():
default=None,
help="Use element count matching reward (tables and math equations) with optional weight (default: 1.0)"
)
parser.add_argument(
"--reward_eos",
nargs='?',
const=1.0,
type=float,
default=None,
help="Use EOS token check reward - scores 1 if EOS is last token, 0 otherwise (default: 1.0)"
)
parser.add_argument(
"--vllm_mode",
type=str,
@ -1016,8 +1066,19 @@ def main():
reward_names.append("element_count")
logger.info(f"Added element count matching reward function with weight {args.reward_element_count}")
if args.reward_eos is not None:
# Get EOS token ID from processor's tokenizer
eos_token_id = processor.tokenizer.eos_token_id
logger.info(f"EOS token ID from tokenizer: {eos_token_id}")
# Use partial to bind the EOS token ID to the reward function
reward_eos_partial = partial(reward_eos, eos_token_id)
reward_funcs.append(reward_eos_partial)
reward_weights.append(args.reward_eos)
reward_names.append("eos")
logger.info(f"Added EOS token check reward function with weight {args.reward_eos}")
if not reward_funcs:
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter, --reward_element_count")
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter, --reward_element_count, --reward_eos")
return
# Log summary of reward configuration