mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-04 12:07:15 +00:00 
			
		
		
		
	Adding some rotation retry contrl
This commit is contained in:
		
							parent
							
								
									7678f31aa9
								
							
						
					
					
						commit
						08d51b7183
					
				@ -4,7 +4,7 @@ import boto3
 | 
			
		||||
import sqlite3
 | 
			
		||||
import orjson
 | 
			
		||||
import argparse
 | 
			
		||||
import uuid
 | 
			
		||||
import base64
 | 
			
		||||
import tempfile
 | 
			
		||||
import datetime
 | 
			
		||||
import posixpath
 | 
			
		||||
@ -15,6 +15,8 @@ import urllib3.exceptions
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from pypdf import PdfReader
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Optional, List, Tuple, Dict, Callable, Any
 | 
			
		||||
@ -383,8 +385,23 @@ class BatchWriter:
 | 
			
		||||
            thread.join()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int) -> dict:
 | 
			
		||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
 | 
			
		||||
    assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
 | 
			
		||||
    image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
 | 
			
		||||
 | 
			
		||||
    if image_rotation != 0:
 | 
			
		||||
        image_bytes = base64.b64decode(image_base64)
 | 
			
		||||
        with Image.open(BytesIO(image_bytes)) as img:
 | 
			
		||||
            rotated_img = img.rotate(-image_rotation, expand=True)
 | 
			
		||||
 | 
			
		||||
            # Save the rotated image to a bytes buffer
 | 
			
		||||
            buffered = BytesIO()
 | 
			
		||||
            rotated_img.save(buffered, format="PNG")
 | 
			
		||||
 | 
			
		||||
        # Encode the rotated image back to base64
 | 
			
		||||
        image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
@ -512,6 +529,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
 | 
			
		||||
                    
 | 
			
		||||
                    # TODO: If the rotation was previously invalid, then apply a rotation  
 | 
			
		||||
                    
 | 
			
		||||
 | 
			
		||||
                    # TODO: Try to provide a smaller prompt hint
 | 
			
		||||
                else:
 | 
			
		||||
                    new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})
 | 
			
		||||
 | 
			
		||||
@ -2,9 +2,12 @@ import unittest
 | 
			
		||||
from unittest.mock import MagicMock, patch
 | 
			
		||||
import hashlib
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import base64
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
# Adjust the import path to match where your code resides
 | 
			
		||||
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager
 | 
			
		||||
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager, build_finetuning_prompt, build_page_query
 | 
			
		||||
 | 
			
		||||
class TestBuildDolmaDoc(unittest.TestCase):
 | 
			
		||||
    @patch('pdelfin.birrpipeline.DatabaseManager')
 | 
			
		||||
@ -121,6 +124,87 @@ class TestBuildDolmaDoc(unittest.TestCase):
 | 
			
		||||
        expected_id = hashlib.sha1(expected_text.encode()).hexdigest()
 | 
			
		||||
        self.assertEqual(dolma_doc['id'], expected_id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestBuildPageQuery(unittest.TestCase):
 | 
			
		||||
    def testRotation(self):
 | 
			
		||||
        # First, generate and save the non-rotated image
 | 
			
		||||
        query = build_page_query(os.path.join(
 | 
			
		||||
            os.path.dirname(__file__),
 | 
			
		||||
            "gnarly_pdfs",
 | 
			
		||||
            "edgar.pdf"
 | 
			
		||||
        ), "edgar.pdf", 1, 1024, 6000, 0)
 | 
			
		||||
 | 
			
		||||
        # Extract the base64 image from the query
 | 
			
		||||
        image_content = query["chat_messages"][0]["content"][1]
 | 
			
		||||
        self.assertEqual(image_content["type"], "image_url")
 | 
			
		||||
        image_url = image_content["image_url"]["url"]
 | 
			
		||||
 | 
			
		||||
        # Extract base64 string from the data URL
 | 
			
		||||
        prefix = "data:image/png;base64,"
 | 
			
		||||
        self.assertTrue(image_url.startswith(prefix))
 | 
			
		||||
        image_base64 = image_url[len(prefix):]
 | 
			
		||||
 | 
			
		||||
        # Decode the base64 string
 | 
			
		||||
        image_data = base64.b64decode(image_base64)
 | 
			
		||||
 | 
			
		||||
        # Define the output file path for the non-rotated image
 | 
			
		||||
        output_image_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image.png")
 | 
			
		||||
 | 
			
		||||
        # Save the non-rotated image to a file
 | 
			
		||||
        with open(output_image_path, "wb") as image_file:
 | 
			
		||||
            image_file.write(image_data)
 | 
			
		||||
 | 
			
		||||
        # Now, generate and save the rotated image (90 degrees clockwise)
 | 
			
		||||
        query_rotated = build_page_query(os.path.join(
 | 
			
		||||
            os.path.dirname(__file__),
 | 
			
		||||
            "gnarly_pdfs",
 | 
			
		||||
            "edgar.pdf"
 | 
			
		||||
        ), "edgar.pdf", 1, 1024, 6000, 90)
 | 
			
		||||
 | 
			
		||||
        # Extract the base64 image from the rotated query
 | 
			
		||||
        image_content_rotated = query_rotated["chat_messages"][0]["content"][1]
 | 
			
		||||
        self.assertEqual(image_content_rotated["type"], "image_url")
 | 
			
		||||
        image_url_rotated = image_content_rotated["image_url"]["url"]
 | 
			
		||||
 | 
			
		||||
        # Extract base64 string from the data URL for the rotated image
 | 
			
		||||
        self.assertTrue(image_url_rotated.startswith(prefix))
 | 
			
		||||
        image_base64_rotated = image_url_rotated[len(prefix):]
 | 
			
		||||
 | 
			
		||||
        # Decode the base64 string for the rotated image
 | 
			
		||||
        image_data_rotated = base64.b64decode(image_base64_rotated)
 | 
			
		||||
 | 
			
		||||
        # Define the output file path for the rotated image
 | 
			
		||||
        output_image_rotated_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image_rotated90.png")
 | 
			
		||||
 | 
			
		||||
        # Save the rotated image to a file
 | 
			
		||||
        with open(output_image_rotated_path, "wb") as image_file_rotated:
 | 
			
		||||
            image_file_rotated.write(image_data_rotated)
 | 
			
		||||
 | 
			
		||||
        # Verification Step: Ensure the rotated image is 90 degrees clockwise rotated
 | 
			
		||||
 | 
			
		||||
        # Open both images using PIL
 | 
			
		||||
        with Image.open(output_image_path) as original_image:
 | 
			
		||||
            with Image.open(output_image_rotated_path) as rotated_image:
 | 
			
		||||
 | 
			
		||||
                # Compare pixel by pixel
 | 
			
		||||
                original_pixels = original_image.load()
 | 
			
		||||
                rotated_pixels = rotated_image.load()
 | 
			
		||||
                width, height = original_image.size
 | 
			
		||||
 | 
			
		||||
                self.assertEqual(width, rotated_image.size[1])
 | 
			
		||||
                self.assertEqual(height, rotated_image.size[0])
 | 
			
		||||
 | 
			
		||||
                for x in range(width):
 | 
			
		||||
                    for y in range(height):
 | 
			
		||||
 | 
			
		||||
                        self.assertEqual(
 | 
			
		||||
                            original_pixels[x, y], rotated_pixels[height - 1 - y, x],
 | 
			
		||||
                            f"Pixel mismatch at ({x}, {y})"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
        print("Rotation verification passed: The rotated image is correctly rotated 90 degrees clockwise.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Run the test
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/test_renders/output_image.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/test_renders/output_image.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 357 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/test_renders/output_image_rotated90.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/test_renders/output_image_rotated90.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 362 KiB  | 
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user