mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 03:25:22 +00:00
Olmocr bench grpo stuff
This commit is contained in:
parent
6184c94c3c
commit
1dd6ff9b03
@ -11,6 +11,7 @@ import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
import glob
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@ -28,6 +29,7 @@ from io import BytesIO
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
|
||||
from olmocr.bench.tests import load_tests
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -170,11 +172,101 @@ class OlmOCRDataset(Dataset):
|
||||
# Return None if processing fails
|
||||
return None
|
||||
|
||||
def simple_length_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs):
|
||||
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
|
||||
logger.info(f"Reward function called {kwargs}")
|
||||
# return [float(len(ids)) for ids in completions_ids]
|
||||
return [random.choice([0.1, 0.5]) for x in completions]
|
||||
@lru_cache(maxsize=128)
|
||||
def load_tests_cached(jsonl_file: str):
|
||||
"""
|
||||
Cached version of load_tests to avoid reloading the same JSONL file multiple times.
|
||||
|
||||
Args:
|
||||
jsonl_file: Path to the JSONL file containing test definitions
|
||||
|
||||
Returns:
|
||||
List of test objects loaded from the file
|
||||
"""
|
||||
logger.info(f"Loading tests from {jsonl_file} (will be cached)")
|
||||
return load_tests(jsonl_file)
|
||||
|
||||
|
||||
def unit_test_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs):
|
||||
"""
|
||||
Reward function that runs unit tests on completions and returns average pass rate.
|
||||
|
||||
For each completion, loads the corresponding tests from the JSONL file and runs them.
|
||||
Returns the proportion of tests that pass as the reward score.
|
||||
|
||||
Args:
|
||||
prompts: List of prompts
|
||||
completions: List of generated completions (model outputs)
|
||||
completion_ids: List of completion token IDs
|
||||
pdf_path: Path to the PDF file being processed
|
||||
jsonl_file: Path to the JSONL file containing test definitions
|
||||
test_ids: List of test IDs associated with this PDF page
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
List of reward scores (0.0 to 1.0) based on test pass rates
|
||||
"""
|
||||
logger.info(f"Running unit test reward function for {len(completions)} completions")
|
||||
logger.info(f"PDF: {pdf_path}, JSONL: {jsonl_file}, Test IDs: {test_ids}")
|
||||
|
||||
rewards = []
|
||||
|
||||
# Load all tests from the JSONL file (cached)
|
||||
try:
|
||||
all_tests = load_tests_cached(jsonl_file)
|
||||
# Filter to only the tests for this specific PDF page
|
||||
relevant_tests = [test for test in all_tests if test.id in test_ids]
|
||||
|
||||
if not relevant_tests:
|
||||
logger.warning(f"No relevant tests found for test IDs: {test_ids}")
|
||||
# Return a small positive reward to avoid training issues
|
||||
return [0.1 for _ in completions]
|
||||
|
||||
logger.info(f"Found {len(relevant_tests)} relevant tests for this PDF page")
|
||||
|
||||
# Process each completion
|
||||
for i, completion in enumerate(completions):
|
||||
if not completion or not isinstance(completion, str):
|
||||
logger.warning(f"Invalid completion at index {i}: {type(completion)}")
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
# Run all relevant tests on this completion
|
||||
passed = 0
|
||||
total = len(relevant_tests)
|
||||
|
||||
for test in relevant_tests:
|
||||
try:
|
||||
test_passed, failure_reason = test.run(completion)
|
||||
if test_passed:
|
||||
passed += 1
|
||||
else:
|
||||
logger.debug(f"Test {test.id} failed: {failure_reason}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error running test {test.id}: {e}")
|
||||
# Count errored tests as failures
|
||||
continue
|
||||
|
||||
# Calculate reward as proportion of tests passed
|
||||
reward = passed / total if total > 0 else 0.0
|
||||
rewards.append(reward)
|
||||
|
||||
logger.info(f"Completion {i}: {passed}/{total} tests passed, reward={reward:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unit_test_reward function: {e}")
|
||||
# Return small positive rewards to avoid training issues
|
||||
return [0.1 for _ in completions]
|
||||
|
||||
# Ensure we always return rewards between 0 and 1
|
||||
rewards = [max(0.0, min(1.0, r)) for r in rewards]
|
||||
|
||||
# If all rewards are 0, add a small epsilon to avoid training issues
|
||||
if all(r == 0.0 for r in rewards):
|
||||
logger.warning("All completions failed all tests, adding small epsilon reward")
|
||||
rewards = [0.01 for _ in rewards]
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def main():
|
||||
@ -370,7 +462,7 @@ def main():
|
||||
processing_class=processor,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
reward_funcs=simple_length_reward,
|
||||
reward_funcs=unit_test_reward,
|
||||
)
|
||||
|
||||
# Start training
|
||||
|
||||
@ -166,6 +166,7 @@ commands = [
|
||||
"pip install trl wandb",
|
||||
"pip install transformers==4.55.2", # Updated for GRPO compatibility
|
||||
"pip install flash-attn==2.8.0.post2 --no-build-isolation",
|
||||
"pip install vllm==v0.10.1.1",
|
||||
"pip install s5cmd",
|
||||
|
||||
# Sync the bench data from S3
|
||||
|
||||
@ -16,7 +16,7 @@ import shutil
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from olmocr.train.grpo_train import OlmOCRDataset, collate_fn
|
||||
from olmocr.train.grpo_train import OlmOCRDataset, unit_test_reward, load_tests_cached
|
||||
|
||||
|
||||
class TestGRPODataloader(unittest.TestCase):
|
||||
@ -179,48 +179,6 @@ class TestGRPODataloader(unittest.TestCase):
|
||||
self.assertIsInstance(item["test_ids"], list)
|
||||
self.assertTrue(len(item["test_ids"]) > 0)
|
||||
|
||||
@patch('olmocr.train.grpo_train.render_pdf_to_base64png')
|
||||
@patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt')
|
||||
def test_collate_function(self, mock_prompt, mock_render):
|
||||
"""Test that the collate function works correctly."""
|
||||
# Mock the rendering and prompt functions
|
||||
mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
mock_prompt.return_value = "Test prompt"
|
||||
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=self.bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=2,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
# Create a batch
|
||||
batch = [dataset[0], dataset[1]]
|
||||
collated = collate_fn(batch)
|
||||
|
||||
self.assertIsNotNone(collated)
|
||||
self.assertIn("prompts", collated)
|
||||
self.assertIn("images", collated)
|
||||
self.assertIn("pdf_paths", collated)
|
||||
self.assertIn("jsonl_files", collated)
|
||||
self.assertIn("test_ids", collated)
|
||||
|
||||
# Check batch size consistency
|
||||
self.assertEqual(len(collated["prompts"]), 2)
|
||||
self.assertEqual(len(collated["images"]), 2)
|
||||
self.assertEqual(len(collated["pdf_paths"]), 2)
|
||||
self.assertEqual(len(collated["jsonl_files"]), 2)
|
||||
self.assertEqual(len(collated["test_ids"]), 2)
|
||||
|
||||
def test_collate_with_none_values(self):
|
||||
"""Test that collate function handles None values correctly."""
|
||||
batch = [None, {"prompt": [], "image": None, "pdf_path": "test.pdf",
|
||||
"jsonl_file": "test.jsonl", "test_ids": ["id1"]}, None]
|
||||
|
||||
collated = collate_fn(batch)
|
||||
|
||||
self.assertIsNotNone(collated)
|
||||
self.assertEqual(len(collated["prompts"]), 1)
|
||||
|
||||
def test_empty_jsonl_handling(self):
|
||||
"""Test handling of empty JSONL files."""
|
||||
@ -300,6 +258,202 @@ class TestGRPODataloader(unittest.TestCase):
|
||||
shutil.rmtree(temp_folder)
|
||||
|
||||
|
||||
class TestUnitTestReward(unittest.TestCase):
|
||||
"""Test cases for the unit_test_reward function."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Create temporary test files."""
|
||||
# Clear any cached tests from previous runs
|
||||
load_tests_cached.cache_clear()
|
||||
cls.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
# Create a sample JSONL test file with different test types
|
||||
cls.jsonl_path = os.path.join(cls.temp_dir, "test.jsonl")
|
||||
test_data = [
|
||||
{
|
||||
"pdf": "test.pdf",
|
||||
"page": 0,
|
||||
"id": "test1",
|
||||
"type": "present",
|
||||
"text": "Hello World",
|
||||
"max_diffs": 0
|
||||
},
|
||||
{
|
||||
"pdf": "test.pdf",
|
||||
"page": 0,
|
||||
"id": "test2",
|
||||
"type": "absent",
|
||||
"text": "Bad Text",
|
||||
"max_diffs": 0
|
||||
},
|
||||
{
|
||||
"pdf": "test.pdf",
|
||||
"page": 0,
|
||||
"id": "test3",
|
||||
"type": "baseline",
|
||||
"max_repeats": 30
|
||||
},
|
||||
{
|
||||
"pdf": "test.pdf",
|
||||
"page": 0,
|
||||
"id": "test4",
|
||||
"type": "order",
|
||||
"before": "First",
|
||||
"after": "Second",
|
||||
"max_diffs": 0
|
||||
}
|
||||
]
|
||||
|
||||
with open(cls.jsonl_path, 'w') as f:
|
||||
for test in test_data:
|
||||
f.write(json.dumps(test) + '\n')
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""Clean up temporary files."""
|
||||
# Clear the LRU cache before removing temp dir
|
||||
load_tests_cached.cache_clear()
|
||||
shutil.rmtree(cls.temp_dir)
|
||||
|
||||
def setUp(self):
|
||||
"""Clear cache before each test method."""
|
||||
load_tests_cached.cache_clear()
|
||||
|
||||
def test_perfect_completion(self):
|
||||
"""Test reward calculation for a completion that passes all tests."""
|
||||
completions = ["Hello World\n\nFirst paragraph.\n\nSecond paragraph.\n\nThis is a good document with no bad text."]
|
||||
test_ids = ["test1", "test2", "test3", "test4"]
|
||||
|
||||
rewards = unit_test_reward(
|
||||
prompts=["prompt"],
|
||||
completions=completions,
|
||||
completion_ids=[[]],
|
||||
pdf_path="test.pdf",
|
||||
jsonl_file=self.jsonl_path,
|
||||
test_ids=test_ids
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 1)
|
||||
self.assertEqual(rewards[0], 1.0) # All 4 tests should pass
|
||||
|
||||
def test_partial_completion(self):
|
||||
"""Test reward calculation for a completion that passes some tests."""
|
||||
completions = ["This document contains Bad Text but nothing else of note."]
|
||||
test_ids = ["test1", "test2", "test3"]
|
||||
|
||||
rewards = unit_test_reward(
|
||||
prompts=["prompt"],
|
||||
completions=completions,
|
||||
completion_ids=[[]],
|
||||
pdf_path="test.pdf",
|
||||
jsonl_file=self.jsonl_path,
|
||||
test_ids=test_ids
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 1)
|
||||
# Should pass only baseline test (1/3)
|
||||
self.assertAlmostEqual(rewards[0], 1/3, places=2)
|
||||
|
||||
def test_multiple_completions(self):
|
||||
"""Test reward calculation for multiple completions."""
|
||||
completions = [
|
||||
"Hello World with good content. First then Second.",
|
||||
"Bad Text only",
|
||||
"", # Empty completion
|
||||
]
|
||||
test_ids = ["test1", "test2", "test3", "test4"]
|
||||
|
||||
rewards = unit_test_reward(
|
||||
prompts=["prompt"] * 3,
|
||||
completions=completions,
|
||||
completion_ids=[[]] * 3,
|
||||
pdf_path="test.pdf",
|
||||
jsonl_file=self.jsonl_path,
|
||||
test_ids=test_ids
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 3)
|
||||
# First should pass all 4 tests
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
# Second should pass only baseline (1/4)
|
||||
self.assertEqual(rewards[1], 0.25)
|
||||
# Third (empty string) passes only the "absent" test (1/4)
|
||||
self.assertEqual(rewards[2], 0.25)
|
||||
|
||||
def test_no_relevant_tests(self):
|
||||
"""Test behavior when no relevant tests are found."""
|
||||
completions = ["Some content"]
|
||||
test_ids = ["nonexistent_test"]
|
||||
|
||||
rewards = unit_test_reward(
|
||||
prompts=["prompt"],
|
||||
completions=completions,
|
||||
completion_ids=[[]],
|
||||
pdf_path="test.pdf",
|
||||
jsonl_file=self.jsonl_path,
|
||||
test_ids=test_ids
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 1)
|
||||
self.assertEqual(rewards[0], 0.1) # Default reward when no tests found
|
||||
|
||||
def test_invalid_completion(self):
|
||||
"""Test handling of invalid completions."""
|
||||
completions = [None, "", "Valid content with Hello World"]
|
||||
test_ids = ["test1"]
|
||||
|
||||
rewards = unit_test_reward(
|
||||
prompts=["prompt"] * 3,
|
||||
completions=completions,
|
||||
completion_ids=[[]] * 3,
|
||||
pdf_path="test.pdf",
|
||||
jsonl_file=self.jsonl_path,
|
||||
test_ids=test_ids
|
||||
)
|
||||
|
||||
self.assertEqual(len(rewards), 3)
|
||||
# First two should get 0 or epsilon
|
||||
self.assertLessEqual(rewards[0], 0.01)
|
||||
self.assertLessEqual(rewards[1], 0.01)
|
||||
# Last should pass the test
|
||||
self.assertEqual(rewards[2], 1.0)
|
||||
|
||||
def test_cache_functionality(self):
|
||||
"""Test that load_tests_cached properly caches results."""
|
||||
# Clear cache first
|
||||
load_tests_cached.cache_clear()
|
||||
|
||||
# First call should load from file
|
||||
with patch('olmocr.train.grpo_train.load_tests') as mock_load:
|
||||
mock_load.return_value = []
|
||||
result1 = load_tests_cached(self.jsonl_path)
|
||||
self.assertEqual(mock_load.call_count, 1)
|
||||
|
||||
# Second call should use cache
|
||||
result2 = load_tests_cached(self.jsonl_path)
|
||||
self.assertEqual(mock_load.call_count, 1) # Should not increase
|
||||
|
||||
# Results should be the same
|
||||
self.assertEqual(result1, result2)
|
||||
|
||||
def test_error_handling(self):
|
||||
"""Test error handling in reward function."""
|
||||
# Test with non-existent file
|
||||
rewards = unit_test_reward(
|
||||
prompts=["prompt"],
|
||||
completions=["content"],
|
||||
completion_ids=[[]],
|
||||
pdf_path="test.pdf",
|
||||
jsonl_file="/nonexistent/file.jsonl",
|
||||
test_ids=["test1"]
|
||||
)
|
||||
|
||||
# Should return default reward on error
|
||||
self.assertEqual(len(rewards), 1)
|
||||
self.assertEqual(rewards[0], 0.1)
|
||||
|
||||
|
||||
class TestIntegrationWithRealData(unittest.TestCase):
|
||||
"""Integration tests with real bench data if available."""
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user