mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 03:25:22 +00:00
Adding some rotation retry contrl
This commit is contained in:
parent
7678f31aa9
commit
08d51b7183
@ -4,7 +4,7 @@ import boto3
|
||||
import sqlite3
|
||||
import orjson
|
||||
import argparse
|
||||
import uuid
|
||||
import base64
|
||||
import tempfile
|
||||
import datetime
|
||||
import posixpath
|
||||
@ -15,6 +15,8 @@ import urllib3.exceptions
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pypdf import PdfReader
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||
@ -383,8 +385,23 @@ class BatchWriter:
|
||||
thread.join()
|
||||
|
||||
|
||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int) -> dict:
|
||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
|
||||
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
|
||||
|
||||
if image_rotation != 0:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
with Image.open(BytesIO(image_bytes)) as img:
|
||||
rotated_img = img.rotate(-image_rotation, expand=True)
|
||||
|
||||
# Save the rotated image to a bytes buffer
|
||||
buffered = BytesIO()
|
||||
rotated_img.save(buffered, format="PNG")
|
||||
|
||||
# Encode the rotated image back to base64
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
||||
|
||||
return {
|
||||
@ -511,6 +528,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
|
||||
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})
|
||||
|
||||
# TODO: If the rotation was previously invalid, then apply a rotation
|
||||
|
||||
|
||||
# TODO: Try to provide a smaller prompt hint
|
||||
else:
|
||||
|
||||
@ -2,9 +2,12 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from PIL import Image
|
||||
|
||||
# Adjust the import path to match where your code resides
|
||||
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager
|
||||
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager, build_finetuning_prompt, build_page_query
|
||||
|
||||
class TestBuildDolmaDoc(unittest.TestCase):
|
||||
@patch('pdelfin.birrpipeline.DatabaseManager')
|
||||
@ -121,6 +124,87 @@ class TestBuildDolmaDoc(unittest.TestCase):
|
||||
expected_id = hashlib.sha1(expected_text.encode()).hexdigest()
|
||||
self.assertEqual(dolma_doc['id'], expected_id)
|
||||
|
||||
|
||||
class TestBuildPageQuery(unittest.TestCase):
|
||||
def testRotation(self):
|
||||
# First, generate and save the non-rotated image
|
||||
query = build_page_query(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"gnarly_pdfs",
|
||||
"edgar.pdf"
|
||||
), "edgar.pdf", 1, 1024, 6000, 0)
|
||||
|
||||
# Extract the base64 image from the query
|
||||
image_content = query["chat_messages"][0]["content"][1]
|
||||
self.assertEqual(image_content["type"], "image_url")
|
||||
image_url = image_content["image_url"]["url"]
|
||||
|
||||
# Extract base64 string from the data URL
|
||||
prefix = "data:image/png;base64,"
|
||||
self.assertTrue(image_url.startswith(prefix))
|
||||
image_base64 = image_url[len(prefix):]
|
||||
|
||||
# Decode the base64 string
|
||||
image_data = base64.b64decode(image_base64)
|
||||
|
||||
# Define the output file path for the non-rotated image
|
||||
output_image_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image.png")
|
||||
|
||||
# Save the non-rotated image to a file
|
||||
with open(output_image_path, "wb") as image_file:
|
||||
image_file.write(image_data)
|
||||
|
||||
# Now, generate and save the rotated image (90 degrees clockwise)
|
||||
query_rotated = build_page_query(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"gnarly_pdfs",
|
||||
"edgar.pdf"
|
||||
), "edgar.pdf", 1, 1024, 6000, 90)
|
||||
|
||||
# Extract the base64 image from the rotated query
|
||||
image_content_rotated = query_rotated["chat_messages"][0]["content"][1]
|
||||
self.assertEqual(image_content_rotated["type"], "image_url")
|
||||
image_url_rotated = image_content_rotated["image_url"]["url"]
|
||||
|
||||
# Extract base64 string from the data URL for the rotated image
|
||||
self.assertTrue(image_url_rotated.startswith(prefix))
|
||||
image_base64_rotated = image_url_rotated[len(prefix):]
|
||||
|
||||
# Decode the base64 string for the rotated image
|
||||
image_data_rotated = base64.b64decode(image_base64_rotated)
|
||||
|
||||
# Define the output file path for the rotated image
|
||||
output_image_rotated_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image_rotated90.png")
|
||||
|
||||
# Save the rotated image to a file
|
||||
with open(output_image_rotated_path, "wb") as image_file_rotated:
|
||||
image_file_rotated.write(image_data_rotated)
|
||||
|
||||
# Verification Step: Ensure the rotated image is 90 degrees clockwise rotated
|
||||
|
||||
# Open both images using PIL
|
||||
with Image.open(output_image_path) as original_image:
|
||||
with Image.open(output_image_rotated_path) as rotated_image:
|
||||
|
||||
# Compare pixel by pixel
|
||||
original_pixels = original_image.load()
|
||||
rotated_pixels = rotated_image.load()
|
||||
width, height = original_image.size
|
||||
|
||||
self.assertEqual(width, rotated_image.size[1])
|
||||
self.assertEqual(height, rotated_image.size[0])
|
||||
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
|
||||
self.assertEqual(
|
||||
original_pixels[x, y], rotated_pixels[height - 1 - y, x],
|
||||
f"Pixel mismatch at ({x}, {y})"
|
||||
)
|
||||
|
||||
print("Rotation verification passed: The rotated image is correctly rotated 90 degrees clockwise.")
|
||||
|
||||
|
||||
# Run the test
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
BIN
tests/test_renders/output_image.png
Normal file
BIN
tests/test_renders/output_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 357 KiB |
BIN
tests/test_renders/output_image_rotated90.png
Normal file
BIN
tests/test_renders/output_image_rotated90.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 362 KiB |
Loading…
x
Reference in New Issue
Block a user