olmocr/tests/test_pipeline.py
2025-08-04 17:53:48 +00:00

299 lines
11 KiB
Python

import base64
import json
import os
from dataclasses import dataclass
from io import BytesIO
from unittest.mock import AsyncMock, patch
import pytest
from PIL import Image
from olmocr.pipeline import PageResult, build_page_query, process_page
def create_test_image(width=100, height=150):
"""Create a simple test image with distinct features to verify rotation."""
img = Image.new("RGB", (width, height), color="white")
pixels = img.load()
# Draw a red square in top-left corner
for x in range(10, 30):
for y in range(10, 30):
if pixels is not None:
pixels[x, y] = (255, 0, 0)
# Draw a blue rectangle in bottom-right corner
for x in range(width - 40, width - 10):
for y in range(height - 30, height - 10):
if pixels is not None:
pixels[x, y] = (0, 0, 255)
# Draw a green line near the top
for x in range(20, 80):
if pixels is not None:
pixels[x, 5] = (0, 255, 0)
return img
def image_to_base64(img):
"""Convert PIL Image to base64 string."""
buffered = BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def base64_to_image(base64_str):
"""Convert base64 string to PIL Image."""
image_bytes = base64.b64decode(base64_str)
return Image.open(BytesIO(image_bytes))
class TestImageRotation:
@pytest.mark.asyncio
async def test_no_rotation(self):
"""Test that image_rotation=0 returns the original image."""
test_img = create_test_image()
test_base64 = image_to_base64(test_img)
with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
mock_render.return_value = test_base64
result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=0)
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)
# Should be the same size as original
assert result_img.size == test_img.size
# Check pixel at specific location (red square should be in top-left)
assert result_img.getpixel((20, 20)) == (255, 0, 0)
@pytest.mark.asyncio
async def test_rotate_90_degrees(self):
"""Test that image_rotation=90 rotates the image 90 degrees counter-clockwise."""
test_img = create_test_image(100, 150)
test_base64 = image_to_base64(test_img)
with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
mock_render.return_value = test_base64
result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=90)
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)
# After 90 degree counter-clockwise rotation, dimensions should be swapped
assert result_img.size == (150, 100)
# The red square that was at top-left should now be at bottom-left
# Original (20, 20) -> After 90° CCW rotation -> (20, 80)
assert result_img.getpixel((20, 80)) == (255, 0, 0)
@pytest.mark.asyncio
async def test_rotate_180_degrees(self):
"""Test that image_rotation=180 rotates the image 180 degrees."""
test_img = create_test_image(100, 150)
test_base64 = image_to_base64(test_img)
with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
mock_render.return_value = test_base64
result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=180)
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)
# After 180 degree rotation, dimensions should be the same
assert result_img.size == (100, 150)
# The red square that was at top-left should now be at bottom-right
# Original (20, 20) -> After 180° rotation -> (80, 130)
assert result_img.getpixel((80, 130)) == (255, 0, 0)
@pytest.mark.asyncio
async def test_rotate_270_degrees(self):
"""Test that image_rotation=270 rotates the image 270 degrees counter-clockwise (90 clockwise)."""
test_img = create_test_image(100, 150)
test_base64 = image_to_base64(test_img)
with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
mock_render.return_value = test_base64
result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=270)
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)
# After 270 degree counter-clockwise rotation, dimensions should be swapped
assert result_img.size == (150, 100)
# The red square that was at top-left should now be at top-right
# Original (20, 20) -> After 270° CCW rotation -> (130, 20)
assert result_img.getpixel((130, 20)) == (255, 0, 0)
@pytest.mark.asyncio
async def test_invalid_rotation_angle(self):
"""Test that invalid rotation angles raise an assertion error."""
test_img = create_test_image()
test_base64 = image_to_base64(test_img)
with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
mock_render.return_value = test_base64
with pytest.raises(AssertionError, match="Invalid image rotation"):
await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=45)
@pytest.mark.asyncio
async def test_rotation_preserves_image_quality(self):
"""Test that rotation preserves the image without distortion."""
# Create a more complex test image
test_img = create_test_image(200, 300)
test_base64 = image_to_base64(test_img)
with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
mock_render.return_value = test_base64
# Test all valid rotation angles
for angle in [0, 90, 180, 270]:
result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=angle)
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)
# Verify image format is preserved
assert result_img.format == "PNG" or result_img.format is None
assert result_img.mode == "RGB"
@dataclass
class MockArgs:
max_page_retries: int = 8
target_longest_image_dim: int = 1288
guided_decoding: bool = False
class TestRotationCorrection:
@pytest.mark.asyncio
async def test_process_page_with_rotation_correction(self):
"""Test that process_page correctly handles rotation correction from model response."""
# Path to the test PDF that needs 90 degree rotation
test_pdf_path = "tests/gnarly_pdfs/edgar-rotated90.pdf"
# Mock arguments
args = MockArgs()
# Counter to track number of API calls
call_count = 0
async def mock_apost(url, json_data):
nonlocal call_count
call_count += 1
# Check the rotation in the request
messages = json_data.get("messages", [])
if messages:
content = messages[0].get("content", [])
image_data = content[0].get("image_url", {}).get("url", "")
# First call - model detects rotation is needed
if call_count == 1:
response_content = """---
primary_language: en
is_rotation_valid: false
rotation_correction: 90
is_table: false
is_diagram: false
---
This document appears to be rotated and needs correction."""
# Second call - after rotation, model says it's correct
elif call_count == 2:
response_content = """---
primary_language: en
is_rotation_valid: true
rotation_correction: 0
is_table: false
is_diagram: false
---
UNITED STATES
SECURITIES AND EXCHANGE COMMISSION
Washington, D.C. 20549
This is the corrected text from the document."""
else:
raise ValueError(f"Unexpected call count: {call_count}")
# Mock response structure
response_body = {
"choices": [{"message": {"content": response_content}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 1000, "completion_tokens": 100, "total_tokens": 1100},
}
return 200, json.dumps(response_body).encode()
# Mock the worker tracker
mock_tracker = AsyncMock()
# Ensure the test PDF exists
assert os.path.exists(test_pdf_path), f"Test PDF not found at {test_pdf_path}"
# Track calls to build_page_query
build_page_query_calls = []
original_build_page_query = build_page_query
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0):
build_page_query_calls.append(image_rotation)
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation)
with patch("olmocr.pipeline.apost", side_effect=mock_apost):
with patch("olmocr.pipeline.tracker", mock_tracker):
with patch("olmocr.pipeline.build_page_query", side_effect=mock_build_page_query):
result = await process_page(args=args, worker_id=0, pdf_orig_path="test-edgar-rotated90.pdf", pdf_local_path=test_pdf_path, page_num=1)
# Verify the result
assert isinstance(result, PageResult)
assert result.page_num == 1
assert result.is_fallback == False
assert result.response.is_rotation_valid == True
assert result.response.rotation_correction == 0
assert result.response.natural_text is not None
assert "SECURITIES AND EXCHANGE COMMISSION" in result.response.natural_text
# Verify that exactly 2 API calls were made
assert call_count == 2
# Verify build_page_query was called with correct rotations
assert len(build_page_query_calls) == 2
assert build_page_query_calls[0] == 0 # First call with no rotation
assert build_page_query_calls[1] == 90 # Second call with 90 degree rotation
# Verify tracker was called correctly
mock_tracker.track_work.assert_any_call(0, "test-edgar-rotated90.pdf-1", "started")
mock_tracker.track_work.assert_any_call(0, "test-edgar-rotated90.pdf-1", "finished")