mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-10 15:52:31 +00:00
Testing reward eos
This commit is contained in:
parent
01bc1ff7b6
commit
b4b121b118
@ -11,7 +11,7 @@ import json
|
|||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import glob
|
import glob
|
||||||
from functools import lru_cache
|
from functools import lru_cache, partial
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from rapidfuzz import fuzz
|
from rapidfuzz import fuzz
|
||||||
import sys
|
import sys
|
||||||
@ -673,6 +673,48 @@ def reward_element_count(prompts, completions: list[str] | list[list[dict]], cla
|
|||||||
return rewards
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
|
def reward_eos(eos_token_id: int, prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], **kwargs):
|
||||||
|
"""
|
||||||
|
Reward function that checks if the EOS token is the last token in completion_ids.
|
||||||
|
|
||||||
|
Returns 1.0 if the EOS token is the last token, 0.0 otherwise.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eos_token_id: The EOS token ID from the tokenizer
|
||||||
|
prompts: List of prompts
|
||||||
|
completions: List of generated completions (model outputs)
|
||||||
|
completion_ids: List of lists of token IDs for each completion
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of reward scores (1.0 if EOS is last, 0.0 otherwise)
|
||||||
|
"""
|
||||||
|
logger.info(f"Running EOS reward function for {len(completions)} completions (EOS token ID: {eos_token_id})")
|
||||||
|
|
||||||
|
rewards = []
|
||||||
|
|
||||||
|
for i, comp_ids in enumerate(completion_ids):
|
||||||
|
print("Testing, ", i, comp_ids[:-100])
|
||||||
|
|
||||||
|
if comp_ids and len(comp_ids) > 0:
|
||||||
|
last_token = comp_ids[-1]
|
||||||
|
if last_token == eos_token_id:
|
||||||
|
rewards.append(1.0)
|
||||||
|
logger.debug(f"Completion {i}: EOS token {last_token} found at end")
|
||||||
|
else:
|
||||||
|
rewards.append(0.0)
|
||||||
|
logger.debug(f"Completion {i}: Last token {last_token} is not EOS (expected {eos_token_id})")
|
||||||
|
else:
|
||||||
|
# Empty completion, no EOS
|
||||||
|
rewards.append(0.0)
|
||||||
|
logger.debug(f"Completion {i}: Empty completion, no EOS")
|
||||||
|
|
||||||
|
eos_count = sum(rewards)
|
||||||
|
logger.info(f"EOS rewards: {eos_count}/{len(rewards)} completions have EOS as last token")
|
||||||
|
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs):
|
def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs):
|
||||||
"""
|
"""
|
||||||
Reward function that runs unit tests on completions and returns average pass rate.
|
Reward function that runs unit tests on completions and returns average pass rate.
|
||||||
@ -883,6 +925,14 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Use element count matching reward (tables and math equations) with optional weight (default: 1.0)"
|
help="Use element count matching reward (tables and math equations) with optional weight (default: 1.0)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reward_eos",
|
||||||
|
nargs='?',
|
||||||
|
const=1.0,
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Use EOS token check reward - scores 1 if EOS is last token, 0 otherwise (default: 1.0)"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vllm_mode",
|
"--vllm_mode",
|
||||||
type=str,
|
type=str,
|
||||||
@ -1016,8 +1066,19 @@ def main():
|
|||||||
reward_names.append("element_count")
|
reward_names.append("element_count")
|
||||||
logger.info(f"Added element count matching reward function with weight {args.reward_element_count}")
|
logger.info(f"Added element count matching reward function with weight {args.reward_element_count}")
|
||||||
|
|
||||||
|
if args.reward_eos is not None:
|
||||||
|
# Get EOS token ID from processor's tokenizer
|
||||||
|
eos_token_id = processor.tokenizer.eos_token_id
|
||||||
|
logger.info(f"EOS token ID from tokenizer: {eos_token_id}")
|
||||||
|
# Use partial to bind the EOS token ID to the reward function
|
||||||
|
reward_eos_partial = partial(reward_eos, eos_token_id)
|
||||||
|
reward_funcs.append(reward_eos_partial)
|
||||||
|
reward_weights.append(args.reward_eos)
|
||||||
|
reward_names.append("eos")
|
||||||
|
logger.info(f"Added EOS token check reward function with weight {args.reward_eos}")
|
||||||
|
|
||||||
if not reward_funcs:
|
if not reward_funcs:
|
||||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter, --reward_element_count")
|
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter, --reward_element_count, --reward_eos")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Log summary of reward configuration
|
# Log summary of reward configuration
|
||||||
|
Loading…
x
Reference in New Issue
Block a user