From 33d889c7480527ca23b1b71e3de399b2dba0c44f Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Thu, 21 Aug 2025 18:47:53 +0000 Subject: [PATCH] Fixing for conv format --- olmocr/train/grpo_train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index 992a0a0..f562c03 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -187,7 +187,7 @@ def load_tests_cached(jsonl_file: str): return load_tests(jsonl_file) -def olmocr_bench_reward(prompts, completions: list[str], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs): +def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs): """ Reward function that runs unit tests on completions and returns average pass rate. @@ -219,7 +219,7 @@ def olmocr_bench_reward(prompts, completions: list[str], completion_ids: list[li logger.info(f"Completion {i}: PDF: {comp_pdf_path}, JSONL: {comp_jsonl_file}, Test IDs: {comp_test_ids}") - if completion is None or not isinstance(completion, str): + if completion is None or not isinstance(completion, str) or not isinstance(completion, list): logger.warning(f"Invalid completion at index {i}: {type(completion)}") logger.warning(f"completion: {completion}") rewards.append(None) @@ -229,6 +229,9 @@ def olmocr_bench_reward(prompts, completions: list[str], completion_ids: list[li logger.warning(f"Missing metadata for completion {i}") rewards.append(None) continue + + if isinstance(completion, list): + completion = completion[0]["content"] try: # Load all tests from the JSONL file (cached)