mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-09 23:18:02 +00:00
Olmocr bench grpo stuff
This commit is contained in:
parent
6184c94c3c
commit
1dd6ff9b03
@ -11,6 +11,7 @@ import json
|
|||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import glob
|
import glob
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -28,6 +29,7 @@ from io import BytesIO
|
|||||||
|
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
|
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
|
||||||
|
from olmocr.bench.tests import load_tests
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -170,11 +172,101 @@ class OlmOCRDataset(Dataset):
|
|||||||
# Return None if processing fails
|
# Return None if processing fails
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def simple_length_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs):
|
@lru_cache(maxsize=128)
|
||||||
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
|
def load_tests_cached(jsonl_file: str):
|
||||||
logger.info(f"Reward function called {kwargs}")
|
"""
|
||||||
# return [float(len(ids)) for ids in completions_ids]
|
Cached version of load_tests to avoid reloading the same JSONL file multiple times.
|
||||||
return [random.choice([0.1, 0.5]) for x in completions]
|
|
||||||
|
Args:
|
||||||
|
jsonl_file: Path to the JSONL file containing test definitions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of test objects loaded from the file
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading tests from {jsonl_file} (will be cached)")
|
||||||
|
return load_tests(jsonl_file)
|
||||||
|
|
||||||
|
|
||||||
|
def unit_test_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs):
|
||||||
|
"""
|
||||||
|
Reward function that runs unit tests on completions and returns average pass rate.
|
||||||
|
|
||||||
|
For each completion, loads the corresponding tests from the JSONL file and runs them.
|
||||||
|
Returns the proportion of tests that pass as the reward score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of prompts
|
||||||
|
completions: List of generated completions (model outputs)
|
||||||
|
completion_ids: List of completion token IDs
|
||||||
|
pdf_path: Path to the PDF file being processed
|
||||||
|
jsonl_file: Path to the JSONL file containing test definitions
|
||||||
|
test_ids: List of test IDs associated with this PDF page
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of reward scores (0.0 to 1.0) based on test pass rates
|
||||||
|
"""
|
||||||
|
logger.info(f"Running unit test reward function for {len(completions)} completions")
|
||||||
|
logger.info(f"PDF: {pdf_path}, JSONL: {jsonl_file}, Test IDs: {test_ids}")
|
||||||
|
|
||||||
|
rewards = []
|
||||||
|
|
||||||
|
# Load all tests from the JSONL file (cached)
|
||||||
|
try:
|
||||||
|
all_tests = load_tests_cached(jsonl_file)
|
||||||
|
# Filter to only the tests for this specific PDF page
|
||||||
|
relevant_tests = [test for test in all_tests if test.id in test_ids]
|
||||||
|
|
||||||
|
if not relevant_tests:
|
||||||
|
logger.warning(f"No relevant tests found for test IDs: {test_ids}")
|
||||||
|
# Return a small positive reward to avoid training issues
|
||||||
|
return [0.1 for _ in completions]
|
||||||
|
|
||||||
|
logger.info(f"Found {len(relevant_tests)} relevant tests for this PDF page")
|
||||||
|
|
||||||
|
# Process each completion
|
||||||
|
for i, completion in enumerate(completions):
|
||||||
|
if not completion or not isinstance(completion, str):
|
||||||
|
logger.warning(f"Invalid completion at index {i}: {type(completion)}")
|
||||||
|
rewards.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Run all relevant tests on this completion
|
||||||
|
passed = 0
|
||||||
|
total = len(relevant_tests)
|
||||||
|
|
||||||
|
for test in relevant_tests:
|
||||||
|
try:
|
||||||
|
test_passed, failure_reason = test.run(completion)
|
||||||
|
if test_passed:
|
||||||
|
passed += 1
|
||||||
|
else:
|
||||||
|
logger.debug(f"Test {test.id} failed: {failure_reason}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error running test {test.id}: {e}")
|
||||||
|
# Count errored tests as failures
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate reward as proportion of tests passed
|
||||||
|
reward = passed / total if total > 0 else 0.0
|
||||||
|
rewards.append(reward)
|
||||||
|
|
||||||
|
logger.info(f"Completion {i}: {passed}/{total} tests passed, reward={reward:.3f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in unit_test_reward function: {e}")
|
||||||
|
# Return small positive rewards to avoid training issues
|
||||||
|
return [0.1 for _ in completions]
|
||||||
|
|
||||||
|
# Ensure we always return rewards between 0 and 1
|
||||||
|
rewards = [max(0.0, min(1.0, r)) for r in rewards]
|
||||||
|
|
||||||
|
# If all rewards are 0, add a small epsilon to avoid training issues
|
||||||
|
if all(r == 0.0 for r in rewards):
|
||||||
|
logger.warning("All completions failed all tests, adding small epsilon reward")
|
||||||
|
rewards = [0.01 for _ in rewards]
|
||||||
|
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -370,7 +462,7 @@ def main():
|
|||||||
processing_class=processor,
|
processing_class=processor,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
reward_funcs=simple_length_reward,
|
reward_funcs=unit_test_reward,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
|
|||||||
@ -166,6 +166,7 @@ commands = [
|
|||||||
"pip install trl wandb",
|
"pip install trl wandb",
|
||||||
"pip install transformers==4.55.2", # Updated for GRPO compatibility
|
"pip install transformers==4.55.2", # Updated for GRPO compatibility
|
||||||
"pip install flash-attn==2.8.0.post2 --no-build-isolation",
|
"pip install flash-attn==2.8.0.post2 --no-build-isolation",
|
||||||
|
"pip install vllm==v0.10.1.1",
|
||||||
"pip install s5cmd",
|
"pip install s5cmd",
|
||||||
|
|
||||||
# Sync the bench data from S3
|
# Sync the bench data from S3
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import shutil
|
|||||||
# Add parent directory to path
|
# Add parent directory to path
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from olmocr.train.grpo_train import OlmOCRDataset, collate_fn
|
from olmocr.train.grpo_train import OlmOCRDataset, unit_test_reward, load_tests_cached
|
||||||
|
|
||||||
|
|
||||||
class TestGRPODataloader(unittest.TestCase):
|
class TestGRPODataloader(unittest.TestCase):
|
||||||
@ -179,48 +179,6 @@ class TestGRPODataloader(unittest.TestCase):
|
|||||||
self.assertIsInstance(item["test_ids"], list)
|
self.assertIsInstance(item["test_ids"], list)
|
||||||
self.assertTrue(len(item["test_ids"]) > 0)
|
self.assertTrue(len(item["test_ids"]) > 0)
|
||||||
|
|
||||||
@patch('olmocr.train.grpo_train.render_pdf_to_base64png')
|
|
||||||
@patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt')
|
|
||||||
def test_collate_function(self, mock_prompt, mock_render):
|
|
||||||
"""Test that the collate function works correctly."""
|
|
||||||
# Mock the rendering and prompt functions
|
|
||||||
mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
|
||||||
mock_prompt.return_value = "Test prompt"
|
|
||||||
|
|
||||||
dataset = OlmOCRDataset(
|
|
||||||
bench_data_folder=self.bench_data_folder,
|
|
||||||
processor=None,
|
|
||||||
max_samples=2,
|
|
||||||
target_longest_image_dim=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a batch
|
|
||||||
batch = [dataset[0], dataset[1]]
|
|
||||||
collated = collate_fn(batch)
|
|
||||||
|
|
||||||
self.assertIsNotNone(collated)
|
|
||||||
self.assertIn("prompts", collated)
|
|
||||||
self.assertIn("images", collated)
|
|
||||||
self.assertIn("pdf_paths", collated)
|
|
||||||
self.assertIn("jsonl_files", collated)
|
|
||||||
self.assertIn("test_ids", collated)
|
|
||||||
|
|
||||||
# Check batch size consistency
|
|
||||||
self.assertEqual(len(collated["prompts"]), 2)
|
|
||||||
self.assertEqual(len(collated["images"]), 2)
|
|
||||||
self.assertEqual(len(collated["pdf_paths"]), 2)
|
|
||||||
self.assertEqual(len(collated["jsonl_files"]), 2)
|
|
||||||
self.assertEqual(len(collated["test_ids"]), 2)
|
|
||||||
|
|
||||||
def test_collate_with_none_values(self):
|
|
||||||
"""Test that collate function handles None values correctly."""
|
|
||||||
batch = [None, {"prompt": [], "image": None, "pdf_path": "test.pdf",
|
|
||||||
"jsonl_file": "test.jsonl", "test_ids": ["id1"]}, None]
|
|
||||||
|
|
||||||
collated = collate_fn(batch)
|
|
||||||
|
|
||||||
self.assertIsNotNone(collated)
|
|
||||||
self.assertEqual(len(collated["prompts"]), 1)
|
|
||||||
|
|
||||||
def test_empty_jsonl_handling(self):
|
def test_empty_jsonl_handling(self):
|
||||||
"""Test handling of empty JSONL files."""
|
"""Test handling of empty JSONL files."""
|
||||||
@ -300,6 +258,202 @@ class TestGRPODataloader(unittest.TestCase):
|
|||||||
shutil.rmtree(temp_folder)
|
shutil.rmtree(temp_folder)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnitTestReward(unittest.TestCase):
|
||||||
|
"""Test cases for the unit_test_reward function."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"""Create temporary test files."""
|
||||||
|
# Clear any cached tests from previous runs
|
||||||
|
load_tests_cached.cache_clear()
|
||||||
|
cls.temp_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
# Create a sample JSONL test file with different test types
|
||||||
|
cls.jsonl_path = os.path.join(cls.temp_dir, "test.jsonl")
|
||||||
|
test_data = [
|
||||||
|
{
|
||||||
|
"pdf": "test.pdf",
|
||||||
|
"page": 0,
|
||||||
|
"id": "test1",
|
||||||
|
"type": "present",
|
||||||
|
"text": "Hello World",
|
||||||
|
"max_diffs": 0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pdf": "test.pdf",
|
||||||
|
"page": 0,
|
||||||
|
"id": "test2",
|
||||||
|
"type": "absent",
|
||||||
|
"text": "Bad Text",
|
||||||
|
"max_diffs": 0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pdf": "test.pdf",
|
||||||
|
"page": 0,
|
||||||
|
"id": "test3",
|
||||||
|
"type": "baseline",
|
||||||
|
"max_repeats": 30
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pdf": "test.pdf",
|
||||||
|
"page": 0,
|
||||||
|
"id": "test4",
|
||||||
|
"type": "order",
|
||||||
|
"before": "First",
|
||||||
|
"after": "Second",
|
||||||
|
"max_diffs": 0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(cls.jsonl_path, 'w') as f:
|
||||||
|
for test in test_data:
|
||||||
|
f.write(json.dumps(test) + '\n')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
"""Clean up temporary files."""
|
||||||
|
# Clear the LRU cache before removing temp dir
|
||||||
|
load_tests_cached.cache_clear()
|
||||||
|
shutil.rmtree(cls.temp_dir)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Clear cache before each test method."""
|
||||||
|
load_tests_cached.cache_clear()
|
||||||
|
|
||||||
|
def test_perfect_completion(self):
|
||||||
|
"""Test reward calculation for a completion that passes all tests."""
|
||||||
|
completions = ["Hello World\n\nFirst paragraph.\n\nSecond paragraph.\n\nThis is a good document with no bad text."]
|
||||||
|
test_ids = ["test1", "test2", "test3", "test4"]
|
||||||
|
|
||||||
|
rewards = unit_test_reward(
|
||||||
|
prompts=["prompt"],
|
||||||
|
completions=completions,
|
||||||
|
completion_ids=[[]],
|
||||||
|
pdf_path="test.pdf",
|
||||||
|
jsonl_file=self.jsonl_path,
|
||||||
|
test_ids=test_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(rewards), 1)
|
||||||
|
self.assertEqual(rewards[0], 1.0) # All 4 tests should pass
|
||||||
|
|
||||||
|
def test_partial_completion(self):
|
||||||
|
"""Test reward calculation for a completion that passes some tests."""
|
||||||
|
completions = ["This document contains Bad Text but nothing else of note."]
|
||||||
|
test_ids = ["test1", "test2", "test3"]
|
||||||
|
|
||||||
|
rewards = unit_test_reward(
|
||||||
|
prompts=["prompt"],
|
||||||
|
completions=completions,
|
||||||
|
completion_ids=[[]],
|
||||||
|
pdf_path="test.pdf",
|
||||||
|
jsonl_file=self.jsonl_path,
|
||||||
|
test_ids=test_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(rewards), 1)
|
||||||
|
# Should pass only baseline test (1/3)
|
||||||
|
self.assertAlmostEqual(rewards[0], 1/3, places=2)
|
||||||
|
|
||||||
|
def test_multiple_completions(self):
|
||||||
|
"""Test reward calculation for multiple completions."""
|
||||||
|
completions = [
|
||||||
|
"Hello World with good content. First then Second.",
|
||||||
|
"Bad Text only",
|
||||||
|
"", # Empty completion
|
||||||
|
]
|
||||||
|
test_ids = ["test1", "test2", "test3", "test4"]
|
||||||
|
|
||||||
|
rewards = unit_test_reward(
|
||||||
|
prompts=["prompt"] * 3,
|
||||||
|
completions=completions,
|
||||||
|
completion_ids=[[]] * 3,
|
||||||
|
pdf_path="test.pdf",
|
||||||
|
jsonl_file=self.jsonl_path,
|
||||||
|
test_ids=test_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(rewards), 3)
|
||||||
|
# First should pass all 4 tests
|
||||||
|
self.assertEqual(rewards[0], 1.0)
|
||||||
|
# Second should pass only baseline (1/4)
|
||||||
|
self.assertEqual(rewards[1], 0.25)
|
||||||
|
# Third (empty string) passes only the "absent" test (1/4)
|
||||||
|
self.assertEqual(rewards[2], 0.25)
|
||||||
|
|
||||||
|
def test_no_relevant_tests(self):
|
||||||
|
"""Test behavior when no relevant tests are found."""
|
||||||
|
completions = ["Some content"]
|
||||||
|
test_ids = ["nonexistent_test"]
|
||||||
|
|
||||||
|
rewards = unit_test_reward(
|
||||||
|
prompts=["prompt"],
|
||||||
|
completions=completions,
|
||||||
|
completion_ids=[[]],
|
||||||
|
pdf_path="test.pdf",
|
||||||
|
jsonl_file=self.jsonl_path,
|
||||||
|
test_ids=test_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(rewards), 1)
|
||||||
|
self.assertEqual(rewards[0], 0.1) # Default reward when no tests found
|
||||||
|
|
||||||
|
def test_invalid_completion(self):
|
||||||
|
"""Test handling of invalid completions."""
|
||||||
|
completions = [None, "", "Valid content with Hello World"]
|
||||||
|
test_ids = ["test1"]
|
||||||
|
|
||||||
|
rewards = unit_test_reward(
|
||||||
|
prompts=["prompt"] * 3,
|
||||||
|
completions=completions,
|
||||||
|
completion_ids=[[]] * 3,
|
||||||
|
pdf_path="test.pdf",
|
||||||
|
jsonl_file=self.jsonl_path,
|
||||||
|
test_ids=test_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(rewards), 3)
|
||||||
|
# First two should get 0 or epsilon
|
||||||
|
self.assertLessEqual(rewards[0], 0.01)
|
||||||
|
self.assertLessEqual(rewards[1], 0.01)
|
||||||
|
# Last should pass the test
|
||||||
|
self.assertEqual(rewards[2], 1.0)
|
||||||
|
|
||||||
|
def test_cache_functionality(self):
|
||||||
|
"""Test that load_tests_cached properly caches results."""
|
||||||
|
# Clear cache first
|
||||||
|
load_tests_cached.cache_clear()
|
||||||
|
|
||||||
|
# First call should load from file
|
||||||
|
with patch('olmocr.train.grpo_train.load_tests') as mock_load:
|
||||||
|
mock_load.return_value = []
|
||||||
|
result1 = load_tests_cached(self.jsonl_path)
|
||||||
|
self.assertEqual(mock_load.call_count, 1)
|
||||||
|
|
||||||
|
# Second call should use cache
|
||||||
|
result2 = load_tests_cached(self.jsonl_path)
|
||||||
|
self.assertEqual(mock_load.call_count, 1) # Should not increase
|
||||||
|
|
||||||
|
# Results should be the same
|
||||||
|
self.assertEqual(result1, result2)
|
||||||
|
|
||||||
|
def test_error_handling(self):
|
||||||
|
"""Test error handling in reward function."""
|
||||||
|
# Test with non-existent file
|
||||||
|
rewards = unit_test_reward(
|
||||||
|
prompts=["prompt"],
|
||||||
|
completions=["content"],
|
||||||
|
completion_ids=[[]],
|
||||||
|
pdf_path="test.pdf",
|
||||||
|
jsonl_file="/nonexistent/file.jsonl",
|
||||||
|
test_ids=["test1"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return default reward on error
|
||||||
|
self.assertEqual(len(rewards), 1)
|
||||||
|
self.assertEqual(rewards[0], 0.1)
|
||||||
|
|
||||||
|
|
||||||
class TestIntegrationWithRealData(unittest.TestCase):
|
class TestIntegrationWithRealData(unittest.TestCase):
|
||||||
"""Integration tests with real bench data if available."""
|
"""Integration tests with real bench data if available."""
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user