mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-08 14:40:24 +00:00
Moving test code around, adding format reward since some runs stop outputting the front matter thing in grpo training
This commit is contained in:
parent
8383865392
commit
d70208d98a
File diff suppressed because it is too large
Load Diff
@ -32,6 +32,8 @@ from io import BytesIO
|
|||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
|
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
|
||||||
from olmocr.bench.tests import load_single_test
|
from olmocr.bench.tests import load_single_test
|
||||||
|
from olmocr.train.dataloader import FrontMatterParser
|
||||||
|
from olmocr.prompts import PageResponse
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -455,6 +457,56 @@ def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs):
|
|||||||
return rewards
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
|
def reward_format(prompts, completions: list[str] | list[list[dict]], **kwargs):
|
||||||
|
"""
|
||||||
|
Reward function that checks if completions can be successfully parsed by FrontMatterParser.
|
||||||
|
|
||||||
|
Returns 1.0 if the completion can be parsed without errors, 0.0 otherwise.
|
||||||
|
This ensures the model generates properly formatted YAML front matter that can be
|
||||||
|
parsed into a PageResponse object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of prompts
|
||||||
|
completions: List of generated completions (model outputs)
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of reward scores: 1.0 for successful parsing, 0.0 for errors
|
||||||
|
"""
|
||||||
|
logger.info(f"Running format reward function for {len(completions)} completions")
|
||||||
|
|
||||||
|
rewards = []
|
||||||
|
parser = FrontMatterParser(front_matter_class=PageResponse)
|
||||||
|
|
||||||
|
for i, completion in enumerate(completions):
|
||||||
|
# Extract text from completion
|
||||||
|
if isinstance(completion, list):
|
||||||
|
model_response_markdown = completion[0]["content"] if completion else ""
|
||||||
|
elif isinstance(completion, str):
|
||||||
|
model_response_markdown = completion
|
||||||
|
else:
|
||||||
|
model_response_markdown = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to parse the completion using the same logic as in pipeline.py
|
||||||
|
front_matter, text = parser._extract_front_matter_and_text(model_response_markdown)
|
||||||
|
page_response = parser._parse_front_matter(front_matter, text)
|
||||||
|
|
||||||
|
# If we get here without exception, parsing succeeded
|
||||||
|
rewards.append(1.0)
|
||||||
|
logger.debug(f"Completion {i}: Successfully parsed format")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Any parsing error results in 0 reward
|
||||||
|
rewards.append(0.0)
|
||||||
|
logger.debug(f"Completion {i}: Failed to parse format - {type(e).__name__}: {str(e)}")
|
||||||
|
|
||||||
|
success_count = sum(1 for r in rewards if r == 1.0)
|
||||||
|
logger.info(f"Format rewards: {success_count}/{len(rewards)} successfully parsed")
|
||||||
|
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs):
|
def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs):
|
||||||
"""
|
"""
|
||||||
Reward function that runs unit tests on completions and returns average pass rate.
|
Reward function that runs unit tests on completions and returns average pass rate.
|
||||||
@ -640,6 +692,14 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Use bench edit distance reward with optional weight (default: 1.0)"
|
help="Use bench edit distance reward with optional weight (default: 1.0)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reward_format",
|
||||||
|
nargs='?',
|
||||||
|
const=1.0,
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Use format validation reward with optional weight (default: 1.0)"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -739,8 +799,14 @@ def main():
|
|||||||
reward_names.append("bench_edit_distance")
|
reward_names.append("bench_edit_distance")
|
||||||
logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}")
|
logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}")
|
||||||
|
|
||||||
|
if args.reward_format is not None:
|
||||||
|
reward_funcs.append(reward_format)
|
||||||
|
reward_weights.append(args.reward_format)
|
||||||
|
reward_names.append("format")
|
||||||
|
logger.info(f"Added format validation reward function with weight {args.reward_format}")
|
||||||
|
|
||||||
if not reward_funcs:
|
if not reward_funcs:
|
||||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance")
|
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_format")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Log summary of reward configuration
|
# Log summary of reward configuration
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user