mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-03 19:45:41 +00:00 
			
		
		
		
	Olmocr bench grpo stuff
This commit is contained in:
		
							parent
							
								
									6184c94c3c
								
							
						
					
					
						commit
						1dd6ff9b03
					
				@ -11,6 +11,7 @@ import json
 | 
				
			|||||||
import random
 | 
					import random
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
import glob
 | 
					import glob
 | 
				
			||||||
 | 
					from functools import lru_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
@ -28,6 +29,7 @@ from io import BytesIO
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
 | 
					from olmocr.data.renderpdf import render_pdf_to_base64png
 | 
				
			||||||
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
 | 
					from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
 | 
				
			||||||
 | 
					from olmocr.bench.tests import load_tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Configure logging
 | 
					# Configure logging
 | 
				
			||||||
logging.basicConfig(
 | 
					logging.basicConfig(
 | 
				
			||||||
@ -170,11 +172,101 @@ class OlmOCRDataset(Dataset):
 | 
				
			|||||||
            # Return None if processing fails
 | 
					            # Return None if processing fails
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def simple_length_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs):
 | 
					@lru_cache(maxsize=128)
 | 
				
			||||||
    """Reward function that assigns higher scores to longer completions (in terms of token count)."""
 | 
					def load_tests_cached(jsonl_file: str):
 | 
				
			||||||
    logger.info(f"Reward function called {kwargs}")
 | 
					    """
 | 
				
			||||||
    # return [float(len(ids)) for ids in completions_ids]
 | 
					    Cached version of load_tests to avoid reloading the same JSONL file multiple times.
 | 
				
			||||||
    return [random.choice([0.1, 0.5]) for x in completions]
 | 
					    
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        jsonl_file: Path to the JSONL file containing test definitions
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        List of test objects loaded from the file
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    logger.info(f"Loading tests from {jsonl_file} (will be cached)")
 | 
				
			||||||
 | 
					    return load_tests(jsonl_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def unit_test_reward(prompts, completions, completion_ids, pdf_path, jsonl_file, test_ids, **kwargs):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Reward function that runs unit tests on completions and returns average pass rate.
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    For each completion, loads the corresponding tests from the JSONL file and runs them.
 | 
				
			||||||
 | 
					    Returns the proportion of tests that pass as the reward score.
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        prompts: List of prompts
 | 
				
			||||||
 | 
					        completions: List of generated completions (model outputs)
 | 
				
			||||||
 | 
					        completion_ids: List of completion token IDs
 | 
				
			||||||
 | 
					        pdf_path: Path to the PDF file being processed
 | 
				
			||||||
 | 
					        jsonl_file: Path to the JSONL file containing test definitions
 | 
				
			||||||
 | 
					        test_ids: List of test IDs associated with this PDF page
 | 
				
			||||||
 | 
					        **kwargs: Additional arguments
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        List of reward scores (0.0 to 1.0) based on test pass rates
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    logger.info(f"Running unit test reward function for {len(completions)} completions")
 | 
				
			||||||
 | 
					    logger.info(f"PDF: {pdf_path}, JSONL: {jsonl_file}, Test IDs: {test_ids}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    rewards = []
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Load all tests from the JSONL file (cached)
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        all_tests = load_tests_cached(jsonl_file)
 | 
				
			||||||
 | 
					        # Filter to only the tests for this specific PDF page
 | 
				
			||||||
 | 
					        relevant_tests = [test for test in all_tests if test.id in test_ids]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if not relevant_tests:
 | 
				
			||||||
 | 
					            logger.warning(f"No relevant tests found for test IDs: {test_ids}")
 | 
				
			||||||
 | 
					            # Return a small positive reward to avoid training issues
 | 
				
			||||||
 | 
					            return [0.1 for _ in completions]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        logger.info(f"Found {len(relevant_tests)} relevant tests for this PDF page")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Process each completion
 | 
				
			||||||
 | 
					        for i, completion in enumerate(completions):
 | 
				
			||||||
 | 
					            if not completion or not isinstance(completion, str):
 | 
				
			||||||
 | 
					                logger.warning(f"Invalid completion at index {i}: {type(completion)}")
 | 
				
			||||||
 | 
					                rewards.append(0.0)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Run all relevant tests on this completion
 | 
				
			||||||
 | 
					            passed = 0
 | 
				
			||||||
 | 
					            total = len(relevant_tests)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            for test in relevant_tests:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    test_passed, failure_reason = test.run(completion)
 | 
				
			||||||
 | 
					                    if test_passed:
 | 
				
			||||||
 | 
					                        passed += 1
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        logger.debug(f"Test {test.id} failed: {failure_reason}")
 | 
				
			||||||
 | 
					                except Exception as e:
 | 
				
			||||||
 | 
					                    logger.warning(f"Error running test {test.id}: {e}")
 | 
				
			||||||
 | 
					                    # Count errored tests as failures
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Calculate reward as proportion of tests passed
 | 
				
			||||||
 | 
					            reward = passed / total if total > 0 else 0.0
 | 
				
			||||||
 | 
					            rewards.append(reward)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            logger.info(f"Completion {i}: {passed}/{total} tests passed, reward={reward:.3f}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        logger.error(f"Error in unit_test_reward function: {e}")
 | 
				
			||||||
 | 
					        # Return small positive rewards to avoid training issues
 | 
				
			||||||
 | 
					        return [0.1 for _ in completions]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Ensure we always return rewards between 0 and 1
 | 
				
			||||||
 | 
					    rewards = [max(0.0, min(1.0, r)) for r in rewards]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # If all rewards are 0, add a small epsilon to avoid training issues
 | 
				
			||||||
 | 
					    if all(r == 0.0 for r in rewards):
 | 
				
			||||||
 | 
					        logger.warning("All completions failed all tests, adding small epsilon reward")
 | 
				
			||||||
 | 
					        rewards = [0.01 for _ in rewards]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return rewards
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
@ -370,7 +462,7 @@ def main():
 | 
				
			|||||||
        processing_class=processor,
 | 
					        processing_class=processor,
 | 
				
			||||||
        train_dataset=train_dataset,
 | 
					        train_dataset=train_dataset,
 | 
				
			||||||
        eval_dataset=eval_dataset,
 | 
					        eval_dataset=eval_dataset,
 | 
				
			||||||
        reward_funcs=simple_length_reward,
 | 
					        reward_funcs=unit_test_reward,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Start training
 | 
					    # Start training
 | 
				
			||||||
 | 
				
			|||||||
@ -166,6 +166,7 @@ commands = [
 | 
				
			|||||||
    "pip install trl wandb",
 | 
					    "pip install trl wandb",
 | 
				
			||||||
    "pip install transformers==4.55.2",  # Updated for GRPO compatibility
 | 
					    "pip install transformers==4.55.2",  # Updated for GRPO compatibility
 | 
				
			||||||
    "pip install flash-attn==2.8.0.post2 --no-build-isolation",
 | 
					    "pip install flash-attn==2.8.0.post2 --no-build-isolation",
 | 
				
			||||||
 | 
					    "pip install vllm==v0.10.1.1",
 | 
				
			||||||
    "pip install s5cmd",
 | 
					    "pip install s5cmd",
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Sync the bench data from S3
 | 
					    # Sync the bench data from S3
 | 
				
			||||||
 | 
				
			|||||||
@ -16,7 +16,7 @@ import shutil
 | 
				
			|||||||
# Add parent directory to path
 | 
					# Add parent directory to path
 | 
				
			||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 | 
					sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from olmocr.train.grpo_train import OlmOCRDataset, collate_fn
 | 
					from olmocr.train.grpo_train import OlmOCRDataset, unit_test_reward, load_tests_cached
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestGRPODataloader(unittest.TestCase):
 | 
					class TestGRPODataloader(unittest.TestCase):
 | 
				
			||||||
@ -179,48 +179,6 @@ class TestGRPODataloader(unittest.TestCase):
 | 
				
			|||||||
        self.assertIsInstance(item["test_ids"], list)
 | 
					        self.assertIsInstance(item["test_ids"], list)
 | 
				
			||||||
        self.assertTrue(len(item["test_ids"]) > 0)
 | 
					        self.assertTrue(len(item["test_ids"]) > 0)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    @patch('olmocr.train.grpo_train.render_pdf_to_base64png')
 | 
					 | 
				
			||||||
    @patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt')
 | 
					 | 
				
			||||||
    def test_collate_function(self, mock_prompt, mock_render):
 | 
					 | 
				
			||||||
        """Test that the collate function works correctly."""
 | 
					 | 
				
			||||||
        # Mock the rendering and prompt functions
 | 
					 | 
				
			||||||
        mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
 | 
					 | 
				
			||||||
        mock_prompt.return_value = "Test prompt"
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        dataset = OlmOCRDataset(
 | 
					 | 
				
			||||||
            bench_data_folder=self.bench_data_folder,
 | 
					 | 
				
			||||||
            processor=None,
 | 
					 | 
				
			||||||
            max_samples=2,
 | 
					 | 
				
			||||||
            target_longest_image_dim=1024,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        # Create a batch
 | 
					 | 
				
			||||||
        batch = [dataset[0], dataset[1]]
 | 
					 | 
				
			||||||
        collated = collate_fn(batch)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        self.assertIsNotNone(collated)
 | 
					 | 
				
			||||||
        self.assertIn("prompts", collated)
 | 
					 | 
				
			||||||
        self.assertIn("images", collated)
 | 
					 | 
				
			||||||
        self.assertIn("pdf_paths", collated)
 | 
					 | 
				
			||||||
        self.assertIn("jsonl_files", collated)
 | 
					 | 
				
			||||||
        self.assertIn("test_ids", collated)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        # Check batch size consistency
 | 
					 | 
				
			||||||
        self.assertEqual(len(collated["prompts"]), 2)
 | 
					 | 
				
			||||||
        self.assertEqual(len(collated["images"]), 2)
 | 
					 | 
				
			||||||
        self.assertEqual(len(collated["pdf_paths"]), 2)
 | 
					 | 
				
			||||||
        self.assertEqual(len(collated["jsonl_files"]), 2)
 | 
					 | 
				
			||||||
        self.assertEqual(len(collated["test_ids"]), 2)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    def test_collate_with_none_values(self):
 | 
					 | 
				
			||||||
        """Test that collate function handles None values correctly."""
 | 
					 | 
				
			||||||
        batch = [None, {"prompt": [], "image": None, "pdf_path": "test.pdf", 
 | 
					 | 
				
			||||||
                        "jsonl_file": "test.jsonl", "test_ids": ["id1"]}, None]
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        collated = collate_fn(batch)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        self.assertIsNotNone(collated)
 | 
					 | 
				
			||||||
        self.assertEqual(len(collated["prompts"]), 1)
 | 
					 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def test_empty_jsonl_handling(self):
 | 
					    def test_empty_jsonl_handling(self):
 | 
				
			||||||
        """Test handling of empty JSONL files."""
 | 
					        """Test handling of empty JSONL files."""
 | 
				
			||||||
@ -300,6 +258,202 @@ class TestGRPODataloader(unittest.TestCase):
 | 
				
			|||||||
        shutil.rmtree(temp_folder)
 | 
					        shutil.rmtree(temp_folder)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestUnitTestReward(unittest.TestCase):
 | 
				
			||||||
 | 
					    """Test cases for the unit_test_reward function."""
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        """Create temporary test files."""
 | 
				
			||||||
 | 
					        # Clear any cached tests from previous runs
 | 
				
			||||||
 | 
					        load_tests_cached.cache_clear()
 | 
				
			||||||
 | 
					        cls.temp_dir = tempfile.mkdtemp()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Create a sample JSONL test file with different test types
 | 
				
			||||||
 | 
					        cls.jsonl_path = os.path.join(cls.temp_dir, "test.jsonl")
 | 
				
			||||||
 | 
					        test_data = [
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "pdf": "test.pdf",
 | 
				
			||||||
 | 
					                "page": 0,
 | 
				
			||||||
 | 
					                "id": "test1",
 | 
				
			||||||
 | 
					                "type": "present",
 | 
				
			||||||
 | 
					                "text": "Hello World",
 | 
				
			||||||
 | 
					                "max_diffs": 0
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "pdf": "test.pdf",
 | 
				
			||||||
 | 
					                "page": 0,
 | 
				
			||||||
 | 
					                "id": "test2",
 | 
				
			||||||
 | 
					                "type": "absent",
 | 
				
			||||||
 | 
					                "text": "Bad Text",
 | 
				
			||||||
 | 
					                "max_diffs": 0
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "pdf": "test.pdf",
 | 
				
			||||||
 | 
					                "page": 0,
 | 
				
			||||||
 | 
					                "id": "test3",
 | 
				
			||||||
 | 
					                "type": "baseline",
 | 
				
			||||||
 | 
					                "max_repeats": 30
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "pdf": "test.pdf",
 | 
				
			||||||
 | 
					                "page": 0,
 | 
				
			||||||
 | 
					                "id": "test4",
 | 
				
			||||||
 | 
					                "type": "order",
 | 
				
			||||||
 | 
					                "before": "First",
 | 
				
			||||||
 | 
					                "after": "Second",
 | 
				
			||||||
 | 
					                "max_diffs": 0
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        with open(cls.jsonl_path, 'w') as f:
 | 
				
			||||||
 | 
					            for test in test_data:
 | 
				
			||||||
 | 
					                f.write(json.dumps(test) + '\n')
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def tearDownClass(cls):
 | 
				
			||||||
 | 
					        """Clean up temporary files."""
 | 
				
			||||||
 | 
					        # Clear the LRU cache before removing temp dir
 | 
				
			||||||
 | 
					        load_tests_cached.cache_clear()
 | 
				
			||||||
 | 
					        shutil.rmtree(cls.temp_dir)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        """Clear cache before each test method."""
 | 
				
			||||||
 | 
					        load_tests_cached.cache_clear()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def test_perfect_completion(self):
 | 
				
			||||||
 | 
					        """Test reward calculation for a completion that passes all tests."""
 | 
				
			||||||
 | 
					        completions = ["Hello World\n\nFirst paragraph.\n\nSecond paragraph.\n\nThis is a good document with no bad text."]
 | 
				
			||||||
 | 
					        test_ids = ["test1", "test2", "test3", "test4"]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        rewards = unit_test_reward(
 | 
				
			||||||
 | 
					            prompts=["prompt"],
 | 
				
			||||||
 | 
					            completions=completions,
 | 
				
			||||||
 | 
					            completion_ids=[[]],
 | 
				
			||||||
 | 
					            pdf_path="test.pdf",
 | 
				
			||||||
 | 
					            jsonl_file=self.jsonl_path,
 | 
				
			||||||
 | 
					            test_ids=test_ids
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.assertEqual(len(rewards), 1)
 | 
				
			||||||
 | 
					        self.assertEqual(rewards[0], 1.0)  # All 4 tests should pass
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def test_partial_completion(self):
 | 
				
			||||||
 | 
					        """Test reward calculation for a completion that passes some tests."""
 | 
				
			||||||
 | 
					        completions = ["This document contains Bad Text but nothing else of note."]
 | 
				
			||||||
 | 
					        test_ids = ["test1", "test2", "test3"]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        rewards = unit_test_reward(
 | 
				
			||||||
 | 
					            prompts=["prompt"],
 | 
				
			||||||
 | 
					            completions=completions,
 | 
				
			||||||
 | 
					            completion_ids=[[]],
 | 
				
			||||||
 | 
					            pdf_path="test.pdf",
 | 
				
			||||||
 | 
					            jsonl_file=self.jsonl_path,
 | 
				
			||||||
 | 
					            test_ids=test_ids
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.assertEqual(len(rewards), 1)
 | 
				
			||||||
 | 
					        # Should pass only baseline test (1/3)
 | 
				
			||||||
 | 
					        self.assertAlmostEqual(rewards[0], 1/3, places=2)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def test_multiple_completions(self):
 | 
				
			||||||
 | 
					        """Test reward calculation for multiple completions."""
 | 
				
			||||||
 | 
					        completions = [
 | 
				
			||||||
 | 
					            "Hello World with good content. First then Second.",
 | 
				
			||||||
 | 
					            "Bad Text only",
 | 
				
			||||||
 | 
					            "",  # Empty completion
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        test_ids = ["test1", "test2", "test3", "test4"]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        rewards = unit_test_reward(
 | 
				
			||||||
 | 
					            prompts=["prompt"] * 3,
 | 
				
			||||||
 | 
					            completions=completions,
 | 
				
			||||||
 | 
					            completion_ids=[[]] * 3,
 | 
				
			||||||
 | 
					            pdf_path="test.pdf",
 | 
				
			||||||
 | 
					            jsonl_file=self.jsonl_path,
 | 
				
			||||||
 | 
					            test_ids=test_ids
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.assertEqual(len(rewards), 3)
 | 
				
			||||||
 | 
					        # First should pass all 4 tests
 | 
				
			||||||
 | 
					        self.assertEqual(rewards[0], 1.0)
 | 
				
			||||||
 | 
					        # Second should pass only baseline (1/4)
 | 
				
			||||||
 | 
					        self.assertEqual(rewards[1], 0.25)
 | 
				
			||||||
 | 
					        # Third (empty string) passes only the "absent" test (1/4)
 | 
				
			||||||
 | 
					        self.assertEqual(rewards[2], 0.25)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def test_no_relevant_tests(self):
 | 
				
			||||||
 | 
					        """Test behavior when no relevant tests are found."""
 | 
				
			||||||
 | 
					        completions = ["Some content"]
 | 
				
			||||||
 | 
					        test_ids = ["nonexistent_test"]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        rewards = unit_test_reward(
 | 
				
			||||||
 | 
					            prompts=["prompt"],
 | 
				
			||||||
 | 
					            completions=completions,
 | 
				
			||||||
 | 
					            completion_ids=[[]],
 | 
				
			||||||
 | 
					            pdf_path="test.pdf",
 | 
				
			||||||
 | 
					            jsonl_file=self.jsonl_path,
 | 
				
			||||||
 | 
					            test_ids=test_ids
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.assertEqual(len(rewards), 1)
 | 
				
			||||||
 | 
					        self.assertEqual(rewards[0], 0.1)  # Default reward when no tests found
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def test_invalid_completion(self):
 | 
				
			||||||
 | 
					        """Test handling of invalid completions."""
 | 
				
			||||||
 | 
					        completions = [None, "", "Valid content with Hello World"]
 | 
				
			||||||
 | 
					        test_ids = ["test1"]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        rewards = unit_test_reward(
 | 
				
			||||||
 | 
					            prompts=["prompt"] * 3,
 | 
				
			||||||
 | 
					            completions=completions,
 | 
				
			||||||
 | 
					            completion_ids=[[]] * 3,
 | 
				
			||||||
 | 
					            pdf_path="test.pdf",
 | 
				
			||||||
 | 
					            jsonl_file=self.jsonl_path,
 | 
				
			||||||
 | 
					            test_ids=test_ids
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.assertEqual(len(rewards), 3)
 | 
				
			||||||
 | 
					        # First two should get 0 or epsilon
 | 
				
			||||||
 | 
					        self.assertLessEqual(rewards[0], 0.01)
 | 
				
			||||||
 | 
					        self.assertLessEqual(rewards[1], 0.01)
 | 
				
			||||||
 | 
					        # Last should pass the test
 | 
				
			||||||
 | 
					        self.assertEqual(rewards[2], 1.0)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def test_cache_functionality(self):
 | 
				
			||||||
 | 
					        """Test that load_tests_cached properly caches results."""
 | 
				
			||||||
 | 
					        # Clear cache first
 | 
				
			||||||
 | 
					        load_tests_cached.cache_clear()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # First call should load from file
 | 
				
			||||||
 | 
					        with patch('olmocr.train.grpo_train.load_tests') as mock_load:
 | 
				
			||||||
 | 
					            mock_load.return_value = []
 | 
				
			||||||
 | 
					            result1 = load_tests_cached(self.jsonl_path)
 | 
				
			||||||
 | 
					            self.assertEqual(mock_load.call_count, 1)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Second call should use cache
 | 
				
			||||||
 | 
					            result2 = load_tests_cached(self.jsonl_path)
 | 
				
			||||||
 | 
					            self.assertEqual(mock_load.call_count, 1)  # Should not increase
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Results should be the same
 | 
				
			||||||
 | 
					            self.assertEqual(result1, result2)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def test_error_handling(self):
 | 
				
			||||||
 | 
					        """Test error handling in reward function."""
 | 
				
			||||||
 | 
					        # Test with non-existent file
 | 
				
			||||||
 | 
					        rewards = unit_test_reward(
 | 
				
			||||||
 | 
					            prompts=["prompt"],
 | 
				
			||||||
 | 
					            completions=["content"],
 | 
				
			||||||
 | 
					            completion_ids=[[]],
 | 
				
			||||||
 | 
					            pdf_path="test.pdf",
 | 
				
			||||||
 | 
					            jsonl_file="/nonexistent/file.jsonl",
 | 
				
			||||||
 | 
					            test_ids=["test1"]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Should return default reward on error
 | 
				
			||||||
 | 
					        self.assertEqual(len(rewards), 1)
 | 
				
			||||||
 | 
					        self.assertEqual(rewards[0], 0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestIntegrationWithRealData(unittest.TestCase):
 | 
					class TestIntegrationWithRealData(unittest.TestCase):
 | 
				
			||||||
    """Integration tests with real bench data if available."""
 | 
					    """Integration tests with real bench data if available."""
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user