mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 08:12:22 +00:00
Setting up GRPO trainer
This commit is contained in:
parent
d046ba554a
commit
cc918ca03e
377
olmocr/train/grpo_train.py
Normal file
377
olmocr/train/grpo_train.py
Normal file
@ -0,0 +1,377 @@
|
||||
"""
|
||||
GRPO (Generative Reward-based Policy Optimization) training script for OlmOCR.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional, Set
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from olmocr.train.config import Config
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OlmOCRDataset(Dataset):
|
||||
"""Dataset for loading PDF pages from Olmocr-bench format JSONL files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bench_data_folder: str,
|
||||
processor,
|
||||
max_samples: Optional[int] = None,
|
||||
target_longest_image_dim: int = 1024,
|
||||
):
|
||||
self.bench_data_folder = bench_data_folder
|
||||
self.processor = processor
|
||||
self.target_longest_image_dim = target_longest_image_dim
|
||||
self.max_samples = max_samples
|
||||
|
||||
# Find PDF folder
|
||||
self.pdf_folder = os.path.join(bench_data_folder, "pdfs")
|
||||
if not os.path.exists(self.pdf_folder):
|
||||
raise ValueError(f"PDFs folder not found at {self.pdf_folder}")
|
||||
|
||||
# Load unique PDFs from JSONL files
|
||||
self.samples = self._load_unique_pdfs_from_jsonl()
|
||||
|
||||
logger.info(f"Created dataset with {len(self.samples)} unique PDF samples")
|
||||
|
||||
def _load_unique_pdfs_from_jsonl(self) -> List[Dict[str, Any]]:
|
||||
"""Load unique PDFs from JSONL files in the bench_data folder, tracking all test cases per PDF."""
|
||||
jsonl_files = glob.glob(os.path.join(self.bench_data_folder, "*.jsonl"))
|
||||
|
||||
if not jsonl_files:
|
||||
raise ValueError(f"No JSONL files found in {self.bench_data_folder}")
|
||||
|
||||
logger.info(f"Found {len(jsonl_files)} JSONL files")
|
||||
|
||||
# Track unique PDFs and their test cases
|
||||
pdf_data: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for jsonl_file in jsonl_files:
|
||||
logger.info(f"Processing {os.path.basename(jsonl_file)}")
|
||||
|
||||
with open(jsonl_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
entry = json.loads(line.strip())
|
||||
pdf_name = entry.get("pdf")
|
||||
page = entry.get("page", 0)
|
||||
test_id = entry.get("id")
|
||||
|
||||
if pdf_name and test_id:
|
||||
# Create unique key for PDF+page combination
|
||||
pdf_page_key = f"{pdf_name}::{page}"
|
||||
|
||||
if pdf_page_key not in pdf_data:
|
||||
# First time seeing this PDF+page
|
||||
pdf_path = os.path.join(self.pdf_folder, pdf_name)
|
||||
pdf_data[pdf_page_key] = {
|
||||
"pdf_path": pdf_path,
|
||||
"pdf_name": pdf_name,
|
||||
"page": page,
|
||||
"jsonl_file": jsonl_file,
|
||||
"test_ids": [test_id],
|
||||
"entries": [entry]
|
||||
}
|
||||
else:
|
||||
# Add test case to existing PDF+page
|
||||
pdf_data[pdf_page_key]["test_ids"].append(test_id)
|
||||
pdf_data[pdf_page_key]["entries"].append(entry)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse line in {jsonl_file}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing entry in {jsonl_file}: {e}")
|
||||
continue
|
||||
|
||||
# Convert to list and apply max_samples limit
|
||||
samples = list(pdf_data.values())
|
||||
if self.max_samples:
|
||||
samples = samples[:self.max_samples]
|
||||
|
||||
return samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.samples[idx]
|
||||
pdf_path = sample["pdf_path"]
|
||||
page_num = sample["page"]
|
||||
jsonl_file = sample["jsonl_file"]
|
||||
test_ids = sample["test_ids"]
|
||||
|
||||
try:
|
||||
# Render PDF page to base64 image
|
||||
image_base64 = render_pdf_to_base64png(
|
||||
pdf_path,
|
||||
page_num,
|
||||
target_longest_image_dim=self.target_longest_image_dim
|
||||
)
|
||||
|
||||
# Convert base64 to PIL Image
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||
|
||||
# Build the text prompt
|
||||
text_prompt = build_no_anchoring_v4_yaml_prompt()
|
||||
|
||||
# Create messages in the format expected by Qwen2-VL
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text_prompt},
|
||||
{"type": "image"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Return the required format
|
||||
return {
|
||||
"prompt": messages,
|
||||
"pdf_path": pdf_path,
|
||||
"jsonl_file": jsonl_file,
|
||||
"test_ids": test_ids,
|
||||
"image": image, # Include the PIL image for processing later
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process sample {idx}: {e}")
|
||||
# Return None if processing fails
|
||||
return None
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""Custom collate function to handle the new batch format with prompts and metadata."""
|
||||
# Filter out None values
|
||||
batch = [item for item in batch if item is not None]
|
||||
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
# Collect all components
|
||||
prompts = [item["prompt"] for item in batch]
|
||||
images = [item["image"] for item in batch]
|
||||
pdf_paths = [item["pdf_path"] for item in batch]
|
||||
jsonl_files = [item["jsonl_file"] for item in batch]
|
||||
test_ids = [item["test_ids"] for item in batch]
|
||||
|
||||
# Return batch with all required information
|
||||
return {
|
||||
"prompts": prompts,
|
||||
"images": images,
|
||||
"pdf_paths": pdf_paths,
|
||||
"jsonl_files": jsonl_files,
|
||||
"test_ids": test_ids,
|
||||
}
|
||||
|
||||
|
||||
def simple_length_reward(completions: List[str], **kwargs) -> List[float]:
|
||||
"""
|
||||
Simple reward function that rewards completions close to 100 tokens.
|
||||
Returns higher rewards for completions closer to the target length.
|
||||
"""
|
||||
target_length = 100
|
||||
rewards = []
|
||||
|
||||
for completion in completions:
|
||||
# Count tokens (simple word-based approximation)
|
||||
tokens = completion.split()
|
||||
length = len(tokens)
|
||||
|
||||
# Calculate reward based on distance from target
|
||||
distance = abs(length - target_length)
|
||||
|
||||
# Reward function: max reward of 1.0 at target length,
|
||||
# decreasing as we get further away
|
||||
if distance == 0:
|
||||
reward = 1.0
|
||||
else:
|
||||
# Exponential decay based on distance
|
||||
reward = max(0.0, 1.0 - (distance / target_length))
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
logger.info(f"Reward stats: mean={np.mean(rewards):.3f}, std={np.std(rewards):.3f}")
|
||||
return rewards
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GRPO training for OlmOCR")
|
||||
parser.add_argument(
|
||||
"--bench_data_folder",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to bench data folder containing JSONL files and pdfs subfolder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
help="Model checkpoint to load"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="outputs/grpo_test",
|
||||
help="Output directory for checkpoints"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-6,
|
||||
help="Learning rate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_train_epochs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of training epochs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Training batch size per device"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Gradient accumulation steps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_samples",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Maximum number of samples to use (for testing)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Verify bench_data_folder exists
|
||||
if not os.path.exists(args.bench_data_folder):
|
||||
logger.error(f"Bench data folder not found: {args.bench_data_folder}")
|
||||
return
|
||||
|
||||
# Load processor
|
||||
logger.info(f"Loading processor: {args.model_name}")
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
args.model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading model: {args.model_name}")
|
||||
if "Qwen2.5-VL" in args.model_name:
|
||||
model_class = Qwen2_5_VLForConditionalGeneration
|
||||
elif "Qwen2-VL" in args.model_name:
|
||||
model_class = Qwen2VLForConditionalGeneration
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {args.model_name}")
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
args.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# Create dataset from bench data folder
|
||||
logger.info(f"Creating dataset from bench data folder: {args.bench_data_folder}")
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=args.bench_data_folder,
|
||||
processor=processor,
|
||||
max_samples=args.max_samples,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
if len(dataset) == 0:
|
||||
logger.error("No samples found in dataset!")
|
||||
return
|
||||
|
||||
# Set up GRPO configuration
|
||||
grpo_config = GRPOConfig(
|
||||
output_dir=args.output_dir,
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
logging_steps=10,
|
||||
save_steps=100,
|
||||
eval_steps=50,
|
||||
warmup_steps=10,
|
||||
max_new_tokens=150,
|
||||
temperature=0.7,
|
||||
do_sample=True,
|
||||
report_to=["tensorboard"],
|
||||
remove_unused_columns=False,
|
||||
torch_dtype=torch.bfloat16,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
dataloader_num_workers=0,
|
||||
)
|
||||
|
||||
# Initialize GRPO trainer
|
||||
logger.info("Initializing GRPO trainer")
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
args=grpo_config,
|
||||
processing_class=processor,
|
||||
train_dataset=dataset,
|
||||
reward_function=simple_length_reward,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
# Start training
|
||||
logger.info("Starting GRPO training")
|
||||
try:
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
logger.info(f"Saving final model to {args.output_dir}")
|
||||
trainer.save_model()
|
||||
processor.save_pretrained(args.output_dir)
|
||||
|
||||
logger.info("Training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
338
tests/test_grpo.py
Normal file
338
tests/test_grpo.py
Normal file
@ -0,0 +1,338 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite for GRPO training dataloader.
|
||||
Tests the OlmOCRDataset class and its functionality with Olmocr-bench format.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
import shutil
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from olmocr.train.grpo_train import OlmOCRDataset, collate_fn
|
||||
|
||||
|
||||
class TestGRPODataloader(unittest.TestCase):
|
||||
"""Test cases for the GRPO dataloader."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Create a temporary bench_data folder with test data."""
|
||||
cls.temp_dir = tempfile.mkdtemp()
|
||||
cls.bench_data_folder = cls.temp_dir
|
||||
cls.pdfs_folder = os.path.join(cls.bench_data_folder, "pdfs")
|
||||
|
||||
# Create folder structure
|
||||
os.makedirs(os.path.join(cls.pdfs_folder, "test_pdfs"), exist_ok=True)
|
||||
|
||||
# Create dummy PDF files
|
||||
cls.pdf_files = []
|
||||
for i in range(3):
|
||||
pdf_path = os.path.join(cls.pdfs_folder, "test_pdfs", f"test_{i}.pdf")
|
||||
# Create a minimal valid PDF
|
||||
with open(pdf_path, "wb") as f:
|
||||
f.write(b"%PDF-1.4\n%%EOF")
|
||||
cls.pdf_files.append(pdf_path)
|
||||
|
||||
# Create test JSONL files
|
||||
cls.jsonl_file1 = os.path.join(cls.bench_data_folder, "test1.jsonl")
|
||||
cls.jsonl_file2 = os.path.join(cls.bench_data_folder, "test2.jsonl")
|
||||
|
||||
# Write test data to JSONL files
|
||||
test_data1 = [
|
||||
{"pdf": "test_pdfs/test_0.pdf", "page": 0, "id": "test_0_001", "type": "math", "math": "x + y = z"},
|
||||
{"pdf": "test_pdfs/test_0.pdf", "page": 0, "id": "test_0_002", "type": "text", "text": "Sample text"},
|
||||
{"pdf": "test_pdfs/test_1.pdf", "page": 0, "id": "test_1_001", "type": "math", "math": "a^2 + b^2 = c^2"},
|
||||
{"pdf": "test_pdfs/test_1.pdf", "page": 1, "id": "test_1_002", "type": "text", "text": "Another sample"},
|
||||
]
|
||||
|
||||
test_data2 = [
|
||||
{"pdf": "test_pdfs/test_2.pdf", "page": 0, "id": "test_2_001", "type": "table", "table": "col1,col2"},
|
||||
{"pdf": "test_pdfs/test_0.pdf", "page": 0, "id": "test_0_003", "type": "text", "text": "More text"},
|
||||
{"pdf": "test_pdfs/test_2.pdf", "page": 0, "id": "test_2_002", "type": "math", "math": "\\int_0^1 x dx"},
|
||||
]
|
||||
|
||||
with open(cls.jsonl_file1, "w") as f:
|
||||
for entry in test_data1:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
with open(cls.jsonl_file2, "w") as f:
|
||||
for entry in test_data2:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""Clean up temporary files."""
|
||||
shutil.rmtree(cls.temp_dir)
|
||||
|
||||
def test_dataset_initialization(self):
|
||||
"""Test that dataset initializes correctly."""
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=self.bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=None,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(dataset)
|
||||
self.assertEqual(dataset.bench_data_folder, self.bench_data_folder)
|
||||
self.assertEqual(dataset.pdf_folder, self.pdfs_folder)
|
||||
self.assertTrue(len(dataset) > 0)
|
||||
|
||||
def test_unique_pdf_loading(self):
|
||||
"""Test that unique PDFs are loaded correctly."""
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=self.bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=None,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
# Should have 4 unique PDF+page combinations:
|
||||
# test_0.pdf page 0, test_1.pdf page 0, test_1.pdf page 1, test_2.pdf page 0
|
||||
self.assertEqual(len(dataset), 4)
|
||||
|
||||
# Check that samples have correct structure
|
||||
for sample in dataset.samples:
|
||||
self.assertIn("pdf_path", sample)
|
||||
self.assertIn("pdf_name", sample)
|
||||
self.assertIn("page", sample)
|
||||
self.assertIn("jsonl_file", sample)
|
||||
self.assertIn("test_ids", sample)
|
||||
self.assertIn("entries", sample)
|
||||
|
||||
def test_test_id_aggregation(self):
|
||||
"""Test that test IDs are correctly aggregated per PDF+page."""
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=self.bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=None,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
# Find the sample for test_0.pdf page 0
|
||||
test_0_sample = None
|
||||
for sample in dataset.samples:
|
||||
if "test_0.pdf" in sample["pdf_name"] and sample["page"] == 0:
|
||||
test_0_sample = sample
|
||||
break
|
||||
|
||||
self.assertIsNotNone(test_0_sample)
|
||||
# Should have 3 test IDs for test_0.pdf page 0
|
||||
self.assertEqual(len(test_0_sample["test_ids"]), 3)
|
||||
self.assertIn("test_0_001", test_0_sample["test_ids"])
|
||||
self.assertIn("test_0_002", test_0_sample["test_ids"])
|
||||
self.assertIn("test_0_003", test_0_sample["test_ids"])
|
||||
|
||||
def test_max_samples_limit(self):
|
||||
"""Test that max_samples correctly limits the dataset size."""
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=self.bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=2,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 2)
|
||||
|
||||
@patch('olmocr.train.grpo_train.render_pdf_to_base64png')
|
||||
@patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt')
|
||||
def test_getitem_format(self, mock_prompt, mock_render):
|
||||
"""Test that __getitem__ returns the correct format."""
|
||||
# Mock the rendering and prompt functions
|
||||
mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" # 1x1 white pixel PNG
|
||||
mock_prompt.return_value = "Test prompt"
|
||||
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=self.bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=1,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
self.assertIsNotNone(item)
|
||||
self.assertIn("prompt", item)
|
||||
self.assertIn("pdf_path", item)
|
||||
self.assertIn("jsonl_file", item)
|
||||
self.assertIn("test_ids", item)
|
||||
self.assertIn("image", item)
|
||||
|
||||
# Check prompt structure
|
||||
self.assertIsInstance(item["prompt"], list)
|
||||
self.assertEqual(len(item["prompt"]), 1)
|
||||
self.assertEqual(item["prompt"][0]["role"], "user")
|
||||
self.assertIsInstance(item["prompt"][0]["content"], list)
|
||||
self.assertEqual(len(item["prompt"][0]["content"]), 2)
|
||||
|
||||
# Check other fields
|
||||
self.assertIsInstance(item["pdf_path"], str)
|
||||
self.assertIsInstance(item["jsonl_file"], str)
|
||||
self.assertIsInstance(item["test_ids"], list)
|
||||
self.assertTrue(len(item["test_ids"]) > 0)
|
||||
|
||||
@patch('olmocr.train.grpo_train.render_pdf_to_base64png')
|
||||
@patch('olmocr.train.grpo_train.build_no_anchoring_v4_yaml_prompt')
|
||||
def test_collate_function(self, mock_prompt, mock_render):
|
||||
"""Test that the collate function works correctly."""
|
||||
# Mock the rendering and prompt functions
|
||||
mock_render.return_value = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
|
||||
mock_prompt.return_value = "Test prompt"
|
||||
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=self.bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=2,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
# Create a batch
|
||||
batch = [dataset[0], dataset[1]]
|
||||
collated = collate_fn(batch)
|
||||
|
||||
self.assertIsNotNone(collated)
|
||||
self.assertIn("prompts", collated)
|
||||
self.assertIn("images", collated)
|
||||
self.assertIn("pdf_paths", collated)
|
||||
self.assertIn("jsonl_files", collated)
|
||||
self.assertIn("test_ids", collated)
|
||||
|
||||
# Check batch size consistency
|
||||
self.assertEqual(len(collated["prompts"]), 2)
|
||||
self.assertEqual(len(collated["images"]), 2)
|
||||
self.assertEqual(len(collated["pdf_paths"]), 2)
|
||||
self.assertEqual(len(collated["jsonl_files"]), 2)
|
||||
self.assertEqual(len(collated["test_ids"]), 2)
|
||||
|
||||
def test_collate_with_none_values(self):
|
||||
"""Test that collate function handles None values correctly."""
|
||||
batch = [None, {"prompt": [], "image": None, "pdf_path": "test.pdf",
|
||||
"jsonl_file": "test.jsonl", "test_ids": ["id1"]}, None]
|
||||
|
||||
collated = collate_fn(batch)
|
||||
|
||||
self.assertIsNotNone(collated)
|
||||
self.assertEqual(len(collated["prompts"]), 1)
|
||||
|
||||
def test_empty_jsonl_handling(self):
|
||||
"""Test handling of empty JSONL files."""
|
||||
# Create an empty JSONL file
|
||||
empty_jsonl = os.path.join(self.bench_data_folder, "empty.jsonl")
|
||||
open(empty_jsonl, "w").close()
|
||||
|
||||
# Should still work with other non-empty files
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=self.bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=None,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
self.assertTrue(len(dataset) > 0)
|
||||
|
||||
# Clean up
|
||||
os.remove(empty_jsonl)
|
||||
|
||||
def test_malformed_jsonl_handling(self):
|
||||
"""Test handling of malformed JSONL entries."""
|
||||
# Create a JSONL with some malformed entries
|
||||
malformed_jsonl = os.path.join(self.bench_data_folder, "malformed.jsonl")
|
||||
with open(malformed_jsonl, "w") as f:
|
||||
f.write('{"pdf": "test.pdf", "id": "valid_1"}\n')
|
||||
f.write('not valid json\n')
|
||||
f.write('{"pdf": "test2.pdf", "id": "valid_2"}\n')
|
||||
|
||||
# Should skip malformed entries but process valid ones
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=self.bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=None,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
# Should still have entries from valid files
|
||||
self.assertTrue(len(dataset) > 0)
|
||||
|
||||
# Clean up
|
||||
os.remove(malformed_jsonl)
|
||||
|
||||
def test_missing_pdf_folder(self):
|
||||
"""Test error handling when pdfs folder is missing."""
|
||||
temp_bad_folder = tempfile.mkdtemp()
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=temp_bad_folder,
|
||||
processor=None,
|
||||
max_samples=None,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
self.assertIn("PDFs folder not found", str(context.exception))
|
||||
|
||||
# Clean up
|
||||
shutil.rmtree(temp_bad_folder)
|
||||
|
||||
def test_no_jsonl_files(self):
|
||||
"""Test error handling when no JSONL files are present."""
|
||||
temp_folder = tempfile.mkdtemp()
|
||||
os.makedirs(os.path.join(temp_folder, "pdfs"))
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=temp_folder,
|
||||
processor=None,
|
||||
max_samples=None,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
self.assertIn("No JSONL files found", str(context.exception))
|
||||
|
||||
# Clean up
|
||||
shutil.rmtree(temp_folder)
|
||||
|
||||
|
||||
class TestIntegrationWithRealData(unittest.TestCase):
|
||||
"""Integration tests with real bench data if available."""
|
||||
|
||||
@unittest.skipUnless(
|
||||
os.path.exists("/home/ubuntu/olmocr/olmOCR-bench/bench_data"),
|
||||
"Real bench data not available"
|
||||
)
|
||||
def test_with_real_bench_data(self):
|
||||
"""Test with real bench data if available."""
|
||||
bench_data_folder = "/home/ubuntu/olmocr/olmOCR-bench/bench_data"
|
||||
|
||||
dataset = OlmOCRDataset(
|
||||
bench_data_folder=bench_data_folder,
|
||||
processor=None,
|
||||
max_samples=5,
|
||||
target_longest_image_dim=1024,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 5)
|
||||
|
||||
# Test that we can iterate through the dataset
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
if item is not None: # Some PDFs might fail to render
|
||||
self.assertIn("prompt", item)
|
||||
self.assertIn("pdf_path", item)
|
||||
self.assertIn("jsonl_file", item)
|
||||
self.assertIn("test_ids", item)
|
||||
|
||||
# Verify paths exist
|
||||
self.assertTrue(os.path.exists(item["pdf_path"]))
|
||||
self.assertTrue(os.path.exists(item["jsonl_file"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
Loading…
x
Reference in New Issue
Block a user