mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 16:52:20 +00:00
Fixing for conv format
This commit is contained in:
parent
0f8d515d8c
commit
33d889c748
@ -187,7 +187,7 @@ def load_tests_cached(jsonl_file: str):
|
||||
return load_tests(jsonl_file)
|
||||
|
||||
|
||||
def olmocr_bench_reward(prompts, completions: list[str], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs):
|
||||
def olmocr_bench_reward(prompts, completions: list[str] | list[list[dict]], completion_ids: list[list[int]], pdf_path: list[str], jsonl_file: list[str], test_ids: list[list[str]], **kwargs):
|
||||
"""
|
||||
Reward function that runs unit tests on completions and returns average pass rate.
|
||||
|
||||
@ -219,7 +219,7 @@ def olmocr_bench_reward(prompts, completions: list[str], completion_ids: list[li
|
||||
|
||||
logger.info(f"Completion {i}: PDF: {comp_pdf_path}, JSONL: {comp_jsonl_file}, Test IDs: {comp_test_ids}")
|
||||
|
||||
if completion is None or not isinstance(completion, str):
|
||||
if completion is None or not isinstance(completion, str) or not isinstance(completion, list):
|
||||
logger.warning(f"Invalid completion at index {i}: {type(completion)}")
|
||||
logger.warning(f"completion: {completion}")
|
||||
rewards.append(None)
|
||||
@ -229,6 +229,9 @@ def olmocr_bench_reward(prompts, completions: list[str], completion_ids: list[li
|
||||
logger.warning(f"Missing metadata for completion {i}")
|
||||
rewards.append(None)
|
||||
continue
|
||||
|
||||
if isinstance(completion, list):
|
||||
completion = completion[0]["content"]
|
||||
|
||||
try:
|
||||
# Load all tests from the JSONL file (cached)
|
||||
|
Loading…
x
Reference in New Issue
Block a user