mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 08:12:22 +00:00
Cleaner front matter reward
This commit is contained in:
parent
09036b07d9
commit
0710debf75
@ -472,52 +472,103 @@ def medoid_reward(prompts, completions: list[str] | list[list[dict]], **kwargs):
|
||||
return rewards
|
||||
|
||||
|
||||
def reward_format(prompts, completions: list[str] | list[list[dict]], **kwargs):
|
||||
def reward_front_matter(prompts, completions: list[str] | list[list[dict]], claude_original: list[Optional[str]] = None, **kwargs):
|
||||
"""
|
||||
Reward function that checks if completions can be successfully parsed by FrontMatterParser.
|
||||
Reward function that checks if completions can be successfully parsed by FrontMatterParser
|
||||
and compares fields to claude_original values.
|
||||
|
||||
Returns 1.0 if the completion can be parsed without errors, 0.0 otherwise.
|
||||
This ensures the model generates properly formatted YAML front matter that can be
|
||||
parsed into a PageResponse object.
|
||||
Scoring:
|
||||
- 0.0: Cannot parse frontmatter at all
|
||||
- 0.5: Can parse frontmatter successfully
|
||||
- +0.1: For each matching field (primary_language, is_rotation_valid,
|
||||
rotation_correction, is_table, is_diagram)
|
||||
|
||||
Maximum score: 1.0 (0.5 + 5 * 0.1)
|
||||
|
||||
Args:
|
||||
prompts: List of prompts
|
||||
completions: List of generated completions (model outputs)
|
||||
claude_original: List of claude_original markdown content (optional)
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
List of reward scores: 1.0 for successful parsing, 0.0 for errors
|
||||
List of reward scores between 0.0 and 1.0
|
||||
"""
|
||||
logger.info(f"Running format reward function for {len(completions)} completions")
|
||||
logger.info(f"Running front matter reward function for {len(completions)} completions")
|
||||
|
||||
rewards = []
|
||||
parser = FrontMatterParser(front_matter_class=PageResponse)
|
||||
|
||||
# Fields to compare
|
||||
fields_to_compare = ['primary_language', 'is_rotation_valid', 'rotation_correction',
|
||||
'is_table', 'is_diagram']
|
||||
|
||||
for i, completion in enumerate(completions):
|
||||
# Extract text from completion
|
||||
if isinstance(completion, list):
|
||||
model_response_markdown = completion[0]["content"] if completion else ""
|
||||
if completion and "content" in completion[0]:
|
||||
model_response_markdown = completion[0]["content"]
|
||||
else:
|
||||
model_response_markdown = ""
|
||||
elif isinstance(completion, str):
|
||||
model_response_markdown = completion
|
||||
else:
|
||||
model_response_markdown = ""
|
||||
|
||||
try:
|
||||
# Try to parse the completion using the same logic as in pipeline.py
|
||||
front_matter, text = parser._extract_front_matter_and_text(model_response_markdown)
|
||||
page_response = parser._parse_front_matter(front_matter, text)
|
||||
reward = 0.0
|
||||
|
||||
# If we get here without exception, parsing succeeded
|
||||
rewards.append(1.0)
|
||||
logger.debug(f"Completion {i}: Successfully parsed format")
|
||||
try:
|
||||
# Try to parse the completion
|
||||
front_matter, text = parser._extract_front_matter_and_text(model_response_markdown)
|
||||
completion_response = parser._parse_front_matter(front_matter, text)
|
||||
|
||||
# Parsing succeeded - base reward of 0.5
|
||||
reward = 0.5
|
||||
logger.debug(f"Completion {i}: Successfully parsed frontmatter (base reward: 0.5)")
|
||||
|
||||
# Try to compare with claude_original if available
|
||||
if claude_original and i < len(claude_original) and claude_original[i]:
|
||||
try:
|
||||
# Parse claude_original frontmatter
|
||||
claude_fm, claude_text = parser._extract_front_matter_and_text(claude_original[i])
|
||||
claude_response = parser._parse_front_matter(claude_fm, claude_text)
|
||||
|
||||
# Compare each field
|
||||
fields_matched = 0
|
||||
for field in fields_to_compare:
|
||||
completion_value = getattr(completion_response, field, None)
|
||||
claude_value = getattr(claude_response, field, None)
|
||||
|
||||
if completion_value == claude_value:
|
||||
fields_matched += 1
|
||||
reward += 0.1
|
||||
logger.debug(f" Field {field} matches: {completion_value}")
|
||||
else:
|
||||
logger.debug(f" Field {field} mismatch: completion={completion_value}, claude={claude_value}")
|
||||
|
||||
logger.debug(f"Completion {i}: Matched {fields_matched}/{len(fields_to_compare)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse claude_original for comparison at index {i}: {e}")
|
||||
# Keep the base 0.5 reward for successful parsing
|
||||
else:
|
||||
logger.debug(f"Completion {i}: No claude_original available for comparison")
|
||||
|
||||
except Exception as e:
|
||||
# Any parsing error results in 0 reward
|
||||
rewards.append(0.0)
|
||||
logger.debug(f"Completion {i}: Failed to parse format - {type(e).__name__}: {str(e)}")
|
||||
reward = 0.0
|
||||
logger.debug(f"Completion {i}: Failed to parse frontmatter - {type(e).__name__}: {str(e)}")
|
||||
|
||||
success_count = sum(1 for r in rewards if r == 1.0)
|
||||
logger.info(f"Format rewards: {success_count}/{len(rewards)} successfully parsed")
|
||||
rewards.append(reward)
|
||||
|
||||
# Log summary statistics
|
||||
zero_rewards = sum(1 for r in rewards if r == 0.0)
|
||||
partial_rewards = sum(1 for r in rewards if 0.0 < r < 1.0)
|
||||
perfect_rewards = sum(1 for r in rewards if r == 1.0)
|
||||
avg_reward = sum(rewards) / len(rewards) if rewards else 0.0
|
||||
|
||||
logger.info(f"Front matter rewards summary: {zero_rewards} failed, {partial_rewards} partial, "
|
||||
f"{perfect_rewards} perfect. Average: {avg_reward:.3f}")
|
||||
|
||||
return rewards
|
||||
|
||||
@ -708,12 +759,12 @@ def main():
|
||||
help="Use bench edit distance reward with optional weight (default: 1.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward_format",
|
||||
"--reward_front_matter",
|
||||
nargs='?',
|
||||
const=1.0,
|
||||
type=float,
|
||||
default=None,
|
||||
help="Use format validation reward with optional weight (default: 1.0)"
|
||||
help="Use front matter validation and field matching reward with optional weight (default: 1.0)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -814,14 +865,14 @@ def main():
|
||||
reward_names.append("bench_edit_distance")
|
||||
logger.info(f"Added bench edit distance reward function with weight {args.reward_bench_edit_distance}")
|
||||
|
||||
if args.reward_format is not None:
|
||||
reward_funcs.append(reward_format)
|
||||
reward_weights.append(args.reward_format)
|
||||
reward_names.append("format")
|
||||
logger.info(f"Added format validation reward function with weight {args.reward_format}")
|
||||
if args.reward_front_matter is not None:
|
||||
reward_funcs.append(reward_front_matter)
|
||||
reward_weights.append(args.reward_front_matter)
|
||||
reward_names.append("front_matter")
|
||||
logger.info(f"Added front matter validation reward function with weight {args.reward_front_matter}")
|
||||
|
||||
if not reward_funcs:
|
||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_format")
|
||||
logger.error("No reward function specified. Use at least one of: --reward_bench, --reward_medoid, --reward_bench_edit_distance, --reward_front_matter")
|
||||
return
|
||||
|
||||
# Log summary of reward configuration
|
||||
|
@ -10,13 +10,13 @@ import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, mock_open
|
||||
import shutil
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from olmocr.train.grpo_train import OlmOCRBenchDataset, olmocr_bench_reward, load_tests_cached
|
||||
from olmocr.train.grpo_train import OlmOCRBenchDataset, olmocr_bench_reward, load_specific_tests_cached, reward_front_matter
|
||||
|
||||
|
||||
class TestGRPODataloader(unittest.TestCase):
|
||||
@ -265,7 +265,7 @@ class TestOlmOCRBenchReward(unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
"""Create temporary test files."""
|
||||
# Clear any cached tests from previous runs
|
||||
load_tests_cached.cache_clear()
|
||||
load_specific_tests_cached.cache_clear()
|
||||
cls.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
# Create a sample JSONL test file with different test types
|
||||
@ -313,12 +313,12 @@ class TestOlmOCRBenchReward(unittest.TestCase):
|
||||
def tearDownClass(cls):
|
||||
"""Clean up temporary files."""
|
||||
# Clear the LRU cache before removing temp dir
|
||||
load_tests_cached.cache_clear()
|
||||
load_specific_tests_cached.cache_clear()
|
||||
shutil.rmtree(cls.temp_dir)
|
||||
|
||||
def setUp(self):
|
||||
"""Clear cache before each test method."""
|
||||
load_tests_cached.cache_clear()
|
||||
load_specific_tests_cached.cache_clear()
|
||||
|
||||
def test_perfect_completion(self):
|
||||
"""Test reward calculation for a completion that passes all tests."""
|
||||
@ -421,22 +421,33 @@ class TestOlmOCRBenchReward(unittest.TestCase):
|
||||
self.assertEqual(rewards[2], 1.0)
|
||||
|
||||
def test_cache_functionality(self):
|
||||
"""Test that load_tests_cached properly caches results."""
|
||||
"""Test that load_specific_tests_cached properly caches results."""
|
||||
# Clear cache first
|
||||
load_tests_cached.cache_clear()
|
||||
load_specific_tests_cached.cache_clear()
|
||||
|
||||
# First call should load from file
|
||||
with patch('olmocr.train.grpo_train.load_tests') as mock_load:
|
||||
mock_load.return_value = []
|
||||
result1 = load_tests_cached(self.jsonl_path)
|
||||
self.assertEqual(mock_load.call_count, 1)
|
||||
# Test that the cache works by calling the function twice
|
||||
test_ids = ("test1", "test2")
|
||||
|
||||
# Second call should use cache
|
||||
result2 = load_tests_cached(self.jsonl_path)
|
||||
self.assertEqual(mock_load.call_count, 1) # Should not increase
|
||||
# First call loads from file
|
||||
with patch('builtins.open', create=True) as mock_open:
|
||||
mock_file = MagicMock()
|
||||
mock_file.__enter__.return_value = iter([
|
||||
'{"id": "test1", "type": "present", "text": "hello", "pdf": "test.pdf", "page": 0}\n',
|
||||
'{"id": "test2", "type": "absent", "text": "world", "pdf": "test.pdf", "page": 0}\n',
|
||||
'{"id": "test3", "type": "present", "text": "foo", "pdf": "test.pdf", "page": 0}\n',
|
||||
])
|
||||
mock_open.return_value = mock_file
|
||||
|
||||
result1 = load_specific_tests_cached(self.jsonl_path, test_ids)
|
||||
self.assertEqual(mock_open.call_count, 1)
|
||||
|
||||
# Second call should use cache, not open file again
|
||||
result2 = load_specific_tests_cached(self.jsonl_path, test_ids)
|
||||
self.assertEqual(mock_open.call_count, 1) # Should not increase
|
||||
|
||||
# Results should be the same
|
||||
self.assertEqual(result1, result2)
|
||||
self.assertEqual(len(result1), 2) # Should have loaded 2 tests
|
||||
|
||||
def test_error_handling(self):
|
||||
"""Test error handling in reward function."""
|
||||
@ -455,6 +466,313 @@ class TestOlmOCRBenchReward(unittest.TestCase):
|
||||
self.assertIsNone(rewards[0])
|
||||
|
||||
|
||||
class TestRewardFrontMatter(unittest.TestCase):
|
||||
"""Test cases for the reward_front_matter function."""
|
||||
|
||||
def test_no_frontmatter(self):
|
||||
"""Test that completions without frontmatter get 0.0 reward."""
|
||||
completions = [
|
||||
"This is just text without any frontmatter",
|
||||
"Another completion with no YAML",
|
||||
"# A markdown header but no frontmatter"
|
||||
]
|
||||
|
||||
rewards = reward_front_matter(
|
||||
prompts=["prompt"] * 3,
|
||||
completions=completions,
|
||||
claude_original=None
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 3)
|
||||
for reward in rewards:
|
||||
self.assertEqual(reward, 0.0)
|
||||
|
||||
def test_valid_frontmatter_no_claude(self):
|
||||
"""Test that valid frontmatter gets 0.5 reward when no claude_original."""
|
||||
completions = [
|
||||
"""---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---
|
||||
|
||||
# Document content here""",
|
||||
"""---
|
||||
primary_language: fr
|
||||
is_rotation_valid: False
|
||||
rotation_correction: 90
|
||||
is_table: True
|
||||
is_diagram: False
|
||||
---
|
||||
|
||||
Some other content"""
|
||||
]
|
||||
|
||||
rewards = reward_front_matter(
|
||||
prompts=["prompt"] * 2,
|
||||
completions=completions,
|
||||
claude_original=None
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 2)
|
||||
for reward in rewards:
|
||||
self.assertEqual(reward, 0.5)
|
||||
|
||||
def test_perfect_match_with_claude(self):
|
||||
"""Test perfect match with claude_original gets 1.0 reward."""
|
||||
claude_original_content = """---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---
|
||||
|
||||
Original content"""
|
||||
|
||||
completion = """---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---
|
||||
|
||||
Different content but same frontmatter"""
|
||||
|
||||
rewards = reward_front_matter(
|
||||
prompts=["prompt"],
|
||||
completions=[completion],
|
||||
claude_original=[claude_original_content]
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 1)
|
||||
self.assertAlmostEqual(rewards[0], 1.0, places=5)
|
||||
|
||||
def test_partial_match_with_claude(self):
|
||||
"""Test partial match with claude_original gets intermediate reward."""
|
||||
claude_original_content = """---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---
|
||||
|
||||
Original content"""
|
||||
|
||||
# 3 out of 5 fields match
|
||||
completion = """---
|
||||
primary_language: en
|
||||
is_rotation_valid: False
|
||||
rotation_correction: 90
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---
|
||||
|
||||
Different values for rotation fields"""
|
||||
|
||||
rewards = reward_front_matter(
|
||||
prompts=["prompt"],
|
||||
completions=[completion],
|
||||
claude_original=[claude_original_content]
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 1)
|
||||
# 0.5 base + 3 * 0.1 = 0.8
|
||||
self.assertAlmostEqual(rewards[0], 0.8, places=2)
|
||||
|
||||
def test_invalid_frontmatter_format(self):
|
||||
"""Test that invalid YAML frontmatter gets 0.0 reward."""
|
||||
completions = [
|
||||
"""---
|
||||
this is not: valid yaml
|
||||
because of: : bad formatting
|
||||
---
|
||||
content""",
|
||||
"""---
|
||||
primary_language: en
|
||||
unclosed_string: "this is not closed
|
||||
---""",
|
||||
"""---
|
||||
primary_language:
|
||||
nested: "should not be nested"
|
||||
---"""
|
||||
]
|
||||
|
||||
rewards = reward_front_matter(
|
||||
prompts=["prompt"] * 3,
|
||||
completions=completions,
|
||||
claude_original=None
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 3)
|
||||
for reward in rewards:
|
||||
self.assertEqual(reward, 0.0)
|
||||
|
||||
def test_mixed_completions_with_claude(self):
|
||||
"""Test mixed completions with claude_original comparisons."""
|
||||
claude_originals = [
|
||||
"""---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---
|
||||
Content 1""",
|
||||
None, # No claude_original for this one
|
||||
"""---
|
||||
primary_language: fr
|
||||
is_rotation_valid: False
|
||||
rotation_correction: 180
|
||||
is_table: True
|
||||
is_diagram: True
|
||||
---
|
||||
Content 3"""
|
||||
]
|
||||
|
||||
completions = [
|
||||
# Perfect match with first claude
|
||||
"""---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---
|
||||
Model output 1""",
|
||||
# Valid frontmatter but no claude to compare
|
||||
"""---
|
||||
primary_language: de
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---
|
||||
Model output 2""",
|
||||
# No frontmatter at all
|
||||
"Just plain text without frontmatter"
|
||||
]
|
||||
|
||||
rewards = reward_front_matter(
|
||||
prompts=["prompt"] * 3,
|
||||
completions=completions,
|
||||
claude_original=claude_originals
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 3)
|
||||
self.assertAlmostEqual(rewards[0], 1.0, places=5) # Perfect match
|
||||
self.assertEqual(rewards[1], 0.5) # Valid but no claude
|
||||
self.assertEqual(rewards[2], 0.0) # No frontmatter
|
||||
|
||||
def test_list_format_completions(self):
|
||||
"""Test handling of completions in list format."""
|
||||
completion_content = """---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---
|
||||
Content"""
|
||||
|
||||
completions = [
|
||||
# List format (as from model output)
|
||||
[{"content": completion_content}],
|
||||
# String format
|
||||
completion_content,
|
||||
# Empty list
|
||||
[],
|
||||
# Invalid format
|
||||
[{}]
|
||||
]
|
||||
|
||||
rewards = reward_front_matter(
|
||||
prompts=["prompt"] * 4,
|
||||
completions=completions,
|
||||
claude_original=None
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 4)
|
||||
self.assertEqual(rewards[0], 0.5) # Valid from list
|
||||
self.assertEqual(rewards[1], 0.5) # Valid from string
|
||||
self.assertEqual(rewards[2], 0.0) # Empty list
|
||||
self.assertEqual(rewards[3], 0.0) # Invalid list format
|
||||
|
||||
def test_field_type_matching(self):
|
||||
"""Test that field types are correctly compared."""
|
||||
claude_original = """---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---"""
|
||||
|
||||
completions = [
|
||||
# Correct types
|
||||
"""---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---""",
|
||||
# String instead of boolean (might still parse correctly)
|
||||
"""---
|
||||
primary_language: en
|
||||
is_rotation_valid: "True"
|
||||
rotation_correction: "0"
|
||||
is_table: "False"
|
||||
is_diagram: "False"
|
||||
---""",
|
||||
]
|
||||
|
||||
rewards = reward_front_matter(
|
||||
prompts=["prompt"] * 2,
|
||||
completions=completions,
|
||||
claude_original=[claude_original] * 2
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 2)
|
||||
# First should be perfect match
|
||||
self.assertAlmostEqual(rewards[0], 1.0, places=5)
|
||||
# Second: YAML parses string "True" as True boolean, so both should match
|
||||
self.assertAlmostEqual(rewards[1], 1.0, places=5)
|
||||
|
||||
def test_none_values_in_fields(self):
|
||||
"""Test handling of None values in frontmatter fields."""
|
||||
claude_original = """---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: 0
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---"""
|
||||
|
||||
# rotation_correction null will fail validation (must be 0, 90, 180, or 270)
|
||||
completion = """---
|
||||
primary_language: en
|
||||
is_rotation_valid: True
|
||||
rotation_correction: null
|
||||
is_table: False
|
||||
is_diagram: False
|
||||
---"""
|
||||
|
||||
rewards = reward_front_matter(
|
||||
prompts=["prompt"],
|
||||
completions=[completion],
|
||||
claude_original=[claude_original]
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 1)
|
||||
# Should fail to parse due to invalid rotation_correction
|
||||
self.assertEqual(rewards[0], 0.0)
|
||||
|
||||
|
||||
class TestIntegrationWithRealData(unittest.TestCase):
|
||||
"""Integration tests with real bench data if available."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user