From 1dd6ff9b031d2052817cba1afb2fd8abc22a2cc3 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Thu, 21 Aug 2025 18:17:07 +0000 Subject: [PATCH] Olmocr bench grpo stuff --- olmocr/train/grpo_train.py | 104 +++++++++++- scripts/train/grpotrainer-beaker.sh | 1 + tests/test_grpo.py | 240 +++++++++++++++++++++++----- 3 files changed, 296 insertions(+), 49 deletions(-) diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index d29f290..12b4561 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -11,6 +11,7 @@ import json import random from pathlib import Path import glob +from functools import lru_cache import torch import numpy as np @@ -28,6 +29,7 @@ from io import BytesIO from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.prompts import build_no_anchoring_v4_yaml_prompt +from olmocr.bench.tests import load_tests # Configure logging logging.basicConfig( @@ -170,11 +172,101 @@ class OlmOCRDataset(Dataset): # Return None if processing fails return None -def simple_length_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs): - """Reward function that assigns higher scores to longer completions (in terms of token count).""" - logger.info(f"Reward function called {kwargs}") - # return [float(len(ids)) for ids in completions_ids] - return [random.choice([0.1, 0.5]) for x in completions] +@lru_cache(maxsize=128) +def load_tests_cached(jsonl_file: str): + """ + Cached version of load_tests to avoid reloading the same JSONL file multiple times. + + Args: + jsonl_file: Path to the JSONL file containing test definitions + + Returns: + List of test objects loaded from the file + """ + logger.info(f"Loading tests from {jsonl_file} (will be cached)") + return load_tests(jsonl_file) + + +def unit_test_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs): + """ + Reward function that runs unit tests on completions and returns average pass rate. + + For each completion, loads the corresponding tests from the JSONL file and runs them. + Returns the proportion of tests that pass as the reward score. + + Args: + prompts: List of prompts + completions: List of generated completions (model outputs) + completion_ids: List of completion token IDs + pdf_path: Path to the PDF file being processed + jsonl_file: Path to the JSONL file containing test definitions + test_ids: List of test IDs associated with this PDF page + **kwargs: Additional arguments + + Returns: + List of reward scores (0.0 to 1.0) based on test pass rates + """ + logger.info(f"Running unit test reward function for {len(completions)} completions") + logger.info(f"PDF: {pdf_path}, JSONL: {jsonl_file}, Test IDs: {test_ids}") + + rewards = [] + + # Load all tests from the JSONL file (cached) + try: + all_tests = load_tests_cached(jsonl_file) + # Filter to only the tests for this specific PDF page + relevant_tests = [test for test in all_tests if test.id in test_ids] + + if not relevant_tests: + logger.warning(f"No relevant tests found for test IDs: {test_ids}") + # Return a small positive reward to avoid training issues + return [0.1 for _ in completions] + + logger.info(f"Found {len(relevant_tests)} relevant tests for this PDF page") + + # Process each completion + for i, completion in enumerate(completions): + if not completion or not isinstance(completion, str): + logger.warning(f"Invalid completion at index {i}: {type(completion)}") + rewards.append(0.0) + continue + + # Run all relevant tests on this completion + passed = 0 + total = len(relevant_tests) + + for test in relevant_tests: + try: + test_passed, failure_reason = test.run(completion) + if test_passed: + passed += 1 + else: + logger.debug(f"Test {test.id} failed: {failure_reason}") + except Exception as e: + logger.warning(f"Error running test {test.id}: {e}") + # Count errored tests as failures + continue + + # Calculate reward as proportion of tests passed + reward = passed / total if total > 0 else 0.0 + rewards.append(reward) + + logger.info(f"Completion {i}: {passed}/{total} tests passed, reward={reward:.3f}") + + except Exception as e: + logger.error(f"Error in unit_test_reward function: {e}") + # Return small positive rewards to avoid training issues + return [0.1 for _ in completions] + + # Ensure we always return rewards between 0 and 1 + rewards = [max(0.0, min(1.0, r)) for r in rewards] + + # If all rewards are 0, add a small epsilon to avoid training issues + if all(r == 0.0 for r in rewards): + logger.warning("All completions failed all tests, adding small epsilon reward") + rewards = [0.01 for _ in rewards] + + return rewards def main(): @@ -370,7 +462,7 @@ def main(): processing_class=processor, train_dataset=train_dataset, eval_dataset=eval_dataset, - reward_funcs=simple_length_reward, + reward_funcs=unit_test_reward, ) # Start training diff --git a/scripts/train/grpotrainer-beaker.sh b/scripts/train/grpotrainer-beaker.sh index a4fb01f..353e416 100755 --- a/scripts/train/grpotrainer-beaker.sh +++ b/scripts/train/grpotrainer-beaker.sh @@ -166,6 +166,7 @@ commands = [ "pip install trl wandb", "pip install transformers==4.55.2", # Updated for GRPO compatibility "pip install flash-attn==2.8.0.post2 --no-build-isolation", + "pip install vllm==v0.10.1.1", "pip install s5cmd", # Sync the bench data from S3 diff --git a/tests/test_grpo.py b/tests/test_grpo.py index 8c2a235..d23d808 100644 --- a/tests/test_grpo.py +++ b/tests/test_grpo.py @@ -16,7 +16,7 @@ import shutil # Add parent directory to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from olmocr.train.grpo_train import OlmOCRDataset, collate_fn +from olmocr.train.grpo_train import OlmOCRDataset, unit_test_reward, load_tests_cached class TestGRPODataloader(unittest.TestCase): @@ -179,48 +179,6 @@ class TestGRPODataloader(unittest.TestCase): self.assertIsInstance(item["test_ids"], list) self.assertTrue(len(item["test_ids"]) > 0) - @patch('olmocr.train.grpo_train.render_pdf_to_base64png') - @patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt') - def test_collate_function(self, mock_prompt, mock_render): - """Test that the collate function works correctly.""" - # Mock the rendering and prompt functions - mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" - mock_prompt.return_value = "Test prompt" - - dataset = OlmOCRDataset( - bench_data_folder=self.bench_data_folder, - processor=None, - max_samples=2, - target_longest_image_dim=1024, - ) - - # Create a batch - batch = [dataset[0], dataset[1]] - collated = collate_fn(batch) - - self.assertIsNotNone(collated) - self.assertIn("prompts", collated) - self.assertIn("images", collated) - self.assertIn("pdf_paths", collated) - self.assertIn("jsonl_files", collated) - self.assertIn("test_ids", collated) - - # Check batch size consistency - self.assertEqual(len(collated["prompts"]), 2) - self.assertEqual(len(collated["images"]), 2) - self.assertEqual(len(collated["pdf_paths"]), 2) - self.assertEqual(len(collated["jsonl_files"]), 2) - self.assertEqual(len(collated["test_ids"]), 2) - - def test_collate_with_none_values(self): - """Test that collate function handles None values correctly.""" - batch = [None, {"prompt": [], "image": None, "pdf_path": "test.pdf", - "jsonl_file": "test.jsonl", "test_ids": ["id1"]}, None] - - collated = collate_fn(batch) - - self.assertIsNotNone(collated) - self.assertEqual(len(collated["prompts"]), 1) def test_empty_jsonl_handling(self): """Test handling of empty JSONL files.""" @@ -300,6 +258,202 @@ class TestGRPODataloader(unittest.TestCase): shutil.rmtree(temp_folder) +class TestUnitTestReward(unittest.TestCase): + """Test cases for the unit_test_reward function.""" + + @classmethod + def setUpClass(cls): + """Create temporary test files.""" + # Clear any cached tests from previous runs + load_tests_cached.cache_clear() + cls.temp_dir = tempfile.mkdtemp() + + # Create a sample JSONL test file with different test types + cls.jsonl_path = os.path.join(cls.temp_dir, "test.jsonl") + test_data = [ + { + "pdf": "test.pdf", + "page": 0, + "id": "test1", + "type": "present", + "text": "Hello World", + "max_diffs": 0 + }, + { + "pdf": "test.pdf", + "page": 0, + "id": "test2", + "type": "absent", + "text": "Bad Text", + "max_diffs": 0 + }, + { + "pdf": "test.pdf", + "page": 0, + "id": "test3", + "type": "baseline", + "max_repeats": 30 + }, + { + "pdf": "test.pdf", + "page": 0, + "id": "test4", + "type": "order", + "before": "First", + "after": "Second", + "max_diffs": 0 + } + ] + + with open(cls.jsonl_path, 'w') as f: + for test in test_data: + f.write(json.dumps(test) + '\n') + + @classmethod + def tearDownClass(cls): + """Clean up temporary files.""" + # Clear the LRU cache before removing temp dir + load_tests_cached.cache_clear() + shutil.rmtree(cls.temp_dir) + + def setUp(self): + """Clear cache before each test method.""" + load_tests_cached.cache_clear() + + def test_perfect_completion(self): + """Test reward calculation for a completion that passes all tests.""" + completions = ["Hello World\n\nFirst paragraph.\n\nSecond paragraph.\n\nThis is a good document with no bad text."] + test_ids = ["test1", "test2", "test3", "test4"] + + rewards = unit_test_reward( + prompts=["prompt"], + completions=completions, + completion_ids=[[]], + pdf_path="test.pdf", + jsonl_file=self.jsonl_path, + test_ids=test_ids + ) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) # All 4 tests should pass + + def test_partial_completion(self): + """Test reward calculation for a completion that passes some tests.""" + completions = ["This document contains Bad Text but nothing else of note."] + test_ids = ["test1", "test2", "test3"] + + rewards = unit_test_reward( + prompts=["prompt"], + completions=completions, + completion_ids=[[]], + pdf_path="test.pdf", + jsonl_file=self.jsonl_path, + test_ids=test_ids + ) + + self.assertEqual(len(rewards), 1) + # Should pass only baseline test (1/3) + self.assertAlmostEqual(rewards[0], 1/3, places=2) + + def test_multiple_completions(self): + """Test reward calculation for multiple completions.""" + completions = [ + "Hello World with good content. First then Second.", + "Bad Text only", + "", # Empty completion + ] + test_ids = ["test1", "test2", "test3", "test4"] + + rewards = unit_test_reward( + prompts=["prompt"] * 3, + completions=completions, + completion_ids=[[]] * 3, + pdf_path="test.pdf", + jsonl_file=self.jsonl_path, + test_ids=test_ids + ) + + self.assertEqual(len(rewards), 3) + # First should pass all 4 tests + self.assertEqual(rewards[0], 1.0) + # Second should pass only baseline (1/4) + self.assertEqual(rewards[1], 0.25) + # Third (empty string) passes only the "absent" test (1/4) + self.assertEqual(rewards[2], 0.25) + + def test_no_relevant_tests(self): + """Test behavior when no relevant tests are found.""" + completions = ["Some content"] + test_ids = ["nonexistent_test"] + + rewards = unit_test_reward( + prompts=["prompt"], + completions=completions, + completion_ids=[[]], + pdf_path="test.pdf", + jsonl_file=self.jsonl_path, + test_ids=test_ids + ) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.1) # Default reward when no tests found + + def test_invalid_completion(self): + """Test handling of invalid completions.""" + completions = [None, "", "Valid content with Hello World"] + test_ids = ["test1"] + + rewards = unit_test_reward( + prompts=["prompt"] * 3, + completions=completions, + completion_ids=[[]] * 3, + pdf_path="test.pdf", + jsonl_file=self.jsonl_path, + test_ids=test_ids + ) + + self.assertEqual(len(rewards), 3) + # First two should get 0 or epsilon + self.assertLessEqual(rewards[0], 0.01) + self.assertLessEqual(rewards[1], 0.01) + # Last should pass the test + self.assertEqual(rewards[2], 1.0) + + def test_cache_functionality(self): + """Test that load_tests_cached properly caches results.""" + # Clear cache first + load_tests_cached.cache_clear() + + # First call should load from file + with patch('olmocr.train.grpo_train.load_tests') as mock_load: + mock_load.return_value = [] + result1 = load_tests_cached(self.jsonl_path) + self.assertEqual(mock_load.call_count, 1) + + # Second call should use cache + result2 = load_tests_cached(self.jsonl_path) + self.assertEqual(mock_load.call_count, 1) # Should not increase + + # Results should be the same + self.assertEqual(result1, result2) + + def test_error_handling(self): + """Test error handling in reward function.""" + # Test with non-existent file + rewards = unit_test_reward( + prompts=["prompt"], + completions=["content"], + completion_ids=[[]], + pdf_path="test.pdf", + jsonl_file="/nonexistent/file.jsonl", + test_ids=["test1"] + ) + + # Should return default reward on error + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.1) + + class TestIntegrationWithRealData(unittest.TestCase): """Integration tests with real bench data if available."""