From cc918ca03eeca9eec90b61ac1eeb4460ae9e2e5b Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 20 Aug 2025 22:18:38 +0000 Subject: [PATCH] Setting up GRPO trainer --- olmocr/train/grpo_train.py | 377 +++++++++++++++++++++++++++++++++++++ tests/test_grpo.py | 338 +++++++++++++++++++++++++++++++++ 2 files changed, 715 insertions(+) create mode 100644 olmocr/train/grpo_train.py create mode 100644 tests/test_grpo.py diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py new file mode 100644 index 0000000..938fe3f --- /dev/null +++ b/olmocr/train/grpo_train.py @@ -0,0 +1,377 @@ +""" +GRPO (Generative Reward-based Policy Optimization) training script for OlmOCR. +""" + +import argparse +import logging +import os +from typing import List, Dict, Any, Optional, Set +import asyncio +import json +import random +from pathlib import Path +import glob + +import torch +import numpy as np +from torch.utils.data import Dataset, DataLoader +from transformers import ( + AutoProcessor, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLForConditionalGeneration, +) +from trl import GRPOConfig, GRPOTrainer +from PIL import Image +import base64 +from io import BytesIO + +from olmocr.train.config import Config +from olmocr.data.renderpdf import render_pdf_to_base64png +from olmocr.prompts import build_no_anchoring_v4_yaml_prompt + +# Configure logging +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +class OlmOCRDataset(Dataset): + """Dataset for loading PDF pages from Olmocr-bench format JSONL files.""" + + def __init__( + self, + bench_data_folder: str, + processor, + max_samples: Optional[int] = None, + target_longest_image_dim: int = 1024, + ): + self.bench_data_folder = bench_data_folder + self.processor = processor + self.target_longest_image_dim = target_longest_image_dim + self.max_samples = max_samples + + # Find PDF folder + self.pdf_folder = os.path.join(bench_data_folder, "pdfs") + if not os.path.exists(self.pdf_folder): + raise ValueError(f"PDFs folder not found at {self.pdf_folder}") + + # Load unique PDFs from JSONL files + self.samples = self._load_unique_pdfs_from_jsonl() + + logger.info(f"Created dataset with {len(self.samples)} unique PDF samples") + + def _load_unique_pdfs_from_jsonl(self) -> List[Dict[str, Any]]: + """Load unique PDFs from JSONL files in the bench_data folder, tracking all test cases per PDF.""" + jsonl_files = glob.glob(os.path.join(self.bench_data_folder, "*.jsonl")) + + if not jsonl_files: + raise ValueError(f"No JSONL files found in {self.bench_data_folder}") + + logger.info(f"Found {len(jsonl_files)} JSONL files") + + # Track unique PDFs and their test cases + pdf_data: Dict[str, Dict[str, Any]] = {} + + for jsonl_file in jsonl_files: + logger.info(f"Processing {os.path.basename(jsonl_file)}") + + with open(jsonl_file, 'r') as f: + for line in f: + try: + entry = json.loads(line.strip()) + pdf_name = entry.get("pdf") + page = entry.get("page", 0) + test_id = entry.get("id") + + if pdf_name and test_id: + # Create unique key for PDF+page combination + pdf_page_key = f"{pdf_name}::{page}" + + if pdf_page_key not in pdf_data: + # First time seeing this PDF+page + pdf_path = os.path.join(self.pdf_folder, pdf_name) + pdf_data[pdf_page_key] = { + "pdf_path": pdf_path, + "pdf_name": pdf_name, + "page": page, + "jsonl_file": jsonl_file, + "test_ids": [test_id], + "entries": [entry] + } + else: + # Add test case to existing PDF+page + pdf_data[pdf_page_key]["test_ids"].append(test_id) + pdf_data[pdf_page_key]["entries"].append(entry) + + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse line in {jsonl_file}: {e}") + continue + except Exception as e: + logger.warning(f"Error processing entry in {jsonl_file}: {e}") + continue + + # Convert to list and apply max_samples limit + samples = list(pdf_data.values()) + if self.max_samples: + samples = samples[:self.max_samples] + + return samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + sample = self.samples[idx] + pdf_path = sample["pdf_path"] + page_num = sample["page"] + jsonl_file = sample["jsonl_file"] + test_ids = sample["test_ids"] + + try: + # Render PDF page to base64 image + image_base64 = render_pdf_to_base64png( + pdf_path, + page_num, + target_longest_image_dim=self.target_longest_image_dim + ) + + # Convert base64 to PIL Image + image_bytes = base64.b64decode(image_base64) + image = Image.open(BytesIO(image_bytes)).convert("RGB") + + # Build the text prompt + text_prompt = build_no_anchoring_v4_yaml_prompt() + + # Create messages in the format expected by Qwen2-VL + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text_prompt}, + {"type": "image"}, + ], + } + ] + + # Return the required format + return { + "prompt": messages, + "pdf_path": pdf_path, + "jsonl_file": jsonl_file, + "test_ids": test_ids, + "image": image, # Include the PIL image for processing later + } + + except Exception as e: + logger.error(f"Failed to process sample {idx}: {e}") + # Return None if processing fails + return None + + +def collate_fn(batch): + """Custom collate function to handle the new batch format with prompts and metadata.""" + # Filter out None values + batch = [item for item in batch if item is not None] + + if not batch: + return None + + # Collect all components + prompts = [item["prompt"] for item in batch] + images = [item["image"] for item in batch] + pdf_paths = [item["pdf_path"] for item in batch] + jsonl_files = [item["jsonl_file"] for item in batch] + test_ids = [item["test_ids"] for item in batch] + + # Return batch with all required information + return { + "prompts": prompts, + "images": images, + "pdf_paths": pdf_paths, + "jsonl_files": jsonl_files, + "test_ids": test_ids, + } + + +def simple_length_reward(completions: List[str], **kwargs) -> List[float]: + """ + Simple reward function that rewards completions close to 100 tokens. + Returns higher rewards for completions closer to the target length. + """ + target_length = 100 + rewards = [] + + for completion in completions: + # Count tokens (simple word-based approximation) + tokens = completion.split() + length = len(tokens) + + # Calculate reward based on distance from target + distance = abs(length - target_length) + + # Reward function: max reward of 1.0 at target length, + # decreasing as we get further away + if distance == 0: + reward = 1.0 + else: + # Exponential decay based on distance + reward = max(0.0, 1.0 - (distance / target_length)) + + rewards.append(reward) + + logger.info(f"Reward stats: mean={np.mean(rewards):.3f}, std={np.std(rewards):.3f}") + return rewards + + +def main(): + parser = argparse.ArgumentParser(description="GRPO training for OlmOCR") + parser.add_argument( + "--bench_data_folder", + type=str, + required=True, + help="Path to bench data folder containing JSONL files and pdfs subfolder" + ) + parser.add_argument( + "--model_name", + type=str, + default="Qwen/Qwen2.5-VL-7B-Instruct", + help="Model checkpoint to load" + ) + parser.add_argument( + "--output_dir", + type=str, + default="outputs/grpo_test", + help="Output directory for checkpoints" + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-6, + help="Learning rate" + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=1, + help="Number of training epochs" + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=1, + help="Training batch size per device" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=4, + help="Gradient accumulation steps" + ) + parser.add_argument( + "--max_samples", + type=int, + default=10, + help="Maximum number of samples to use (for testing)" + ) + + args = parser.parse_args() + + # Set up output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Verify bench_data_folder exists + if not os.path.exists(args.bench_data_folder): + logger.error(f"Bench data folder not found: {args.bench_data_folder}") + return + + # Load processor + logger.info(f"Loading processor: {args.model_name}") + processor = AutoProcessor.from_pretrained( + args.model_name, + trust_remote_code=True, + ) + + # Load model + logger.info(f"Loading model: {args.model_name}") + if "Qwen2.5-VL" in args.model_name: + model_class = Qwen2_5_VLForConditionalGeneration + elif "Qwen2-VL" in args.model_name: + model_class = Qwen2VLForConditionalGeneration + else: + raise ValueError(f"Unsupported model: {args.model_name}") + + model = model_class.from_pretrained( + args.model_name, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) + + # Create dataset from bench data folder + logger.info(f"Creating dataset from bench data folder: {args.bench_data_folder}") + dataset = OlmOCRDataset( + bench_data_folder=args.bench_data_folder, + processor=processor, + max_samples=args.max_samples, + target_longest_image_dim=1024, + ) + + if len(dataset) == 0: + logger.error("No samples found in dataset!") + return + + # Set up GRPO configuration + grpo_config = GRPOConfig( + output_dir=args.output_dir, + num_train_epochs=args.num_train_epochs, + per_device_train_batch_size=args.per_device_train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + learning_rate=args.learning_rate, + logging_steps=10, + save_steps=100, + eval_steps=50, + warmup_steps=10, + max_new_tokens=150, + temperature=0.7, + do_sample=True, + report_to=["tensorboard"], + remove_unused_columns=False, + torch_dtype=torch.bfloat16, + bf16=True, + gradient_checkpointing=True, + dataloader_num_workers=0, + ) + + # Initialize GRPO trainer + logger.info("Initializing GRPO trainer") + trainer = GRPOTrainer( + model=model, + args=grpo_config, + processing_class=processor, + train_dataset=dataset, + reward_function=simple_length_reward, + data_collator=collate_fn, + ) + + # Start training + logger.info("Starting GRPO training") + try: + trainer.train() + + # Save final model + logger.info(f"Saving final model to {args.output_dir}") + trainer.save_model() + processor.save_pretrained(args.output_dir) + + logger.info("Training completed successfully!") + + except Exception as e: + logger.error(f"Training failed: {e}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_grpo.py b/tests/test_grpo.py new file mode 100644 index 0000000..8c2a235 --- /dev/null +++ b/tests/test_grpo.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" +Test suite for GRPO training dataloader. +Tests the OlmOCRDataset class and its functionality with Olmocr-bench format. +""" + +import os +import sys +import json +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch +import shutil + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from olmocr.train.grpo_train import OlmOCRDataset, collate_fn + + +class TestGRPODataloader(unittest.TestCase): + """Test cases for the GRPO dataloader.""" + + @classmethod + def setUpClass(cls): + """Create a temporary bench_data folder with test data.""" + cls.temp_dir = tempfile.mkdtemp() + cls.bench_data_folder = cls.temp_dir + cls.pdfs_folder = os.path.join(cls.bench_data_folder, "pdfs") + + # Create folder structure + os.makedirs(os.path.join(cls.pdfs_folder, "test_pdfs"), exist_ok=True) + + # Create dummy PDF files + cls.pdf_files = [] + for i in range(3): + pdf_path = os.path.join(cls.pdfs_folder, "test_pdfs", f"test_{i}.pdf") + # Create a minimal valid PDF + with open(pdf_path, "wb") as f: + f.write(b"%PDF-1.4\n%%EOF") + cls.pdf_files.append(pdf_path) + + # Create test JSONL files + cls.jsonl_file1 = os.path.join(cls.bench_data_folder, "test1.jsonl") + cls.jsonl_file2 = os.path.join(cls.bench_data_folder, "test2.jsonl") + + # Write test data to JSONL files + test_data1 = [ + {"pdf": "test_pdfs/test_0.pdf", "page": 0, "id": "test_0_001", "type": "math", "math": "x + y = z"}, + {"pdf": "test_pdfs/test_0.pdf", "page": 0, "id": "test_0_002", "type": "text", "text": "Sample text"}, + {"pdf": "test_pdfs/test_1.pdf", "page": 0, "id": "test_1_001", "type": "math", "math": "a^2 + b^2 = c^2"}, + {"pdf": "test_pdfs/test_1.pdf", "page": 1, "id": "test_1_002", "type": "text", "text": "Another sample"}, + ] + + test_data2 = [ + {"pdf": "test_pdfs/test_2.pdf", "page": 0, "id": "test_2_001", "type": "table", "table": "col1,col2"}, + {"pdf": "test_pdfs/test_0.pdf", "page": 0, "id": "test_0_003", "type": "text", "text": "More text"}, + {"pdf": "test_pdfs/test_2.pdf", "page": 0, "id": "test_2_002", "type": "math", "math": "\\int_0^1 x dx"}, + ] + + with open(cls.jsonl_file1, "w") as f: + for entry in test_data1: + f.write(json.dumps(entry) + "\n") + + with open(cls.jsonl_file2, "w") as f: + for entry in test_data2: + f.write(json.dumps(entry) + "\n") + + @classmethod + def tearDownClass(cls): + """Clean up temporary files.""" + shutil.rmtree(cls.temp_dir) + + def test_dataset_initialization(self): + """Test that dataset initializes correctly.""" + dataset = OlmOCRDataset( + bench_data_folder=self.bench_data_folder, + processor=None, + max_samples=None, + target_longest_image_dim=1024, + ) + + self.assertIsNotNone(dataset) + self.assertEqual(dataset.bench_data_folder, self.bench_data_folder) + self.assertEqual(dataset.pdf_folder, self.pdfs_folder) + self.assertTrue(len(dataset) > 0) + + def test_unique_pdf_loading(self): + """Test that unique PDFs are loaded correctly.""" + dataset = OlmOCRDataset( + bench_data_folder=self.bench_data_folder, + processor=None, + max_samples=None, + target_longest_image_dim=1024, + ) + + # Should have 4 unique PDF+page combinations: + # test_0.pdf page 0, test_1.pdf page 0, test_1.pdf page 1, test_2.pdf page 0 + self.assertEqual(len(dataset), 4) + + # Check that samples have correct structure + for sample in dataset.samples: + self.assertIn("pdf_path", sample) + self.assertIn("pdf_name", sample) + self.assertIn("page", sample) + self.assertIn("jsonl_file", sample) + self.assertIn("test_ids", sample) + self.assertIn("entries", sample) + + def test_test_id_aggregation(self): + """Test that test IDs are correctly aggregated per PDF+page.""" + dataset = OlmOCRDataset( + bench_data_folder=self.bench_data_folder, + processor=None, + max_samples=None, + target_longest_image_dim=1024, + ) + + # Find the sample for test_0.pdf page 0 + test_0_sample = None + for sample in dataset.samples: + if "test_0.pdf" in sample["pdf_name"] and sample["page"] == 0: + test_0_sample = sample + break + + self.assertIsNotNone(test_0_sample) + # Should have 3 test IDs for test_0.pdf page 0 + self.assertEqual(len(test_0_sample["test_ids"]), 3) + self.assertIn("test_0_001", test_0_sample["test_ids"]) + self.assertIn("test_0_002", test_0_sample["test_ids"]) + self.assertIn("test_0_003", test_0_sample["test_ids"]) + + def test_max_samples_limit(self): + """Test that max_samples correctly limits the dataset size.""" + dataset = OlmOCRDataset( + bench_data_folder=self.bench_data_folder, + processor=None, + max_samples=2, + target_longest_image_dim=1024, + ) + + self.assertEqual(len(dataset), 2) + + @patch('olmocr.train.grpo_train.render_pdf_to_base64png') + @patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt') + def test_getitem_format(self, mock_prompt, mock_render): + """Test that __getitem__ returns the correct format.""" + # Mock the rendering and prompt functions + mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" # 1x1 white pixel PNG + mock_prompt.return_value = "Test prompt" + + dataset = OlmOCRDataset( + bench_data_folder=self.bench_data_folder, + processor=None, + max_samples=1, + target_longest_image_dim=1024, + ) + + item = dataset[0] + + self.assertIsNotNone(item) + self.assertIn("prompt", item) + self.assertIn("pdf_path", item) + self.assertIn("jsonl_file", item) + self.assertIn("test_ids", item) + self.assertIn("image", item) + + # Check prompt structure + self.assertIsInstance(item["prompt"], list) + self.assertEqual(len(item["prompt"]), 1) + self.assertEqual(item["prompt"][0]["role"], "user") + self.assertIsInstance(item["prompt"][0]["content"], list) + self.assertEqual(len(item["prompt"][0]["content"]), 2) + + # Check other fields + self.assertIsInstance(item["pdf_path"], str) + self.assertIsInstance(item["jsonl_file"], str) + self.assertIsInstance(item["test_ids"], list) + self.assertTrue(len(item["test_ids"]) > 0) + + @patch('olmocr.train.grpo_train.render_pdf_to_base64png') + @patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt') + def test_collate_function(self, mock_prompt, mock_render): + """Test that the collate function works correctly.""" + # Mock the rendering and prompt functions + mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" + mock_prompt.return_value = "Test prompt" + + dataset = OlmOCRDataset( + bench_data_folder=self.bench_data_folder, + processor=None, + max_samples=2, + target_longest_image_dim=1024, + ) + + # Create a batch + batch = [dataset[0], dataset[1]] + collated = collate_fn(batch) + + self.assertIsNotNone(collated) + self.assertIn("prompts", collated) + self.assertIn("images", collated) + self.assertIn("pdf_paths", collated) + self.assertIn("jsonl_files", collated) + self.assertIn("test_ids", collated) + + # Check batch size consistency + self.assertEqual(len(collated["prompts"]), 2) + self.assertEqual(len(collated["images"]), 2) + self.assertEqual(len(collated["pdf_paths"]), 2) + self.assertEqual(len(collated["jsonl_files"]), 2) + self.assertEqual(len(collated["test_ids"]), 2) + + def test_collate_with_none_values(self): + """Test that collate function handles None values correctly.""" + batch = [None, {"prompt": [], "image": None, "pdf_path": "test.pdf", + "jsonl_file": "test.jsonl", "test_ids": ["id1"]}, None] + + collated = collate_fn(batch) + + self.assertIsNotNone(collated) + self.assertEqual(len(collated["prompts"]), 1) + + def test_empty_jsonl_handling(self): + """Test handling of empty JSONL files.""" + # Create an empty JSONL file + empty_jsonl = os.path.join(self.bench_data_folder, "empty.jsonl") + open(empty_jsonl, "w").close() + + # Should still work with other non-empty files + dataset = OlmOCRDataset( + bench_data_folder=self.bench_data_folder, + processor=None, + max_samples=None, + target_longest_image_dim=1024, + ) + + self.assertTrue(len(dataset) > 0) + + # Clean up + os.remove(empty_jsonl) + + def test_malformed_jsonl_handling(self): + """Test handling of malformed JSONL entries.""" + # Create a JSONL with some malformed entries + malformed_jsonl = os.path.join(self.bench_data_folder, "malformed.jsonl") + with open(malformed_jsonl, "w") as f: + f.write('{"pdf": "test.pdf", "id": "valid_1"}\n') + f.write('not valid json\n') + f.write('{"pdf": "test2.pdf", "id": "valid_2"}\n') + + # Should skip malformed entries but process valid ones + dataset = OlmOCRDataset( + bench_data_folder=self.bench_data_folder, + processor=None, + max_samples=None, + target_longest_image_dim=1024, + ) + + # Should still have entries from valid files + self.assertTrue(len(dataset) > 0) + + # Clean up + os.remove(malformed_jsonl) + + def test_missing_pdf_folder(self): + """Test error handling when pdfs folder is missing.""" + temp_bad_folder = tempfile.mkdtemp() + + with self.assertRaises(ValueError) as context: + dataset = OlmOCRDataset( + bench_data_folder=temp_bad_folder, + processor=None, + max_samples=None, + target_longest_image_dim=1024, + ) + + self.assertIn("PDFs folder not found", str(context.exception)) + + # Clean up + shutil.rmtree(temp_bad_folder) + + def test_no_jsonl_files(self): + """Test error handling when no JSONL files are present.""" + temp_folder = tempfile.mkdtemp() + os.makedirs(os.path.join(temp_folder, "pdfs")) + + with self.assertRaises(ValueError) as context: + dataset = OlmOCRDataset( + bench_data_folder=temp_folder, + processor=None, + max_samples=None, + target_longest_image_dim=1024, + ) + + self.assertIn("No JSONL files found", str(context.exception)) + + # Clean up + shutil.rmtree(temp_folder) + + +class TestIntegrationWithRealData(unittest.TestCase): + """Integration tests with real bench data if available.""" + + @unittest.skipUnless( + os.path.exists("/home/ubuntu/olmocr/olmOCR-bench/bench_data"), + "Real bench data not available" + ) + def test_with_real_bench_data(self): + """Test with real bench data if available.""" + bench_data_folder = "/home/ubuntu/olmocr/olmOCR-bench/bench_data" + + dataset = OlmOCRDataset( + bench_data_folder=bench_data_folder, + processor=None, + max_samples=5, + target_longest_image_dim=1024, + ) + + self.assertEqual(len(dataset), 5) + + # Test that we can iterate through the dataset + for i in range(len(dataset)): + item = dataset[i] + if item is not None: # Some PDFs might fail to render + self.assertIn("prompt", item) + self.assertIn("pdf_path", item) + self.assertIn("jsonl_file", item) + self.assertIn("test_ids", item) + + # Verify paths exist + self.assertTrue(os.path.exists(item["pdf_path"])) + self.assertTrue(os.path.exists(item["jsonl_file"])) + + +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file