diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index 890787a..9081e5f 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -472,52 +472,103 @@ def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs): return rewards -def reward_format(prompts, completions: list[str] | list[list[dict]], **kwargs): +def reward_front_matter(prompts, completions: list[str] | list[list[dict]], claude_original: list[Optional[str]] = None, **kwargs): """ - Reward function that checks if completions can be successfully parsed by FrontMatterParser. + Reward function that checks if completions can be successfully parsed by FrontMatterParser + and compares fields to claude_original values. - Returns 1.0 if the completion can be parsed without errors, 0.0 otherwise. - This ensures the model generates properly formatted YAML front matter that can be - parsed into a PageResponse object. + Scoring: + - 0.0: Cannot parse frontmatter at all + - 0.5: Can parse frontmatter successfully + - +0.1: For each matching field (primary_language, is_rotation_valid, + rotation_correction, is_table, is_diagram) + + Maximum score: 1.0 (0.5 + 5 * 0.1) Args: prompts: List of prompts completions: List of generated completions (model outputs) + claude_original: List of claude_original markdown content (optional) **kwargs: Additional arguments Returns: - List of reward scores: 1.0 for successful parsing, 0.0 for errors + List of reward scores between 0.0 and 1.0 """ - logger.info(f"Running format reward function for {len(completions)} completions") + logger.info(f"Running front matter reward function for {len(completions)} completions") rewards = [] parser = FrontMatterParser(front_matter_class=PageResponse) + # Fields to compare + fields_to_compare = ['primary_language', 'is_rotation_valid', 'rotation_correction', + 'is_table', 'is_diagram'] + for i, completion in enumerate(completions): # Extract text from completion if isinstance(completion, list): - model_response_markdown = completion[0]["content"] if completion else "" + if completion and "content" in completion[0]: + model_response_markdown = completion[0]["content"] + else: + model_response_markdown = "" elif isinstance(completion, str): model_response_markdown = completion else: model_response_markdown = "" + reward = 0.0 + try: - # Try to parse the completion using the same logic as in pipeline.py + # Try to parse the completion front_matter, text = parser._extract_front_matter_and_text(model_response_markdown) - page_response = parser._parse_front_matter(front_matter, text) + completion_response = parser._parse_front_matter(front_matter, text) - # If we get here without exception, parsing succeeded - rewards.append(1.0) - logger.debug(f"Completion {i}: Successfully parsed format") + # Parsing succeeded - base reward of 0.5 + reward = 0.5 + logger.debug(f"Completion {i}: Successfully parsed frontmatter (base reward: 0.5)") + + # Try to compare with claude_original if available + if claude_original and i < len(claude_original) and claude_original[i]: + try: + # Parse claude_original frontmatter + claude_fm, claude_text = parser._extract_front_matter_and_text(claude_original[i]) + claude_response = parser._parse_front_matter(claude_fm, claude_text) + + # Compare each field + fields_matched = 0 + for field in fields_to_compare: + completion_value = getattr(completion_response, field, None) + claude_value = getattr(claude_response, field, None) + + if completion_value == claude_value: + fields_matched += 1 + reward += 0.1 + logger.debug(f" Field {field} matches: {completion_value}") + else: + logger.debug(f" Field {field} mismatch: completion={completion_value}, claude={claude_value}") + + logger.debug(f"Completion {i}: Matched {fields_matched}/{len(fields_to_compare)} fields") + + except Exception as e: + logger.warning(f"Failed to parse claude_original for comparison at index {i}: {e}") + # Keep the base 0.5 reward for successful parsing + else: + logger.debug(f"Completion {i}: No claude_original available for comparison") except Exception as e: # Any parsing error results in 0 reward - rewards.append(0.0) - logger.debug(f"Completion {i}: Failed to parse format - {type(e).__name__}: {str(e)}") + reward = 0.0 + logger.debug(f"Completion {i}: Failed to parse frontmatter - {type(e).__name__}: {str(e)}") + + rewards.append(reward) - success_count = sum(1 for r in rewards if r == 1.0) - logger.info(f"Format rewards: {success_count}/{len(rewards)} successfully parsed") + # Log summary statistics + zero_rewards = sum(1 for r in rewards if r == 0.0) + partial_rewards = sum(1 for r in rewards if 0.0 < r < 1.0) + perfect_rewards = sum(1 for r in rewards if r == 1.0) + avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 + + logger.info(f"Front matter rewards summary: {zero_rewards} failed, {partial_rewards} partial, " + f"{perfect_rewards} perfect. Average: {avg_reward:.3f}") return rewards @@ -708,12 +759,12 @@ def main(): help="Use bench edit distance reward with optional weight (default: 1.0)" ) parser.add_argument( - "--reward_format", + "--reward_front_matter", nargs='?', const=1.0, type=float, default=None, - help="Use format validation reward with optional weight (default: 1.0)" + help="Use front matter validation and field matching reward with optional weight (default: 1.0)" ) args = parser.parse_args() @@ -814,14 +865,14 @@ def main(): reward_names.append("bench_edit_distance") logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}") - if args.reward_format is not None: - reward_funcs.append(reward_format) - reward_weights.append(args.reward_format) - reward_names.append("format") - logger.info(f"Added format validation reward function with weight {args.reward_format}") + if args.reward_front_matter is not None: + reward_funcs.append(reward_front_matter) + reward_weights.append(args.reward_front_matter) + reward_names.append("front_matter") + logger.info(f"Added front matter validation reward function with weight {args.reward_front_matter}") if not reward_funcs: - logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_format") + logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter") return # Log summary of reward configuration diff --git a/tests/test_grpo.py b/tests/test_grpo.py index 3e30b18..e789225 100644 --- a/tests/test_grpo.py +++ b/tests/test_grpo.py @@ -10,13 +10,13 @@ import json import tempfile import unittest from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, mock_open import shutil # Add parent directory to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from olmocr.train.grpo_train import OlmOCRBenchDataset, olmocr_bench_reward, load_tests_cached +from olmocr.train.grpo_train import OlmOCRBenchDataset, olmocr_bench_reward, load_specific_tests_cached, reward_front_matter class TestGRPODataloader(unittest.TestCase): @@ -265,7 +265,7 @@ class TestOlmOCRBenchReward(unittest.TestCase): def setUpClass(cls): """Create temporary test files.""" # Clear any cached tests from previous runs - load_tests_cached.cache_clear() + load_specific_tests_cached.cache_clear() cls.temp_dir = tempfile.mkdtemp() # Create a sample JSONL test file with different test types @@ -313,12 +313,12 @@ class TestOlmOCRBenchReward(unittest.TestCase): def tearDownClass(cls): """Clean up temporary files.""" # Clear the LRU cache before removing temp dir - load_tests_cached.cache_clear() + load_specific_tests_cached.cache_clear() shutil.rmtree(cls.temp_dir) def setUp(self): """Clear cache before each test method.""" - load_tests_cached.cache_clear() + load_specific_tests_cached.cache_clear() def test_perfect_completion(self): """Test reward calculation for a completion that passes all tests.""" @@ -421,22 +421,33 @@ class TestOlmOCRBenchReward(unittest.TestCase): self.assertEqual(rewards[2], 1.0) def test_cache_functionality(self): - """Test that load_tests_cached properly caches results.""" + """Test that load_specific_tests_cached properly caches results.""" # Clear cache first - load_tests_cached.cache_clear() + load_specific_tests_cached.cache_clear() - # First call should load from file - with patch('olmocr.train.grpo_train.load_tests') as mock_load: - mock_load.return_value = [] - result1 = load_tests_cached(self.jsonl_path) - self.assertEqual(mock_load.call_count, 1) + # Test that the cache works by calling the function twice + test_ids = ("test1", "test2") + + # First call loads from file + with patch('builtins.open', create=True) as mock_open: + mock_file = MagicMock() + mock_file.__enter__.return_value = iter([ + '{"id": "test1", "type": "present", "text": "hello", "pdf": "test.pdf", "page": 0}\n', + '{"id": "test2", "type": "absent", "text": "world", "pdf": "test.pdf", "page": 0}\n', + '{"id": "test3", "type": "present", "text": "foo", "pdf": "test.pdf", "page": 0}\n', + ]) + mock_open.return_value = mock_file - # Second call should use cache - result2 = load_tests_cached(self.jsonl_path) - self.assertEqual(mock_load.call_count, 1) # Should not increase + result1 = load_specific_tests_cached(self.jsonl_path, test_ids) + self.assertEqual(mock_open.call_count, 1) + + # Second call should use cache, not open file again + result2 = load_specific_tests_cached(self.jsonl_path, test_ids) + self.assertEqual(mock_open.call_count, 1) # Should not increase # Results should be the same self.assertEqual(result1, result2) + self.assertEqual(len(result1), 2) # Should have loaded 2 tests def test_error_handling(self): """Test error handling in reward function.""" @@ -455,6 +466,313 @@ class TestOlmOCRBenchReward(unittest.TestCase): self.assertIsNone(rewards[0]) +class TestRewardFrontMatter(unittest.TestCase): + """Test cases for the reward_front_matter function.""" + + def test_no_frontmatter(self): + """Test that completions without frontmatter get 0.0 reward.""" + completions = [ + "This is just text without any frontmatter", + "Another completion with no YAML", + "# A markdown header but no frontmatter" + ] + + rewards = reward_front_matter( + prompts=["prompt"] * 3, + completions=completions, + claude_original=None + ) + + self.assertEqual(len(rewards), 3) + for reward in rewards: + self.assertEqual(reward, 0.0) + + def test_valid_frontmatter_no_claude(self): + """Test that valid frontmatter gets 0.5 reward when no claude_original.""" + completions = [ + """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +--- + +# Document content here""", + """--- +primary_language: fr +is_rotation_valid: False +rotation_correction: 90 +is_table: True +is_diagram: False +--- + +Some other content""" + ] + + rewards = reward_front_matter( + prompts=["prompt"] * 2, + completions=completions, + claude_original=None + ) + + self.assertEqual(len(rewards), 2) + for reward in rewards: + self.assertEqual(reward, 0.5) + + def test_perfect_match_with_claude(self): + """Test perfect match with claude_original gets 1.0 reward.""" + claude_original_content = """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +--- + +Original content""" + + completion = """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +--- + +Different content but same frontmatter""" + + rewards = reward_front_matter( + prompts=["prompt"], + completions=[completion], + claude_original=[claude_original_content] + ) + + self.assertEqual(len(rewards), 1) + self.assertAlmostEqual(rewards[0], 1.0, places=5) + + def test_partial_match_with_claude(self): + """Test partial match with claude_original gets intermediate reward.""" + claude_original_content = """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +--- + +Original content""" + + # 3 out of 5 fields match + completion = """--- +primary_language: en +is_rotation_valid: False +rotation_correction: 90 +is_table: False +is_diagram: False +--- + +Different values for rotation fields""" + + rewards = reward_front_matter( + prompts=["prompt"], + completions=[completion], + claude_original=[claude_original_content] + ) + + self.assertEqual(len(rewards), 1) + # 0.5 base + 3 * 0.1 = 0.8 + self.assertAlmostEqual(rewards[0], 0.8, places=2) + + def test_invalid_frontmatter_format(self): + """Test that invalid YAML frontmatter gets 0.0 reward.""" + completions = [ + """--- +this is not: valid yaml + because of: : bad formatting +--- +content""", + """--- +primary_language: en +unclosed_string: "this is not closed +---""", + """--- +primary_language: + nested: "should not be nested" +---""" + ] + + rewards = reward_front_matter( + prompts=["prompt"] * 3, + completions=completions, + claude_original=None + ) + + self.assertEqual(len(rewards), 3) + for reward in rewards: + self.assertEqual(reward, 0.0) + + def test_mixed_completions_with_claude(self): + """Test mixed completions with claude_original comparisons.""" + claude_originals = [ + """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +--- +Content 1""", + None, # No claude_original for this one + """--- +primary_language: fr +is_rotation_valid: False +rotation_correction: 180 +is_table: True +is_diagram: True +--- +Content 3""" + ] + + completions = [ + # Perfect match with first claude + """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +--- +Model output 1""", + # Valid frontmatter but no claude to compare + """--- +primary_language: de +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +--- +Model output 2""", + # No frontmatter at all + "Just plain text without frontmatter" + ] + + rewards = reward_front_matter( + prompts=["prompt"] * 3, + completions=completions, + claude_original=claude_originals + ) + + self.assertEqual(len(rewards), 3) + self.assertAlmostEqual(rewards[0], 1.0, places=5) # Perfect match + self.assertEqual(rewards[1], 0.5) # Valid but no claude + self.assertEqual(rewards[2], 0.0) # No frontmatter + + def test_list_format_completions(self): + """Test handling of completions in list format.""" + completion_content = """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +--- +Content""" + + completions = [ + # List format (as from model output) + [{"content": completion_content}], + # String format + completion_content, + # Empty list + [], + # Invalid format + [{}] + ] + + rewards = reward_front_matter( + prompts=["prompt"] * 4, + completions=completions, + claude_original=None + ) + + self.assertEqual(len(rewards), 4) + self.assertEqual(rewards[0], 0.5) # Valid from list + self.assertEqual(rewards[1], 0.5) # Valid from string + self.assertEqual(rewards[2], 0.0) # Empty list + self.assertEqual(rewards[3], 0.0) # Invalid list format + + def test_field_type_matching(self): + """Test that field types are correctly compared.""" + claude_original = """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +---""" + + completions = [ + # Correct types + """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +---""", + # String instead of boolean (might still parse correctly) + """--- +primary_language: en +is_rotation_valid: "True" +rotation_correction: "0" +is_table: "False" +is_diagram: "False" +---""", + ] + + rewards = reward_front_matter( + prompts=["prompt"] * 2, + completions=completions, + claude_original=[claude_original] * 2 + ) + + self.assertEqual(len(rewards), 2) + # First should be perfect match + self.assertAlmostEqual(rewards[0], 1.0, places=5) + # Second: YAML parses string "True" as True boolean, so both should match + self.assertAlmostEqual(rewards[1], 1.0, places=5) + + def test_none_values_in_fields(self): + """Test handling of None values in frontmatter fields.""" + claude_original = """--- +primary_language: en +is_rotation_valid: True +rotation_correction: 0 +is_table: False +is_diagram: False +---""" + + # rotation_correction null will fail validation (must be 0, 90, 180, or 270) + completion = """--- +primary_language: en +is_rotation_valid: True +rotation_correction: null +is_table: False +is_diagram: False +---""" + + rewards = reward_front_matter( + prompts=["prompt"], + completions=[completion], + claude_original=[claude_original] + ) + + self.assertEqual(len(rewards), 1) + # Should fail to parse due to invalid rotation_correction + self.assertEqual(rewards[0], 0.0) + + class TestIntegrationWithRealData(unittest.TestCase): """Integration tests with real bench data if available."""