mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 10:04:26 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			507 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			507 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import base64
 | |
| import json
 | |
| import os
 | |
| from dataclasses import dataclass
 | |
| from io import BytesIO
 | |
| from unittest.mock import AsyncMock, patch
 | |
| 
 | |
| import pytest
 | |
| from PIL import Image
 | |
| 
 | |
| from olmocr.pipeline import PageResult, build_page_query, process_page
 | |
| 
 | |
| 
 | |
| def create_test_image(width=100, height=150):
 | |
|     """Create a simple test image with distinct features to verify rotation."""
 | |
|     img = Image.new("RGB", (width, height), color="white")
 | |
|     pixels = img.load()
 | |
| 
 | |
|     # Draw a red square in top-left corner
 | |
|     for x in range(10, 30):
 | |
|         for y in range(10, 30):
 | |
|             if pixels is not None:
 | |
|                 pixels[x, y] = (255, 0, 0)
 | |
| 
 | |
|     # Draw a blue rectangle in bottom-right corner
 | |
|     for x in range(width - 40, width - 10):
 | |
|         for y in range(height - 30, height - 10):
 | |
|             if pixels is not None:
 | |
|                 pixels[x, y] = (0, 0, 255)
 | |
| 
 | |
|     # Draw a green line near the top
 | |
|     for x in range(20, 80):
 | |
|         if pixels is not None:
 | |
|             pixels[x, 5] = (0, 255, 0)
 | |
| 
 | |
|     return img
 | |
| 
 | |
| 
 | |
| def image_to_base64(img):
 | |
|     """Convert PIL Image to base64 string."""
 | |
|     buffered = BytesIO()
 | |
|     img.save(buffered, format="PNG")
 | |
|     return base64.b64encode(buffered.getvalue()).decode()
 | |
| 
 | |
| 
 | |
| def base64_to_image(base64_str):
 | |
|     """Convert base64 string to PIL Image."""
 | |
|     image_bytes = base64.b64decode(base64_str)
 | |
|     return Image.open(BytesIO(image_bytes))
 | |
| 
 | |
| 
 | |
| class TestImageRotation:
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_no_rotation(self):
 | |
|         """Test that image_rotation=0 returns the original image."""
 | |
|         test_img = create_test_image()
 | |
|         test_base64 = image_to_base64(test_img)
 | |
| 
 | |
|         with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
 | |
|             mock_render.return_value = test_base64
 | |
| 
 | |
|             result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=0)
 | |
| 
 | |
|             # Extract the image from the result
 | |
|             messages = result["messages"]
 | |
|             content = messages[0]["content"]
 | |
|             image_url = content[1]["image_url"]["url"]
 | |
|             image_base64 = image_url.split(",")[1]
 | |
|             result_img = base64_to_image(image_base64)
 | |
| 
 | |
|             # Should be the same size as original
 | |
|             assert result_img.size == test_img.size
 | |
| 
 | |
|             # Check pixel at specific location (red square should be in top-left)
 | |
|             assert result_img.getpixel((20, 20)) == (255, 0, 0)
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_rotate_90_degrees(self):
 | |
|         """Test that image_rotation=90 rotates the image 90 degrees counter-clockwise."""
 | |
|         test_img = create_test_image(100, 150)
 | |
|         test_base64 = image_to_base64(test_img)
 | |
| 
 | |
|         with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
 | |
|             mock_render.return_value = test_base64
 | |
| 
 | |
|             result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=90)
 | |
| 
 | |
|             # Extract the image from the result
 | |
|             messages = result["messages"]
 | |
|             content = messages[0]["content"]
 | |
|             image_url = content[1]["image_url"]["url"]
 | |
|             image_base64 = image_url.split(",")[1]
 | |
|             result_img = base64_to_image(image_base64)
 | |
| 
 | |
|             # After 90 degree counter-clockwise rotation, dimensions should be swapped
 | |
|             assert result_img.size == (150, 100)
 | |
| 
 | |
|             # The red square that was at top-left should now be at bottom-left
 | |
|             # Original (20, 20) -> After 90° CCW rotation -> (20, 80)
 | |
|             assert result_img.getpixel((20, 80)) == (255, 0, 0)
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_rotate_180_degrees(self):
 | |
|         """Test that image_rotation=180 rotates the image 180 degrees."""
 | |
|         test_img = create_test_image(100, 150)
 | |
|         test_base64 = image_to_base64(test_img)
 | |
| 
 | |
|         with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
 | |
|             mock_render.return_value = test_base64
 | |
| 
 | |
|             result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=180)
 | |
| 
 | |
|             # Extract the image from the result
 | |
|             messages = result["messages"]
 | |
|             content = messages[0]["content"]
 | |
|             image_url = content[1]["image_url"]["url"]
 | |
|             image_base64 = image_url.split(",")[1]
 | |
|             result_img = base64_to_image(image_base64)
 | |
| 
 | |
|             # After 180 degree rotation, dimensions should be the same
 | |
|             assert result_img.size == (100, 150)
 | |
| 
 | |
|             # The red square that was at top-left should now be at bottom-right
 | |
|             # Original (20, 20) -> After 180° rotation -> (80, 130)
 | |
|             assert result_img.getpixel((80, 130)) == (255, 0, 0)
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_rotate_270_degrees(self):
 | |
|         """Test that image_rotation=270 rotates the image 270 degrees counter-clockwise (90 clockwise)."""
 | |
|         test_img = create_test_image(100, 150)
 | |
|         test_base64 = image_to_base64(test_img)
 | |
| 
 | |
|         with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
 | |
|             mock_render.return_value = test_base64
 | |
| 
 | |
|             result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=270)
 | |
| 
 | |
|             # Extract the image from the result
 | |
|             messages = result["messages"]
 | |
|             content = messages[0]["content"]
 | |
|             image_url = content[1]["image_url"]["url"]
 | |
|             image_base64 = image_url.split(",")[1]
 | |
|             result_img = base64_to_image(image_base64)
 | |
| 
 | |
|             # After 270 degree counter-clockwise rotation, dimensions should be swapped
 | |
|             assert result_img.size == (150, 100)
 | |
| 
 | |
|             # The red square that was at top-left should now be at top-right
 | |
|             # Original (20, 20) -> After 270° CCW rotation -> (130, 20)
 | |
|             assert result_img.getpixel((130, 20)) == (255, 0, 0)
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_invalid_rotation_angle(self):
 | |
|         """Test that invalid rotation angles raise an assertion error."""
 | |
|         test_img = create_test_image()
 | |
|         test_base64 = image_to_base64(test_img)
 | |
| 
 | |
|         with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
 | |
|             mock_render.return_value = test_base64
 | |
| 
 | |
|             with pytest.raises(AssertionError, match="Invalid image rotation"):
 | |
|                 await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=45)
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_rotation_preserves_image_quality(self):
 | |
|         """Test that rotation preserves the image without distortion."""
 | |
|         # Create a more complex test image
 | |
|         test_img = create_test_image(200, 300)
 | |
|         test_base64 = image_to_base64(test_img)
 | |
| 
 | |
|         with patch("olmocr.pipeline.render_pdf_to_base64png") as mock_render:
 | |
|             mock_render.return_value = test_base64
 | |
| 
 | |
|             # Test all valid rotation angles
 | |
|             for angle in [0, 90, 180, 270]:
 | |
|                 result = await build_page_query("fake_pdf.pdf", 1, 1000, image_rotation=angle)
 | |
| 
 | |
|                 # Extract the image from the result
 | |
|                 messages = result["messages"]
 | |
|                 content = messages[0]["content"]
 | |
|                 image_url = content[1]["image_url"]["url"]
 | |
|                 image_base64 = image_url.split(",")[1]
 | |
|                 result_img = base64_to_image(image_base64)
 | |
| 
 | |
|                 # Verify image format is preserved
 | |
|                 assert result_img.format == "PNG" or result_img.format is None
 | |
|                 assert result_img.mode == "RGB"
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class MockArgs:
 | |
|     max_page_retries: int = 8
 | |
|     target_longest_image_dim: int = 1288
 | |
|     guided_decoding: bool = False
 | |
| 
 | |
| 
 | |
| class TestRotationCorrection:
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_process_page_with_rotation_correction(self):
 | |
|         """Test that process_page correctly handles rotation correction from model response."""
 | |
| 
 | |
|         # Path to the test PDF that needs 90 degree rotation
 | |
|         test_pdf_path = "tests/gnarly_pdfs/edgar-rotated90.pdf"
 | |
| 
 | |
|         # Mock arguments
 | |
|         args = MockArgs()
 | |
| 
 | |
|         # Counter to track number of API calls
 | |
|         call_count = 0
 | |
| 
 | |
|         async def mock_apost(url, json_data):
 | |
|             nonlocal call_count
 | |
|             call_count += 1
 | |
| 
 | |
|             # Check the rotation in the request
 | |
|             messages = json_data.get("messages", [])
 | |
|             if messages:
 | |
|                 content = messages[0].get("content", [])
 | |
|                 image_data = content[0].get("image_url", {}).get("url", "")
 | |
| 
 | |
|                 # First call - model detects rotation is needed
 | |
|                 if call_count == 1:
 | |
|                     response_content = """---
 | |
| primary_language: en
 | |
| is_rotation_valid: false
 | |
| rotation_correction: 90
 | |
| is_table: false
 | |
| is_diagram: false
 | |
| ---
 | |
| 
 | |
| This document appears to be rotated and needs correction."""
 | |
| 
 | |
|                 # Second call - after rotation, model says it's correct
 | |
|                 elif call_count == 2:
 | |
|                     response_content = """---
 | |
| primary_language: en
 | |
| is_rotation_valid: true
 | |
| rotation_correction: 0
 | |
| is_table: false
 | |
| is_diagram: false
 | |
| ---
 | |
| 
 | |
| UNITED STATES
 | |
| SECURITIES AND EXCHANGE COMMISSION
 | |
| Washington, D.C. 20549
 | |
| 
 | |
| This is the corrected text from the document."""
 | |
| 
 | |
|                 else:
 | |
|                     raise ValueError(f"Unexpected call count: {call_count}")
 | |
| 
 | |
|             # Mock response structure
 | |
|             response_body = {
 | |
|                 "choices": [{"message": {"content": response_content}, "finish_reason": "stop"}],
 | |
|                 "usage": {"prompt_tokens": 1000, "completion_tokens": 100, "total_tokens": 1100},
 | |
|             }
 | |
| 
 | |
|             return 200, json.dumps(response_body).encode()
 | |
| 
 | |
|         # Mock the worker tracker
 | |
|         mock_tracker = AsyncMock()
 | |
| 
 | |
|         # Ensure the test PDF exists
 | |
|         assert os.path.exists(test_pdf_path), f"Test PDF not found at {test_pdf_path}"
 | |
| 
 | |
|         # Track calls to build_page_query
 | |
|         build_page_query_calls = []
 | |
|         original_build_page_query = build_page_query
 | |
| 
 | |
|         async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0):
 | |
|             build_page_query_calls.append(image_rotation)
 | |
|             return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation)
 | |
| 
 | |
|         with patch("olmocr.pipeline.apost", side_effect=mock_apost):
 | |
|             with patch("olmocr.pipeline.tracker", mock_tracker):
 | |
|                 with patch("olmocr.pipeline.build_page_query", side_effect=mock_build_page_query):
 | |
|                     result = await process_page(args=args, worker_id=0, pdf_orig_path="test-edgar-rotated90.pdf", pdf_local_path=test_pdf_path, page_num=1)
 | |
| 
 | |
|         # Verify the result
 | |
|         assert isinstance(result, PageResult)
 | |
|         assert result.page_num == 1
 | |
|         assert result.is_fallback == False
 | |
|         assert result.response.is_rotation_valid == True
 | |
|         assert result.response.rotation_correction == 0
 | |
|         assert result.response.natural_text is not None
 | |
|         assert "SECURITIES AND EXCHANGE COMMISSION" in result.response.natural_text
 | |
| 
 | |
|         # Verify that exactly 2 API calls were made
 | |
|         assert call_count == 2
 | |
| 
 | |
|         # Verify build_page_query was called with correct rotations
 | |
|         assert len(build_page_query_calls) == 2
 | |
|         assert build_page_query_calls[0] == 0  # First call with no rotation
 | |
|         assert build_page_query_calls[1] == 90  # Second call with 90 degree rotation
 | |
| 
 | |
|         # Verify tracker was called correctly
 | |
|         mock_tracker.track_work.assert_any_call(0, "test-edgar-rotated90.pdf-1", "started")
 | |
|         mock_tracker.track_work.assert_any_call(0, "test-edgar-rotated90.pdf-1", "finished")
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_process_page_with_cumulative_rotation(self):
 | |
|         """Test that process_page correctly accumulates rotations across multiple attempts."""
 | |
| 
 | |
|         # Path to the test PDF (can use any test PDF)
 | |
|         test_pdf_path = "tests/gnarly_pdfs/edgar-rotated90.pdf"
 | |
| 
 | |
|         # Mock arguments
 | |
|         args = MockArgs()
 | |
| 
 | |
|         # Counter to track number of API calls
 | |
|         call_count = 0
 | |
| 
 | |
|         async def mock_apost(url, json_data):
 | |
|             nonlocal call_count
 | |
|             call_count += 1
 | |
| 
 | |
|             # First call - model detects rotation is needed (90 degrees)
 | |
|             if call_count == 1:
 | |
|                 response_content = """---
 | |
| primary_language: en
 | |
| is_rotation_valid: false
 | |
| rotation_correction: 90
 | |
| is_table: false
 | |
| is_diagram: false
 | |
| ---
 | |
| 
 | |
| This document appears to be rotated and needs correction."""
 | |
| 
 | |
|             # Second call - model still detects rotation is needed (another 90 degrees)
 | |
|             elif call_count == 2:
 | |
|                 response_content = """---
 | |
| primary_language: en
 | |
| is_rotation_valid: false
 | |
| rotation_correction: 90
 | |
| is_table: false
 | |
| is_diagram: false
 | |
| ---
 | |
| 
 | |
| Document still needs rotation."""
 | |
| 
 | |
|             # Third call - after 180 total degrees of rotation, model says it's correct
 | |
|             elif call_count == 3:
 | |
|                 response_content = """---
 | |
| primary_language: en
 | |
| is_rotation_valid: true
 | |
| rotation_correction: 0
 | |
| is_table: false
 | |
| is_diagram: false
 | |
| ---
 | |
| 
 | |
| UNITED STATES
 | |
| SECURITIES AND EXCHANGE COMMISSION
 | |
| Washington, D.C. 20549
 | |
| 
 | |
| Document is now correctly oriented after 180 degree rotation."""
 | |
| 
 | |
|             else:
 | |
|                 raise ValueError(f"Unexpected call count: {call_count}")
 | |
| 
 | |
|             # Mock response structure
 | |
|             response_body = {
 | |
|                 "choices": [{"message": {"content": response_content}, "finish_reason": "stop"}],
 | |
|                 "usage": {"prompt_tokens": 1000, "completion_tokens": 100, "total_tokens": 1100},
 | |
|             }
 | |
| 
 | |
|             return 200, json.dumps(response_body).encode()
 | |
| 
 | |
|         # Mock the worker tracker
 | |
|         mock_tracker = AsyncMock()
 | |
| 
 | |
|         # Ensure the test PDF exists
 | |
|         assert os.path.exists(test_pdf_path), f"Test PDF not found at {test_pdf_path}"
 | |
| 
 | |
|         # Track calls to build_page_query
 | |
|         build_page_query_calls = []
 | |
|         original_build_page_query = build_page_query
 | |
| 
 | |
|         async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0):
 | |
|             build_page_query_calls.append(image_rotation)
 | |
|             return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation)
 | |
| 
 | |
|         with patch("olmocr.pipeline.apost", side_effect=mock_apost):
 | |
|             with patch("olmocr.pipeline.tracker", mock_tracker):
 | |
|                 with patch("olmocr.pipeline.build_page_query", side_effect=mock_build_page_query):
 | |
|                     result = await process_page(args=args, worker_id=0, pdf_orig_path="test-cumulative-rotation.pdf", pdf_local_path=test_pdf_path, page_num=1)
 | |
| 
 | |
|         # Verify the result
 | |
|         assert isinstance(result, PageResult)
 | |
|         assert result.page_num == 1
 | |
|         assert result.is_fallback == False
 | |
|         assert result.response.is_rotation_valid == True
 | |
|         assert result.response.rotation_correction == 0
 | |
|         assert result.response.natural_text is not None
 | |
|         assert "180 degree rotation" in result.response.natural_text
 | |
| 
 | |
|         # Verify that exactly 3 API calls were made
 | |
|         assert call_count == 3
 | |
| 
 | |
|         # Verify build_page_query was called with correct cumulative rotations
 | |
|         assert len(build_page_query_calls) == 3
 | |
|         assert build_page_query_calls[0] == 0  # First call with no rotation
 | |
|         assert build_page_query_calls[1] == 90  # Second call with 90 degree rotation
 | |
|         assert build_page_query_calls[2] == 180  # Third call with cumulative 180 degree rotation
 | |
| 
 | |
|         # Verify tracker was called correctly
 | |
|         mock_tracker.track_work.assert_any_call(0, "test-cumulative-rotation.pdf-1", "started")
 | |
|         mock_tracker.track_work.assert_any_call(0, "test-cumulative-rotation.pdf-1", "finished")
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_process_page_rotation_wraps_around(self):
 | |
|         """Test that cumulative rotation correctly wraps around at 360 degrees."""
 | |
| 
 | |
|         # Path to the test PDF
 | |
|         test_pdf_path = "tests/gnarly_pdfs/edgar-rotated90.pdf"
 | |
| 
 | |
|         # Mock arguments
 | |
|         args = MockArgs()
 | |
| 
 | |
|         # Counter to track number of API calls
 | |
|         call_count = 0
 | |
| 
 | |
|         async def mock_apost(url, json_data):
 | |
|             nonlocal call_count
 | |
|             call_count += 1
 | |
| 
 | |
|             # First call - model detects rotation is needed (270 degrees)
 | |
|             if call_count == 1:
 | |
|                 response_content = """---
 | |
| primary_language: en
 | |
| is_rotation_valid: false
 | |
| rotation_correction: 270
 | |
| is_table: false
 | |
| is_diagram: false
 | |
| ---
 | |
| 
 | |
| Document needs 270 degree rotation."""
 | |
| 
 | |
|             # Second call - model detects more rotation is needed (180 degrees)
 | |
|             # Total would be 450, but should wrap to 90
 | |
|             elif call_count == 2:
 | |
|                 response_content = """---
 | |
| primary_language: en
 | |
| is_rotation_valid: false
 | |
| rotation_correction: 180
 | |
| is_table: false
 | |
| is_diagram: false
 | |
| ---
 | |
| 
 | |
| Document needs additional rotation."""
 | |
| 
 | |
|             # Third call - after wrapped rotation (90 degrees), model says it's correct
 | |
|             elif call_count == 3:
 | |
|                 response_content = """---
 | |
| primary_language: en
 | |
| is_rotation_valid: true
 | |
| rotation_correction: 0
 | |
| is_table: false
 | |
| is_diagram: false
 | |
| ---
 | |
| 
 | |
| Document correctly oriented at 90 degrees total rotation."""
 | |
| 
 | |
|             else:
 | |
|                 raise ValueError(f"Unexpected call count: {call_count}")
 | |
| 
 | |
|             # Mock response structure
 | |
|             response_body = {
 | |
|                 "choices": [{"message": {"content": response_content}, "finish_reason": "stop"}],
 | |
|                 "usage": {"prompt_tokens": 1000, "completion_tokens": 100, "total_tokens": 1100},
 | |
|             }
 | |
| 
 | |
|             return 200, json.dumps(response_body).encode()
 | |
| 
 | |
|         # Mock the worker tracker
 | |
|         mock_tracker = AsyncMock()
 | |
| 
 | |
|         # Ensure the test PDF exists
 | |
|         assert os.path.exists(test_pdf_path), f"Test PDF not found at {test_pdf_path}"
 | |
| 
 | |
|         # Track calls to build_page_query
 | |
|         build_page_query_calls = []
 | |
|         original_build_page_query = build_page_query
 | |
| 
 | |
|         async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0):
 | |
|             build_page_query_calls.append(image_rotation)
 | |
|             return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation)
 | |
| 
 | |
|         with patch("olmocr.pipeline.apost", side_effect=mock_apost):
 | |
|             with patch("olmocr.pipeline.tracker", mock_tracker):
 | |
|                 with patch("olmocr.pipeline.build_page_query", side_effect=mock_build_page_query):
 | |
|                     result = await process_page(args=args, worker_id=0, pdf_orig_path="test-rotation-wrap.pdf", pdf_local_path=test_pdf_path, page_num=1)
 | |
| 
 | |
|         # Verify the result
 | |
|         assert isinstance(result, PageResult)
 | |
|         assert result.page_num == 1
 | |
|         assert result.is_fallback == False
 | |
|         assert result.response.is_rotation_valid == True
 | |
| 
 | |
|         # Verify that exactly 3 API calls were made
 | |
|         assert call_count == 3
 | |
| 
 | |
|         # Verify build_page_query was called with correct cumulative rotations
 | |
|         assert len(build_page_query_calls) == 3
 | |
|         assert build_page_query_calls[0] == 0  # First call with no rotation
 | |
|         assert build_page_query_calls[1] == 270  # Second call with 270 degree rotation
 | |
|         assert build_page_query_calls[2] == 90  # Third call with wrapped rotation (270 + 180 = 450 % 360 = 90)
 | 
