mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-10 07:42:13 +00:00
Testing reward eos
This commit is contained in:
parent
01bc1ff7b6
commit
b4b121b118
@ -11,7 +11,7 @@ import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
import glob
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from rapidfuzz import fuzz
|
||||
import sys
|
||||
@ -673,6 +673,48 @@ def reward_element_count(prompts, completions: list[str] | list[list[dict]], cla
|
||||
return rewards
|
||||
|
||||
|
||||
def reward_eos(eos_token_id: int, prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], **kwargs):
|
||||
"""
|
||||
Reward function that checks if the EOS token is the last token in completion_ids.
|
||||
|
||||
Returns 1.0 if the EOS token is the last token, 0.0 otherwise.
|
||||
|
||||
Args:
|
||||
eos_token_id: The EOS token ID from the tokenizer
|
||||
prompts: List of prompts
|
||||
completions: List of generated completions (model outputs)
|
||||
completion_ids: List of lists of token IDs for each completion
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
List of reward scores (1.0 if EOS is last, 0.0 otherwise)
|
||||
"""
|
||||
logger.info(f"Running EOS reward function for {len(completions)} completions (EOS token ID: {eos_token_id})")
|
||||
|
||||
rewards = []
|
||||
|
||||
for i, comp_ids in enumerate(completion_ids):
|
||||
print("Testing, ", i, comp_ids[:-100])
|
||||
|
||||
if comp_ids and len(comp_ids) > 0:
|
||||
last_token = comp_ids[-1]
|
||||
if last_token == eos_token_id:
|
||||
rewards.append(1.0)
|
||||
logger.debug(f"Completion {i}: EOS token {last_token} found at end")
|
||||
else:
|
||||
rewards.append(0.0)
|
||||
logger.debug(f"Completion {i}: Last token {last_token} is not EOS (expected {eos_token_id})")
|
||||
else:
|
||||
# Empty completion, no EOS
|
||||
rewards.append(0.0)
|
||||
logger.debug(f"Completion {i}: Empty completion, no EOS")
|
||||
|
||||
eos_count = sum(rewards)
|
||||
logger.info(f"EOS rewards: {eos_count}/{len(rewards)} completions have EOS as last token")
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs):
|
||||
"""
|
||||
Reward function that runs unit tests on completions and returns average pass rate.
|
||||
@ -883,6 +925,14 @@ def main():
|
||||
default=None,
|
||||
help="Use element count matching reward (tables and math equations) with optional weight (default: 1.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward_eos",
|
||||
nargs='?',
|
||||
const=1.0,
|
||||
type=float,
|
||||
default=None,
|
||||
help="Use EOS token check reward - scores 1 if EOS is last token, 0 otherwise (default: 1.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm_mode",
|
||||
type=str,
|
||||
@ -1015,9 +1065,20 @@ def main():
|
||||
reward_weights.append(args.reward_element_count)
|
||||
reward_names.append("element_count")
|
||||
logger.info(f"Added element count matching reward function with weight {args.reward_element_count}")
|
||||
|
||||
|
||||
if args.reward_eos is not None:
|
||||
# Get EOS token ID from processor's tokenizer
|
||||
eos_token_id = processor.tokenizer.eos_token_id
|
||||
logger.info(f"EOS token ID from tokenizer: {eos_token_id}")
|
||||
# Use partial to bind the EOS token ID to the reward function
|
||||
reward_eos_partial = partial(reward_eos, eos_token_id)
|
||||
reward_funcs.append(reward_eos_partial)
|
||||
reward_weights.append(args.reward_eos)
|
||||
reward_names.append("eos")
|
||||
logger.info(f"Added EOS token check reward function with weight {args.reward_eos}")
|
||||
|
||||
if not reward_funcs:
|
||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter, --reward_element_count")
|
||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter, --reward_element_count, --reward_eos")
|
||||
return
|
||||
|
||||
# Log summary of reward configuration
|
||||
|
Loading…
x
Reference in New Issue
Block a user