Olmocr bench grpo stuff

This commit is contained in:
Jake Poznanski 2025-08-21 18:17:07 +00:00
parent 6184c94c3c
commit 1dd6ff9b03
3 changed files with 296 additions and 49 deletions

View File

@ -11,6 +11,7 @@ import json
import random
from pathlib import Path
import glob
from functools import lru_cache
import torch
import numpy as np
@ -28,6 +29,7 @@ from io import BytesIO
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
from olmocr.bench.tests import load_tests
# Configure logging
logging.basicConfig(
@ -170,11 +172,101 @@ class OlmOCRDataset(Dataset):
# Return None if processing fails
return None
def simple_length_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs):
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
logger.info(f"Reward function called {kwargs}")
# return [float(len(ids)) for ids in completions_ids]
return [random.choice([0.1, 0.5]) for x in completions]
@lru_cache(maxsize=128)
def load_tests_cached(jsonl_file: str):
"""
Cached version of load_tests to avoid reloading the same JSONL file multiple times.
Args:
jsonl_file: Path to the JSONL file containing test definitions
Returns:
List of test objects loaded from the file
"""
logger.info(f"Loading tests from {jsonl_file} (will be cached)")
return load_tests(jsonl_file)
def unit_test_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs):
"""
Reward function that runs unit tests on completions and returns average pass rate.
For each completion, loads the corresponding tests from the JSONL file and runs them.
Returns the proportion of tests that pass as the reward score.
Args:
prompts: List of prompts
completions: List of generated completions (model outputs)
completion_ids: List of completion token IDs
pdf_path: Path to the PDF file being processed
jsonl_file: Path to the JSONL file containing test definitions
test_ids: List of test IDs associated with this PDF page
**kwargs: Additional arguments
Returns:
List of reward scores (0.0 to 1.0) based on test pass rates
"""
logger.info(f"Running unit test reward function for {len(completions)} completions")
logger.info(f"PDF: {pdf_path}, JSONL: {jsonl_file}, Test IDs: {test_ids}")
rewards = []
# Load all tests from the JSONL file (cached)
try:
all_tests = load_tests_cached(jsonl_file)
# Filter to only the tests for this specific PDF page
relevant_tests = [test for test in all_tests if test.id in test_ids]
if not relevant_tests:
logger.warning(f"No relevant tests found for test IDs: {test_ids}")
# Return a small positive reward to avoid training issues
return [0.1 for _ in completions]
logger.info(f"Found {len(relevant_tests)} relevant tests for this PDF page")
# Process each completion
for i, completion in enumerate(completions):
if not completion or not isinstance(completion, str):
logger.warning(f"Invalid completion at index {i}: {type(completion)}")
rewards.append(0.0)
continue
# Run all relevant tests on this completion
passed = 0
total = len(relevant_tests)
for test in relevant_tests:
try:
test_passed, failure_reason = test.run(completion)
if test_passed:
passed += 1
else:
logger.debug(f"Test {test.id} failed: {failure_reason}")
except Exception as e:
logger.warning(f"Error running test {test.id}: {e}")
# Count errored tests as failures
continue
# Calculate reward as proportion of tests passed
reward = passed / total if total > 0 else 0.0
rewards.append(reward)
logger.info(f"Completion {i}: {passed}/{total} tests passed, reward={reward:.3f}")
except Exception as e:
logger.error(f"Error in unit_test_reward function: {e}")
# Return small positive rewards to avoid training issues
return [0.1 for _ in completions]
# Ensure we always return rewards between 0 and 1
rewards = [max(0.0, min(1.0, r)) for r in rewards]
# If all rewards are 0, add a small epsilon to avoid training issues
if all(r == 0.0 for r in rewards):
logger.warning("All completions failed all tests, adding small epsilon reward")
rewards = [0.01 for _ in rewards]
return rewards
def main():
@ -370,7 +462,7 @@ def main():
processing_class=processor,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
reward_funcs=simple_length_reward,
reward_funcs=unit_test_reward,
)
# Start training

View File

@ -166,6 +166,7 @@ commands = [
"pip install trl wandb",
"pip install transformers==4.55.2", # Updated for GRPO compatibility
"pip install flash-attn==2.8.0.post2 --no-build-isolation",
"pip install vllm==v0.10.1.1",
"pip install s5cmd",
# Sync the bench data from S3

View File

@ -16,7 +16,7 @@ import shutil
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from olmocr.train.grpo_train import OlmOCRDataset, collate_fn
from olmocr.train.grpo_train import OlmOCRDataset, unit_test_reward, load_tests_cached
class TestGRPODataloader(unittest.TestCase):
@ -179,48 +179,6 @@ class TestGRPODataloader(unittest.TestCase):
self.assertIsInstance(item["test_ids"], list)
self.assertTrue(len(item["test_ids"]) > 0)
@patch('olmocr.train.grpo_train.render_pdf_to_base64png')
@patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt')
def test_collate_function(self, mock_prompt, mock_render):
"""Test that the collate function works correctly."""
# Mock the rendering and prompt functions
mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
mock_prompt.return_value = "Test prompt"
dataset = OlmOCRDataset(
bench_data_folder=self.bench_data_folder,
processor=None,
max_samples=2,
target_longest_image_dim=1024,
)
# Create a batch
batch = [dataset[0], dataset[1]]
collated = collate_fn(batch)
self.assertIsNotNone(collated)
self.assertIn("prompts", collated)
self.assertIn("images", collated)
self.assertIn("pdf_paths", collated)
self.assertIn("jsonl_files", collated)
self.assertIn("test_ids", collated)
# Check batch size consistency
self.assertEqual(len(collated["prompts"]), 2)
self.assertEqual(len(collated["images"]), 2)
self.assertEqual(len(collated["pdf_paths"]), 2)
self.assertEqual(len(collated["jsonl_files"]), 2)
self.assertEqual(len(collated["test_ids"]), 2)
def test_collate_with_none_values(self):
"""Test that collate function handles None values correctly."""
batch = [None, {"prompt": [], "image": None, "pdf_path": "test.pdf",
"jsonl_file": "test.jsonl", "test_ids": ["id1"]}, None]
collated = collate_fn(batch)
self.assertIsNotNone(collated)
self.assertEqual(len(collated["prompts"]), 1)
def test_empty_jsonl_handling(self):
"""Test handling of empty JSONL files."""
@ -300,6 +258,202 @@ class TestGRPODataloader(unittest.TestCase):
shutil.rmtree(temp_folder)
class TestUnitTestReward(unittest.TestCase):
"""Test cases for the unit_test_reward function."""
@classmethod
def setUpClass(cls):
"""Create temporary test files."""
# Clear any cached tests from previous runs
load_tests_cached.cache_clear()
cls.temp_dir = tempfile.mkdtemp()
# Create a sample JSONL test file with different test types
cls.jsonl_path = os.path.join(cls.temp_dir, "test.jsonl")
test_data = [
{
"pdf": "test.pdf",
"page": 0,
"id": "test1",
"type": "present",
"text": "Hello World",
"max_diffs": 0
},
{
"pdf": "test.pdf",
"page": 0,
"id": "test2",
"type": "absent",
"text": "Bad Text",
"max_diffs": 0
},
{
"pdf": "test.pdf",
"page": 0,
"id": "test3",
"type": "baseline",
"max_repeats": 30
},
{
"pdf": "test.pdf",
"page": 0,
"id": "test4",
"type": "order",
"before": "First",
"after": "Second",
"max_diffs": 0
}
]
with open(cls.jsonl_path, 'w') as f:
for test in test_data:
f.write(json.dumps(test) + '\n')
@classmethod
def tearDownClass(cls):
"""Clean up temporary files."""
# Clear the LRU cache before removing temp dir
load_tests_cached.cache_clear()
shutil.rmtree(cls.temp_dir)
def setUp(self):
"""Clear cache before each test method."""
load_tests_cached.cache_clear()
def test_perfect_completion(self):
"""Test reward calculation for a completion that passes all tests."""
completions = ["Hello World\n\nFirst paragraph.\n\nSecond paragraph.\n\nThis is a good document with no bad text."]
test_ids = ["test1", "test2", "test3", "test4"]
rewards = unit_test_reward(
prompts=["prompt"],
completions=completions,
completion_ids=[[]],
pdf_path="test.pdf",
jsonl_file=self.jsonl_path,
test_ids=test_ids
)
self.assertEqual(len(rewards), 1)
self.assertEqual(rewards[0], 1.0) # All 4 tests should pass
def test_partial_completion(self):
"""Test reward calculation for a completion that passes some tests."""
completions = ["This document contains Bad Text but nothing else of note."]
test_ids = ["test1", "test2", "test3"]
rewards = unit_test_reward(
prompts=["prompt"],
completions=completions,
completion_ids=[[]],
pdf_path="test.pdf",
jsonl_file=self.jsonl_path,
test_ids=test_ids
)
self.assertEqual(len(rewards), 1)
# Should pass only baseline test (1/3)
self.assertAlmostEqual(rewards[0], 1/3, places=2)
def test_multiple_completions(self):
"""Test reward calculation for multiple completions."""
completions = [
"Hello World with good content. First then Second.",
"Bad Text only",
"", # Empty completion
]
test_ids = ["test1", "test2", "test3", "test4"]
rewards = unit_test_reward(
prompts=["prompt"] * 3,
completions=completions,
completion_ids=[[]] * 3,
pdf_path="test.pdf",
jsonl_file=self.jsonl_path,
test_ids=test_ids
)
self.assertEqual(len(rewards), 3)
# First should pass all 4 tests
self.assertEqual(rewards[0], 1.0)
# Second should pass only baseline (1/4)
self.assertEqual(rewards[1], 0.25)
# Third (empty string) passes only the "absent" test (1/4)
self.assertEqual(rewards[2], 0.25)
def test_no_relevant_tests(self):
"""Test behavior when no relevant tests are found."""
completions = ["Some content"]
test_ids = ["nonexistent_test"]
rewards = unit_test_reward(
prompts=["prompt"],
completions=completions,
completion_ids=[[]],
pdf_path="test.pdf",
jsonl_file=self.jsonl_path,
test_ids=test_ids
)
self.assertEqual(len(rewards), 1)
self.assertEqual(rewards[0], 0.1) # Default reward when no tests found
def test_invalid_completion(self):
"""Test handling of invalid completions."""
completions = [None, "", "Valid content with Hello World"]
test_ids = ["test1"]
rewards = unit_test_reward(
prompts=["prompt"] * 3,
completions=completions,
completion_ids=[[]] * 3,
pdf_path="test.pdf",
jsonl_file=self.jsonl_path,
test_ids=test_ids
)
self.assertEqual(len(rewards), 3)
# First two should get 0 or epsilon
self.assertLessEqual(rewards[0], 0.01)
self.assertLessEqual(rewards[1], 0.01)
# Last should pass the test
self.assertEqual(rewards[2], 1.0)
def test_cache_functionality(self):
"""Test that load_tests_cached properly caches results."""
# Clear cache first
load_tests_cached.cache_clear()
# First call should load from file
with patch('olmocr.train.grpo_train.load_tests') as mock_load:
mock_load.return_value = []
result1 = load_tests_cached(self.jsonl_path)
self.assertEqual(mock_load.call_count, 1)
# Second call should use cache
result2 = load_tests_cached(self.jsonl_path)
self.assertEqual(mock_load.call_count, 1) # Should not increase
# Results should be the same
self.assertEqual(result1, result2)
def test_error_handling(self):
"""Test error handling in reward function."""
# Test with non-existent file
rewards = unit_test_reward(
prompts=["prompt"],
completions=["content"],
completion_ids=[[]],
pdf_path="test.pdf",
jsonl_file="/nonexistent/file.jsonl",
test_ids=["test1"]
)
# Should return default reward on error
self.assertEqual(len(rewards), 1)
self.assertEqual(rewards[0], 0.1)
class TestIntegrationWithRealData(unittest.TestCase):
"""Integration tests with real bench data if available."""