Cleaner front matter reward

This commit is contained in:
Jake Poznanski 2025-08-27 19:49:42 +00:00
parent 09036b07d9
commit 0710debf75
2 changed files with 409 additions and 40 deletions

View File

@ -472,52 +472,103 @@ def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs):
return rewards
def reward_format(prompts, completions: list[str] | list[list[dict]], **kwargs):
def reward_front_matter(prompts, completions: list[str] | list[list[dict]], claude_original: list[Optional[str]] = None, **kwargs):
"""
Reward function that checks if completions can be successfully parsed by FrontMatterParser.
Reward function that checks if completions can be successfully parsed by FrontMatterParser
and compares fields to claude_original values.
Returns 1.0 if the completion can be parsed without errors, 0.0 otherwise.
This ensures the model generates properly formatted YAML front matter that can be
parsed into a PageResponse object.
Scoring:
- 0.0: Cannot parse frontmatter at all
- 0.5: Can parse frontmatter successfully
- +0.1: For each matching field (primary_language, is_rotation_valid,
rotation_correction, is_table, is_diagram)
Maximum score: 1.0 (0.5 + 5 * 0.1)
Args:
prompts: List of prompts
completions: List of generated completions (model outputs)
claude_original: List of claude_original markdown content (optional)
**kwargs: Additional arguments
Returns:
List of reward scores: 1.0 for successful parsing, 0.0 for errors
List of reward scores between 0.0 and 1.0
"""
logger.info(f"Running format reward function for {len(completions)} completions")
logger.info(f"Running front matter reward function for {len(completions)} completions")
rewards = []
parser = FrontMatterParser(front_matter_class=PageResponse)
# Fields to compare
fields_to_compare = ['primary_language', 'is_rotation_valid', 'rotation_correction',
'is_table', 'is_diagram']
for i, completion in enumerate(completions):
# Extract text from completion
if isinstance(completion, list):
model_response_markdown = completion[0]["content"] if completion else ""
if completion and "content" in completion[0]:
model_response_markdown = completion[0]["content"]
else:
model_response_markdown = ""
elif isinstance(completion, str):
model_response_markdown = completion
else:
model_response_markdown = ""
try:
# Try to parse the completion using the same logic as in pipeline.py
front_matter, text = parser._extract_front_matter_and_text(model_response_markdown)
page_response = parser._parse_front_matter(front_matter, text)
reward = 0.0
# If we get here without exception, parsing succeeded
rewards.append(1.0)
logger.debug(f"Completion {i}: Successfully parsed format")
try:
# Try to parse the completion
front_matter, text = parser._extract_front_matter_and_text(model_response_markdown)
completion_response = parser._parse_front_matter(front_matter, text)
# Parsing succeeded - base reward of 0.5
reward = 0.5
logger.debug(f"Completion {i}: Successfully parsed frontmatter (base reward: 0.5)")
# Try to compare with claude_original if available
if claude_original and i < len(claude_original) and claude_original[i]:
try:
# Parse claude_original frontmatter
claude_fm, claude_text = parser._extract_front_matter_and_text(claude_original[i])
claude_response = parser._parse_front_matter(claude_fm, claude_text)
# Compare each field
fields_matched = 0
for field in fields_to_compare:
completion_value = getattr(completion_response, field, None)
claude_value = getattr(claude_response, field, None)
if completion_value == claude_value:
fields_matched += 1
reward += 0.1
logger.debug(f" Field {field} matches: {completion_value}")
else:
logger.debug(f" Field {field} mismatch: completion={completion_value}, claude={claude_value}")
logger.debug(f"Completion {i}: Matched {fields_matched}/{len(fields_to_compare)} fields")
except Exception as e:
logger.warning(f"Failed to parse claude_original for comparison at index {i}: {e}")
# Keep the base 0.5 reward for successful parsing
else:
logger.debug(f"Completion {i}: No claude_original available for comparison")
except Exception as e:
# Any parsing error results in 0 reward
rewards.append(0.0)
logger.debug(f"Completion {i}: Failed to parse format - {type(e).__name__}: {str(e)}")
reward = 0.0
logger.debug(f"Completion {i}: Failed to parse frontmatter - {type(e).__name__}: {str(e)}")
success_count = sum(1 for r in rewards if r == 1.0)
logger.info(f"Format rewards: {success_count}/{len(rewards)} successfully parsed")
rewards.append(reward)
# Log summary statistics
zero_rewards = sum(1 for r in rewards if r == 0.0)
partial_rewards = sum(1 for r in rewards if 0.0 < r < 1.0)
perfect_rewards = sum(1 for r in rewards if r == 1.0)
avg_reward = sum(rewards) / len(rewards) if rewards else 0.0
logger.info(f"Front matter rewards summary: {zero_rewards} failed, {partial_rewards} partial, "
f"{perfect_rewards} perfect. Average: {avg_reward:.3f}")
return rewards
@ -708,12 +759,12 @@ def main():
help="Use bench edit distance reward with optional weight (default: 1.0)"
)
parser.add_argument(
"--reward_format",
"--reward_front_matter",
nargs='?',
const=1.0,
type=float,
default=None,
help="Use format validation reward with optional weight (default: 1.0)"
help="Use front matter validation and field matching reward with optional weight (default: 1.0)"
)
args = parser.parse_args()
@ -814,14 +865,14 @@ def main():
reward_names.append("bench_edit_distance")
logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}")
if args.reward_format is not None:
reward_funcs.append(reward_format)
reward_weights.append(args.reward_format)
reward_names.append("format")
logger.info(f"Added format validation reward function with weight {args.reward_format}")
if args.reward_front_matter is not None:
reward_funcs.append(reward_front_matter)
reward_weights.append(args.reward_front_matter)
reward_names.append("front_matter")
logger.info(f"Added front matter validation reward function with weight {args.reward_front_matter}")
if not reward_funcs:
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_format")
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter")
return
# Log summary of reward configuration

View File

@ -10,13 +10,13 @@ import json
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, mock_open
import shutil
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from olmocr.train.grpo_train import OlmOCRBenchDataset, olmocr_bench_reward, load_tests_cached
from olmocr.train.grpo_train import OlmOCRBenchDataset, olmocr_bench_reward, load_specific_tests_cached, reward_front_matter
class TestGRPODataloader(unittest.TestCase):
@ -265,7 +265,7 @@ class TestOlmOCRBenchReward(unittest.TestCase):
def setUpClass(cls):
"""Create temporary test files."""
# Clear any cached tests from previous runs
load_tests_cached.cache_clear()
load_specific_tests_cached.cache_clear()
cls.temp_dir = tempfile.mkdtemp()
# Create a sample JSONL test file with different test types
@ -313,12 +313,12 @@ class TestOlmOCRBenchReward(unittest.TestCase):
def tearDownClass(cls):
"""Clean up temporary files."""
# Clear the LRU cache before removing temp dir
load_tests_cached.cache_clear()
load_specific_tests_cached.cache_clear()
shutil.rmtree(cls.temp_dir)
def setUp(self):
"""Clear cache before each test method."""
load_tests_cached.cache_clear()
load_specific_tests_cached.cache_clear()
def test_perfect_completion(self):
"""Test reward calculation for a completion that passes all tests."""
@ -421,22 +421,33 @@ class TestOlmOCRBenchReward(unittest.TestCase):
self.assertEqual(rewards[2], 1.0)
def test_cache_functionality(self):
"""Test that load_tests_cached properly caches results."""
"""Test that load_specific_tests_cached properly caches results."""
# Clear cache first
load_tests_cached.cache_clear()
load_specific_tests_cached.cache_clear()
# First call should load from file
with patch('olmocr.train.grpo_train.load_tests') as mock_load:
mock_load.return_value = []
result1 = load_tests_cached(self.jsonl_path)
self.assertEqual(mock_load.call_count, 1)
# Test that the cache works by calling the function twice
test_ids = ("test1", "test2")
# Second call should use cache
result2 = load_tests_cached(self.jsonl_path)
self.assertEqual(mock_load.call_count, 1) # Should not increase
# First call loads from file
with patch('builtins.open', create=True) as mock_open:
mock_file = MagicMock()
mock_file.__enter__.return_value = iter([
'{"id": "test1", "type": "present", "text": "hello", "pdf": "test.pdf", "page": 0}\n',
'{"id": "test2", "type": "absent", "text": "world", "pdf": "test.pdf", "page": 0}\n',
'{"id": "test3", "type": "present", "text": "foo", "pdf": "test.pdf", "page": 0}\n',
])
mock_open.return_value = mock_file
result1 = load_specific_tests_cached(self.jsonl_path, test_ids)
self.assertEqual(mock_open.call_count, 1)
# Second call should use cache, not open file again
result2 = load_specific_tests_cached(self.jsonl_path, test_ids)
self.assertEqual(mock_open.call_count, 1) # Should not increase
# Results should be the same
self.assertEqual(result1, result2)
self.assertEqual(len(result1), 2) # Should have loaded 2 tests
def test_error_handling(self):
"""Test error handling in reward function."""
@ -455,6 +466,313 @@ class TestOlmOCRBenchReward(unittest.TestCase):
self.assertIsNone(rewards[0])
class TestRewardFrontMatter(unittest.TestCase):
"""Test cases for the reward_front_matter function."""
def test_no_frontmatter(self):
"""Test that completions without frontmatter get 0.0 reward."""
completions = [
"This is just text without any frontmatter",
"Another completion with no YAML",
"# A markdown header but no frontmatter"
]
rewards = reward_front_matter(
prompts=["prompt"] * 3,
completions=completions,
claude_original=None
)
self.assertEqual(len(rewards), 3)
for reward in rewards:
self.assertEqual(reward, 0.0)
def test_valid_frontmatter_no_claude(self):
"""Test that valid frontmatter gets 0.5 reward when no claude_original."""
completions = [
"""---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---
# Document content here""",
"""---
primary_language: fr
is_rotation_valid: False
rotation_correction: 90
is_table: True
is_diagram: False
---
Some other content"""
]
rewards = reward_front_matter(
prompts=["prompt"] * 2,
completions=completions,
claude_original=None
)
self.assertEqual(len(rewards), 2)
for reward in rewards:
self.assertEqual(reward, 0.5)
def test_perfect_match_with_claude(self):
"""Test perfect match with claude_original gets 1.0 reward."""
claude_original_content = """---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---
Original content"""
completion = """---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---
Different content but same frontmatter"""
rewards = reward_front_matter(
prompts=["prompt"],
completions=[completion],
claude_original=[claude_original_content]
)
self.assertEqual(len(rewards), 1)
self.assertAlmostEqual(rewards[0], 1.0, places=5)
def test_partial_match_with_claude(self):
"""Test partial match with claude_original gets intermediate reward."""
claude_original_content = """---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---
Original content"""
# 3 out of 5 fields match
completion = """---
primary_language: en
is_rotation_valid: False
rotation_correction: 90
is_table: False
is_diagram: False
---
Different values for rotation fields"""
rewards = reward_front_matter(
prompts=["prompt"],
completions=[completion],
claude_original=[claude_original_content]
)
self.assertEqual(len(rewards), 1)
# 0.5 base + 3 * 0.1 = 0.8
self.assertAlmostEqual(rewards[0], 0.8, places=2)
def test_invalid_frontmatter_format(self):
"""Test that invalid YAML frontmatter gets 0.0 reward."""
completions = [
"""---
this is not: valid yaml
because of: : bad formatting
---
content""",
"""---
primary_language: en
unclosed_string: "this is not closed
---""",
"""---
primary_language:
nested: "should not be nested"
---"""
]
rewards = reward_front_matter(
prompts=["prompt"] * 3,
completions=completions,
claude_original=None
)
self.assertEqual(len(rewards), 3)
for reward in rewards:
self.assertEqual(reward, 0.0)
def test_mixed_completions_with_claude(self):
"""Test mixed completions with claude_original comparisons."""
claude_originals = [
"""---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---
Content 1""",
None, # No claude_original for this one
"""---
primary_language: fr
is_rotation_valid: False
rotation_correction: 180
is_table: True
is_diagram: True
---
Content 3"""
]
completions = [
# Perfect match with first claude
"""---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---
Model output 1""",
# Valid frontmatter but no claude to compare
"""---
primary_language: de
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---
Model output 2""",
# No frontmatter at all
"Just plain text without frontmatter"
]
rewards = reward_front_matter(
prompts=["prompt"] * 3,
completions=completions,
claude_original=claude_originals
)
self.assertEqual(len(rewards), 3)
self.assertAlmostEqual(rewards[0], 1.0, places=5) # Perfect match
self.assertEqual(rewards[1], 0.5) # Valid but no claude
self.assertEqual(rewards[2], 0.0) # No frontmatter
def test_list_format_completions(self):
"""Test handling of completions in list format."""
completion_content = """---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---
Content"""
completions = [
# List format (as from model output)
[{"content": completion_content}],
# String format
completion_content,
# Empty list
[],
# Invalid format
[{}]
]
rewards = reward_front_matter(
prompts=["prompt"] * 4,
completions=completions,
claude_original=None
)
self.assertEqual(len(rewards), 4)
self.assertEqual(rewards[0], 0.5) # Valid from list
self.assertEqual(rewards[1], 0.5) # Valid from string
self.assertEqual(rewards[2], 0.0) # Empty list
self.assertEqual(rewards[3], 0.0) # Invalid list format
def test_field_type_matching(self):
"""Test that field types are correctly compared."""
claude_original = """---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---"""
completions = [
# Correct types
"""---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---""",
# String instead of boolean (might still parse correctly)
"""---
primary_language: en
is_rotation_valid: "True"
rotation_correction: "0"
is_table: "False"
is_diagram: "False"
---""",
]
rewards = reward_front_matter(
prompts=["prompt"] * 2,
completions=completions,
claude_original=[claude_original] * 2
)
self.assertEqual(len(rewards), 2)
# First should be perfect match
self.assertAlmostEqual(rewards[0], 1.0, places=5)
# Second: YAML parses string "True" as True boolean, so both should match
self.assertAlmostEqual(rewards[1], 1.0, places=5)
def test_none_values_in_fields(self):
"""Test handling of None values in frontmatter fields."""
claude_original = """---
primary_language: en
is_rotation_valid: True
rotation_correction: 0
is_table: False
is_diagram: False
---"""
# rotation_correction null will fail validation (must be 0, 90, 180, or 270)
completion = """---
primary_language: en
is_rotation_valid: True
rotation_correction: null
is_table: False
is_diagram: False
---"""
rewards = reward_front_matter(
prompts=["prompt"],
completions=[completion],
claude_original=[claude_original]
)
self.assertEqual(len(rewards), 1)
# Should fail to parse due to invalid rotation_correction
self.assertEqual(rewards[0], 0.0)
class TestIntegrationWithRealData(unittest.TestCase):
"""Integration tests with real bench data if available."""