Setting up GRPO trainer

This commit is contained in:
Jake Poznanski 2025-08-20 22:18:38 +00:00
parent d046ba554a
commit cc918ca03e
2 changed files with 715 additions and 0 deletions

377
olmocr/train/grpo_train.py Normal file
View File

@ -0,0 +1,377 @@
"""
GRPO (Generative Reward-based Policy Optimization) training script for OlmOCR.
"""
import argparse
import logging
import os
from typing import List, Dict, Any, Optional, Set
import asyncio
import json
import random
from pathlib import Path
import glob
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import (
AutoProcessor,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
)
from trl import GRPOConfig, GRPOTrainer
from PIL import Image
import base64
from io import BytesIO
from olmocr.train.config import Config
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
# Configure logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
class OlmOCRDataset(Dataset):
"""Dataset for loading PDF pages from Olmocr-bench format JSONL files."""
def __init__(
self,
bench_data_folder: str,
processor,
max_samples: Optional[int] = None,
target_longest_image_dim: int = 1024,
):
self.bench_data_folder = bench_data_folder
self.processor = processor
self.target_longest_image_dim = target_longest_image_dim
self.max_samples = max_samples
# Find PDF folder
self.pdf_folder = os.path.join(bench_data_folder, "pdfs")
if not os.path.exists(self.pdf_folder):
raise ValueError(f"PDFs folder not found at {self.pdf_folder}")
# Load unique PDFs from JSONL files
self.samples = self._load_unique_pdfs_from_jsonl()
logger.info(f"Created dataset with {len(self.samples)} unique PDF samples")
def _load_unique_pdfs_from_jsonl(self) -> List[Dict[str, Any]]:
"""Load unique PDFs from JSONL files in the bench_data folder, tracking all test cases per PDF."""
jsonl_files = glob.glob(os.path.join(self.bench_data_folder, "*.jsonl"))
if not jsonl_files:
raise ValueError(f"No JSONL files found in {self.bench_data_folder}")
logger.info(f"Found {len(jsonl_files)} JSONL files")
# Track unique PDFs and their test cases
pdf_data: Dict[str, Dict[str, Any]] = {}
for jsonl_file in jsonl_files:
logger.info(f"Processing {os.path.basename(jsonl_file)}")
with open(jsonl_file, 'r') as f:
for line in f:
try:
entry = json.loads(line.strip())
pdf_name = entry.get("pdf")
page = entry.get("page", 0)
test_id = entry.get("id")
if pdf_name and test_id:
# Create unique key for PDF+page combination
pdf_page_key = f"{pdf_name}::{page}"
if pdf_page_key not in pdf_data:
# First time seeing this PDF+page
pdf_path = os.path.join(self.pdf_folder, pdf_name)
pdf_data[pdf_page_key] = {
"pdf_path": pdf_path,
"pdf_name": pdf_name,
"page": page,
"jsonl_file": jsonl_file,
"test_ids": [test_id],
"entries": [entry]
}
else:
# Add test case to existing PDF+page
pdf_data[pdf_page_key]["test_ids"].append(test_id)
pdf_data[pdf_page_key]["entries"].append(entry)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse line in {jsonl_file}: {e}")
continue
except Exception as e:
logger.warning(f"Error processing entry in {jsonl_file}: {e}")
continue
# Convert to list and apply max_samples limit
samples = list(pdf_data.values())
if self.max_samples:
samples = samples[:self.max_samples]
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
pdf_path = sample["pdf_path"]
page_num = sample["page"]
jsonl_file = sample["jsonl_file"]
test_ids = sample["test_ids"]
try:
# Render PDF page to base64 image
image_base64 = render_pdf_to_base64png(
pdf_path,
page_num,
target_longest_image_dim=self.target_longest_image_dim
)
# Convert base64 to PIL Image
image_bytes = base64.b64decode(image_base64)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
# Build the text prompt
text_prompt = build_no_anchoring_v4_yaml_prompt()
# Create messages in the format expected by Qwen2-VL
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": text_prompt},
{"type": "image"},
],
}
]
# Return the required format
return {
"prompt": messages,
"pdf_path": pdf_path,
"jsonl_file": jsonl_file,
"test_ids": test_ids,
"image": image, # Include the PIL image for processing later
}
except Exception as e:
logger.error(f"Failed to process sample {idx}: {e}")
# Return None if processing fails
return None
def collate_fn(batch):
"""Custom collate function to handle the new batch format with prompts and metadata."""
# Filter out None values
batch = [item for item in batch if item is not None]
if not batch:
return None
# Collect all components
prompts = [item["prompt"] for item in batch]
images = [item["image"] for item in batch]
pdf_paths = [item["pdf_path"] for item in batch]
jsonl_files = [item["jsonl_file"] for item in batch]
test_ids = [item["test_ids"] for item in batch]
# Return batch with all required information
return {
"prompts": prompts,
"images": images,
"pdf_paths": pdf_paths,
"jsonl_files": jsonl_files,
"test_ids": test_ids,
}
def simple_length_reward(completions: List[str], **kwargs) -> List[float]:
"""
Simple reward function that rewards completions close to 100 tokens.
Returns higher rewards for completions closer to the target length.
"""
target_length = 100
rewards = []
for completion in completions:
# Count tokens (simple word-based approximation)
tokens = completion.split()
length = len(tokens)
# Calculate reward based on distance from target
distance = abs(length - target_length)
# Reward function: max reward of 1.0 at target length,
# decreasing as we get further away
if distance == 0:
reward = 1.0
else:
# Exponential decay based on distance
reward = max(0.0, 1.0 - (distance / target_length))
rewards.append(reward)
logger.info(f"Reward stats: mean={np.mean(rewards):.3f}, std={np.std(rewards):.3f}")
return rewards
def main():
parser = argparse.ArgumentParser(description="GRPO training for OlmOCR")
parser.add_argument(
"--bench_data_folder",
type=str,
required=True,
help="Path to bench data folder containing JSONL files and pdfs subfolder"
)
parser.add_argument(
"--model_name",
type=str,
default="Qwen/Qwen2.5-VL-7B-Instruct",
help="Model checkpoint to load"
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs/grpo_test",
help="Output directory for checkpoints"
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-6,
help="Learning rate"
)
parser.add_argument(
"--num_train_epochs",
type=int,
default=1,
help="Number of training epochs"
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=1,
help="Training batch size per device"
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=4,
help="Gradient accumulation steps"
)
parser.add_argument(
"--max_samples",
type=int,
default=10,
help="Maximum number of samples to use (for testing)"
)
args = parser.parse_args()
# Set up output directory
os.makedirs(args.output_dir, exist_ok=True)
# Verify bench_data_folder exists
if not os.path.exists(args.bench_data_folder):
logger.error(f"Bench data folder not found: {args.bench_data_folder}")
return
# Load processor
logger.info(f"Loading processor: {args.model_name}")
processor = AutoProcessor.from_pretrained(
args.model_name,
trust_remote_code=True,
)
# Load model
logger.info(f"Loading model: {args.model_name}")
if "Qwen2.5-VL" in args.model_name:
model_class = Qwen2_5_VLForConditionalGeneration
elif "Qwen2-VL" in args.model_name:
model_class = Qwen2VLForConditionalGeneration
else:
raise ValueError(f"Unsupported model: {args.model_name}")
model = model_class.from_pretrained(
args.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# Create dataset from bench data folder
logger.info(f"Creating dataset from bench data folder: {args.bench_data_folder}")
dataset = OlmOCRDataset(
bench_data_folder=args.bench_data_folder,
processor=processor,
max_samples=args.max_samples,
target_longest_image_dim=1024,
)
if len(dataset) == 0:
logger.error("No samples found in dataset!")
return
# Set up GRPO configuration
grpo_config = GRPOConfig(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
logging_steps=10,
save_steps=100,
eval_steps=50,
warmup_steps=10,
max_new_tokens=150,
temperature=0.7,
do_sample=True,
report_to=["tensorboard"],
remove_unused_columns=False,
torch_dtype=torch.bfloat16,
bf16=True,
gradient_checkpointing=True,
dataloader_num_workers=0,
)
# Initialize GRPO trainer
logger.info("Initializing GRPO trainer")
trainer = GRPOTrainer(
model=model,
args=grpo_config,
processing_class=processor,
train_dataset=dataset,
reward_function=simple_length_reward,
data_collator=collate_fn,
)
# Start training
logger.info("Starting GRPO training")
try:
trainer.train()
# Save final model
logger.info(f"Saving final model to {args.output_dir}")
trainer.save_model()
processor.save_pretrained(args.output_dir)
logger.info("Training completed successfully!")
except Exception as e:
logger.error(f"Training failed: {e}")
raise
if __name__ == "__main__":
main()

338
tests/test_grpo.py Normal file
View File

@ -0,0 +1,338 @@
#!/usr/bin/env python3
"""
Test suite for GRPO training dataloader.
Tests the OlmOCRDataset class and its functionality with Olmocr-bench format.
"""
import os
import sys
import json
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
import shutil
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from olmocr.train.grpo_train import OlmOCRDataset, collate_fn
class TestGRPODataloader(unittest.TestCase):
"""Test cases for the GRPO dataloader."""
@classmethod
def setUpClass(cls):
"""Create a temporary bench_data folder with test data."""
cls.temp_dir = tempfile.mkdtemp()
cls.bench_data_folder = cls.temp_dir
cls.pdfs_folder = os.path.join(cls.bench_data_folder, "pdfs")
# Create folder structure
os.makedirs(os.path.join(cls.pdfs_folder, "test_pdfs"), exist_ok=True)
# Create dummy PDF files
cls.pdf_files = []
for i in range(3):
pdf_path = os.path.join(cls.pdfs_folder, "test_pdfs", f"test_{i}.pdf")
# Create a minimal valid PDF
with open(pdf_path, "wb") as f:
f.write(b"%PDF-1.4\n%%EOF")
cls.pdf_files.append(pdf_path)
# Create test JSONL files
cls.jsonl_file1 = os.path.join(cls.bench_data_folder, "test1.jsonl")
cls.jsonl_file2 = os.path.join(cls.bench_data_folder, "test2.jsonl")
# Write test data to JSONL files
test_data1 = [
{"pdf": "test_pdfs/test_0.pdf", "page": 0, "id": "test_0_001", "type": "math", "math": "x + y = z"},
{"pdf": "test_pdfs/test_0.pdf", "page": 0, "id": "test_0_002", "type": "text", "text": "Sample text"},
{"pdf": "test_pdfs/test_1.pdf", "page": 0, "id": "test_1_001", "type": "math", "math": "a^2 + b^2 = c^2"},
{"pdf": "test_pdfs/test_1.pdf", "page": 1, "id": "test_1_002", "type": "text", "text": "Another sample"},
]
test_data2 = [
{"pdf": "test_pdfs/test_2.pdf", "page": 0, "id": "test_2_001", "type": "table", "table": "col1,col2"},
{"pdf": "test_pdfs/test_0.pdf", "page": 0, "id": "test_0_003", "type": "text", "text": "More text"},
{"pdf": "test_pdfs/test_2.pdf", "page": 0, "id": "test_2_002", "type": "math", "math": "\\int_0^1 x dx"},
]
with open(cls.jsonl_file1, "w") as f:
for entry in test_data1:
f.write(json.dumps(entry) + "\n")
with open(cls.jsonl_file2, "w") as f:
for entry in test_data2:
f.write(json.dumps(entry) + "\n")
@classmethod
def tearDownClass(cls):
"""Clean up temporary files."""
shutil.rmtree(cls.temp_dir)
def test_dataset_initialization(self):
"""Test that dataset initializes correctly."""
dataset = OlmOCRDataset(
bench_data_folder=self.bench_data_folder,
processor=None,
max_samples=None,
target_longest_image_dim=1024,
)
self.assertIsNotNone(dataset)
self.assertEqual(dataset.bench_data_folder, self.bench_data_folder)
self.assertEqual(dataset.pdf_folder, self.pdfs_folder)
self.assertTrue(len(dataset) > 0)
def test_unique_pdf_loading(self):
"""Test that unique PDFs are loaded correctly."""
dataset = OlmOCRDataset(
bench_data_folder=self.bench_data_folder,
processor=None,
max_samples=None,
target_longest_image_dim=1024,
)
# Should have 4 unique PDF+page combinations:
# test_0.pdf page 0, test_1.pdf page 0, test_1.pdf page 1, test_2.pdf page 0
self.assertEqual(len(dataset), 4)
# Check that samples have correct structure
for sample in dataset.samples:
self.assertIn("pdf_path", sample)
self.assertIn("pdf_name", sample)
self.assertIn("page", sample)
self.assertIn("jsonl_file", sample)
self.assertIn("test_ids", sample)
self.assertIn("entries", sample)
def test_test_id_aggregation(self):
"""Test that test IDs are correctly aggregated per PDF+page."""
dataset = OlmOCRDataset(
bench_data_folder=self.bench_data_folder,
processor=None,
max_samples=None,
target_longest_image_dim=1024,
)
# Find the sample for test_0.pdf page 0
test_0_sample = None
for sample in dataset.samples:
if "test_0.pdf" in sample["pdf_name"] and sample["page"] == 0:
test_0_sample = sample
break
self.assertIsNotNone(test_0_sample)
# Should have 3 test IDs for test_0.pdf page 0
self.assertEqual(len(test_0_sample["test_ids"]), 3)
self.assertIn("test_0_001", test_0_sample["test_ids"])
self.assertIn("test_0_002", test_0_sample["test_ids"])
self.assertIn("test_0_003", test_0_sample["test_ids"])
def test_max_samples_limit(self):
"""Test that max_samples correctly limits the dataset size."""
dataset = OlmOCRDataset(
bench_data_folder=self.bench_data_folder,
processor=None,
max_samples=2,
target_longest_image_dim=1024,
)
self.assertEqual(len(dataset), 2)
@patch('olmocr.train.grpo_train.render_pdf_to_base64png')
@patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt')
def test_getitem_format(self, mock_prompt, mock_render):
"""Test that __getitem__ returns the correct format."""
# Mock the rendering and prompt functions
mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" # 1x1 white pixel PNG
mock_prompt.return_value = "Test prompt"
dataset = OlmOCRDataset(
bench_data_folder=self.bench_data_folder,
processor=None,
max_samples=1,
target_longest_image_dim=1024,
)
item = dataset[0]
self.assertIsNotNone(item)
self.assertIn("prompt", item)
self.assertIn("pdf_path", item)
self.assertIn("jsonl_file", item)
self.assertIn("test_ids", item)
self.assertIn("image", item)
# Check prompt structure
self.assertIsInstance(item["prompt"], list)
self.assertEqual(len(item["prompt"]), 1)
self.assertEqual(item["prompt"][0]["role"], "user")
self.assertIsInstance(item["prompt"][0]["content"], list)
self.assertEqual(len(item["prompt"][0]["content"]), 2)
# Check other fields
self.assertIsInstance(item["pdf_path"], str)
self.assertIsInstance(item["jsonl_file"], str)
self.assertIsInstance(item["test_ids"], list)
self.assertTrue(len(item["test_ids"]) > 0)
@patch('olmocr.train.grpo_train.render_pdf_to_base64png')
@patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt')
def test_collate_function(self, mock_prompt, mock_render):
"""Test that the collate function works correctly."""
# Mock the rendering and prompt functions
mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
mock_prompt.return_value = "Test prompt"
dataset = OlmOCRDataset(
bench_data_folder=self.bench_data_folder,
processor=None,
max_samples=2,
target_longest_image_dim=1024,
)
# Create a batch
batch = [dataset[0], dataset[1]]
collated = collate_fn(batch)
self.assertIsNotNone(collated)
self.assertIn("prompts", collated)
self.assertIn("images", collated)
self.assertIn("pdf_paths", collated)
self.assertIn("jsonl_files", collated)
self.assertIn("test_ids", collated)
# Check batch size consistency
self.assertEqual(len(collated["prompts"]), 2)
self.assertEqual(len(collated["images"]), 2)
self.assertEqual(len(collated["pdf_paths"]), 2)
self.assertEqual(len(collated["jsonl_files"]), 2)
self.assertEqual(len(collated["test_ids"]), 2)
def test_collate_with_none_values(self):
"""Test that collate function handles None values correctly."""
batch = [None, {"prompt": [], "image": None, "pdf_path": "test.pdf",
"jsonl_file": "test.jsonl", "test_ids": ["id1"]}, None]
collated = collate_fn(batch)
self.assertIsNotNone(collated)
self.assertEqual(len(collated["prompts"]), 1)
def test_empty_jsonl_handling(self):
"""Test handling of empty JSONL files."""
# Create an empty JSONL file
empty_jsonl = os.path.join(self.bench_data_folder, "empty.jsonl")
open(empty_jsonl, "w").close()
# Should still work with other non-empty files
dataset = OlmOCRDataset(
bench_data_folder=self.bench_data_folder,
processor=None,
max_samples=None,
target_longest_image_dim=1024,
)
self.assertTrue(len(dataset) > 0)
# Clean up
os.remove(empty_jsonl)
def test_malformed_jsonl_handling(self):
"""Test handling of malformed JSONL entries."""
# Create a JSONL with some malformed entries
malformed_jsonl = os.path.join(self.bench_data_folder, "malformed.jsonl")
with open(malformed_jsonl, "w") as f:
f.write('{"pdf": "test.pdf", "id": "valid_1"}\n')
f.write('not valid json\n')
f.write('{"pdf": "test2.pdf", "id": "valid_2"}\n')
# Should skip malformed entries but process valid ones
dataset = OlmOCRDataset(
bench_data_folder=self.bench_data_folder,
processor=None,
max_samples=None,
target_longest_image_dim=1024,
)
# Should still have entries from valid files
self.assertTrue(len(dataset) > 0)
# Clean up
os.remove(malformed_jsonl)
def test_missing_pdf_folder(self):
"""Test error handling when pdfs folder is missing."""
temp_bad_folder = tempfile.mkdtemp()
with self.assertRaises(ValueError) as context:
dataset = OlmOCRDataset(
bench_data_folder=temp_bad_folder,
processor=None,
max_samples=None,
target_longest_image_dim=1024,
)
self.assertIn("PDFs folder not found", str(context.exception))
# Clean up
shutil.rmtree(temp_bad_folder)
def test_no_jsonl_files(self):
"""Test error handling when no JSONL files are present."""
temp_folder = tempfile.mkdtemp()
os.makedirs(os.path.join(temp_folder, "pdfs"))
with self.assertRaises(ValueError) as context:
dataset = OlmOCRDataset(
bench_data_folder=temp_folder,
processor=None,
max_samples=None,
target_longest_image_dim=1024,
)
self.assertIn("No JSONL files found", str(context.exception))
# Clean up
shutil.rmtree(temp_folder)
class TestIntegrationWithRealData(unittest.TestCase):
"""Integration tests with real bench data if available."""
@unittest.skipUnless(
os.path.exists("/home/ubuntu/olmocr/olmOCR-bench/bench_data"),
"Real bench data not available"
)
def test_with_real_bench_data(self):
"""Test with real bench data if available."""
bench_data_folder = "/home/ubuntu/olmocr/olmOCR-bench/bench_data"
dataset = OlmOCRDataset(
bench_data_folder=bench_data_folder,
processor=None,
max_samples=5,
target_longest_image_dim=1024,
)
self.assertEqual(len(dataset), 5)
# Test that we can iterate through the dataset
for i in range(len(dataset)):
item = dataset[i]
if item is not None: # Some PDFs might fail to render
self.assertIn("prompt", item)
self.assertIn("pdf_path", item)
self.assertIn("jsonl_file", item)
self.assertIn("test_ids", item)
# Verify paths exist
self.assertTrue(os.path.exists(item["pdf_path"]))
self.assertTrue(os.path.exists(item["jsonl_file"]))
if __name__ == "__main__":
unittest.main(verbosity=2)