mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 03:25:22 +00:00
Moving test code around, adding format reward since some runs stop outputting the front matter thing in grpo training
This commit is contained in:
parent
8383865392
commit
d70208d98a
File diff suppressed because it is too large
Load Diff
@ -32,6 +32,8 @@ from io import BytesIO
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
|
||||
from olmocr.bench.tests import load_single_test
|
||||
from olmocr.train.dataloader import FrontMatterParser
|
||||
from olmocr.prompts import PageResponse
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -455,6 +457,56 @@ def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs):
|
||||
return rewards
|
||||
|
||||
|
||||
def reward_format(prompts, completions: list[str] | list[list[dict]], **kwargs):
|
||||
"""
|
||||
Reward function that checks if completions can be successfully parsed by FrontMatterParser.
|
||||
|
||||
Returns 1.0 if the completion can be parsed without errors, 0.0 otherwise.
|
||||
This ensures the model generates properly formatted YAML front matter that can be
|
||||
parsed into a PageResponse object.
|
||||
|
||||
Args:
|
||||
prompts: List of prompts
|
||||
completions: List of generated completions (model outputs)
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
List of reward scores: 1.0 for successful parsing, 0.0 for errors
|
||||
"""
|
||||
logger.info(f"Running format reward function for {len(completions)} completions")
|
||||
|
||||
rewards = []
|
||||
parser = FrontMatterParser(front_matter_class=PageResponse)
|
||||
|
||||
for i, completion in enumerate(completions):
|
||||
# Extract text from completion
|
||||
if isinstance(completion, list):
|
||||
model_response_markdown = completion[0]["content"] if completion else ""
|
||||
elif isinstance(completion, str):
|
||||
model_response_markdown = completion
|
||||
else:
|
||||
model_response_markdown = ""
|
||||
|
||||
try:
|
||||
# Try to parse the completion using the same logic as in pipeline.py
|
||||
front_matter, text = parser._extract_front_matter_and_text(model_response_markdown)
|
||||
page_response = parser._parse_front_matter(front_matter, text)
|
||||
|
||||
# If we get here without exception, parsing succeeded
|
||||
rewards.append(1.0)
|
||||
logger.debug(f"Completion {i}: Successfully parsed format")
|
||||
|
||||
except Exception as e:
|
||||
# Any parsing error results in 0 reward
|
||||
rewards.append(0.0)
|
||||
logger.debug(f"Completion {i}: Failed to parse format - {type(e).__name__}: {str(e)}")
|
||||
|
||||
success_count = sum(1 for r in rewards if r == 1.0)
|
||||
logger.info(f"Format rewards: {success_count}/{len(rewards)} successfully parsed")
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs):
|
||||
"""
|
||||
Reward function that runs unit tests on completions and returns average pass rate.
|
||||
@ -640,6 +692,14 @@ def main():
|
||||
default=None,
|
||||
help="Use bench edit distance reward with optional weight (default: 1.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward_format",
|
||||
nargs='?',
|
||||
const=1.0,
|
||||
type=float,
|
||||
default=None,
|
||||
help="Use format validation reward with optional weight (default: 1.0)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -739,8 +799,14 @@ def main():
|
||||
reward_names.append("bench_edit_distance")
|
||||
logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}")
|
||||
|
||||
if args.reward_format is not None:
|
||||
reward_funcs.append(reward_format)
|
||||
reward_weights.append(args.reward_format)
|
||||
reward_names.append("format")
|
||||
logger.info(f"Added format validation reward function with weight {args.reward_format}")
|
||||
|
||||
if not reward_funcs:
|
||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance")
|
||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_format")
|
||||
return
|
||||
|
||||
# Log summary of reward configuration
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user